Constant folding for dynamic shape node (#129686)

Extend constant folding for dynamic shape node, only support pointwise op and some restricted ops

We support dynamic shapes by limiting constant folding of ops that are guaranteed to have uniform values (full, pointwise ops, and views) and running these operators with tensors of shape 1. This also eliminates the possibility of memory overhead of constant folding.

Taken over from https://github.com/pytorch/pytorch/pull/128937

joint work with @imzhuhl

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129686
Approved by: https://github.com/Chillee
ghstack dependencies: #130367
This commit is contained in:
eellison
2024-07-15 11:21:55 -07:00
committed by PyTorch MergeBot
parent ea4f310ff1
commit 9ab8d47f9d
9 changed files with 199 additions and 21 deletions

View File

@ -94,7 +94,7 @@ hf_Bert_large,pass,6
hf_BigBird,fail_to_run,3
hf_BigBird,pass,6

1 name accuracy graph_breaks
94
95
96
97
98
99
100

View File

@ -1522,7 +1522,7 @@ class CPUReproTests(TestCase):
def test_int_div(self):
def fn(x, y):
s3 = x.size(1)
a = torch.zeros((1 + s3) // 2)
a = torch.ones((1 + s3) // 2)
a += y
return a, s3

View File

@ -1151,7 +1151,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
for _ in range(3):
out = foo(inp)
node = self.curr_node()
self.assertEqual(len(list(node.path_live_weakrefs())), 2)
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@torch.compile(mode="reduce-overhead")
def foo(x):

View File

@ -5153,7 +5153,7 @@ class CommonTemplate:
def fn(x, y):
z = y.item()
torch._check(z // 2 == 3)
return x + x.new_zeros(z)
return x + x.new_ones(z)
self.common(
fn,
@ -11171,7 +11171,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
UniformValueConstantFolder(mod).run()
# there are a couple extra tensors created in `insertable_tensor_check`
self.assertTrue(max_live_tensors == 4)
self.assertTrue(max_live_tensors == 3)
# See https://github.com/pytorch/pytorch/issues/100348
def test_inductor_detach_view(self):

View File

@ -323,6 +323,9 @@ test_failures = {
"test_list_clearing_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),
"test_dropout_trivial_1_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),
"test_dropout2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
"test_dropout3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
"test_masked_fill_promotion_dynamic_shapes": TestFailure(

View File

@ -18,7 +18,9 @@ from torch._inductor.codegen.common import device_codegens, register_backend_for
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.wrapper import WrapperCodeGen
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -146,6 +148,51 @@ class TestInductorDynamic(TestCase):
TestCase.tearDown(self)
torch._dynamo.reset()
def test_constant_fold_uniform_value_dynamic(self, device):
def full_add_zero(x):
a = torch.full(x.shape, 1, dtype=x.dtype, device=x.device)
b = a - 1
return x + b
def full_mul_one(x):
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
b = 2 + a
return x * b
def full_view_op(x):
a = torch.ones([1], dtype=x.dtype, device=x.device)
a = a[:, None]
return x * a
def full_mul_symint(x):
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
b = 2 + a
return b * x.shape[0]
fns = (full_add_zero, full_mul_one, full_view_op)
x = torch.randn((2, 4), device=device)
y = torch.randn((3, 4), device=device)
for dynamic in [False, True]:
torch._dynamo.reset()
for fn in fns:
ref = fn(x)
fn_c = torch.compile(fn, dynamic=dynamic)
actual, source_codes = run_and_get_code(fn_c, x)
if fn is not full_mul_symint:
# due to constant folding, fn returns x directly.
if device == "cpu":
FileCheck().check_not("cpp_fused").run(source_codes[0])
else:
FileCheck().check_not("triton.jit").run(source_codes[0])
self.assertEqual(ref, actual)
self.assertEqual(fn(x), fn_c(x))
self.assertEqual(fn(y), fn_c(y))
def test_arange_dynamic(self, device):
def fn(a):
batch_size = a.numel()

View File

@ -4591,9 +4591,10 @@ else:
# FIXME: move to test distributions
@deviceCountAtLeast(2)
@onlyCUDA
@skipIfTorchInductor("FIXME: error not thrown")
def test_multinomial_gpu_device_constrain(self, devices):
x = torch.empty(3, device=devices[0])
y = torch.empty(3, device=devices[1])
y = torch.empty(3, device=devices[1], dtype=torch.long)
self.assertRaisesRegex(
RuntimeError, "Expected all tensors to be on the same device",
lambda: torch.multinomial(x, 2, out=y))

View File

@ -60,6 +60,13 @@ class ConstantFolder(torch.fx.Interpreter):
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
def _support_dynamic_shape(self):
# ConstantFolder not support dynamic shape now
return False
def _deduce_value(self, node):
return super().run_node(node)
def is_impure(self, node: torch.fx.node.Node):
if (
node.target == torch.ops.prims.convert_element_type.default
@ -159,7 +166,9 @@ class ConstantFolder(torch.fx.Interpreter):
):
return self.unknown_value
out = super().run_node(node)
out = self._deduce_value(node)
if out == self.unknown_value:
return self.unknown_value
if node.op != "get_attr" and isinstance(out, torch.Tensor):
if out.device.type == "meta":
@ -194,10 +203,13 @@ class ConstantFolder(torch.fx.Interpreter):
self.node_replacements[node] = tensor
def run(self):
env = {}
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
env[n] = self.unknown_value
return super().run(initial_env=env)
@torch.utils._python_dispatch._disable_current_modes()

View File

@ -3,11 +3,13 @@ import itertools
import logging
import typing
from collections import Counter
from typing import Dict, List, Set, Union
from typing import Any, Dict, List, Set, Union
import torch
import torch._guards
import torch.utils._pytree as pytree
from torch._inductor.constant_folding import ConstantFolder
from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
from torch.fx.experimental.symbolic_shapes import statically_known_true
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.multiprocessing.reductions import StorageWeakRef
@ -201,21 +203,116 @@ class UniformValueConstantFolder(ConstantFolder):
# see: [constant folding refining of symints]
self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
# initialize symint -> node mapping so that we can
# use symint nodes in full constructors
self.symint_nodes = _SymHashingDict()
for n in self.module.graph.nodes:
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
self.symint_nodes[n.meta["val"]] = n
# reference from torch/_funtorch/partitioners.py:get_default_op_list
self.view_op_packets = [
aten.squeeze,
aten.unsqueeze,
aten.alias,
aten.view,
aten.slice,
aten.t,
prims.broadcast_in_dim,
aten.expand,
aten.as_strided,
aten.permute,
]
self.indexing_op_packets = {
aten.slice,
}
def _support_dynamic_shape(self):
return True
def insertable_tensor_check(self, t: torch.Tensor) -> bool:
# TODO - we could also Tensors which get replaced with arange here
return (
t.numel() != 0
and bool((t == t.flatten()[0]).all())
and torch._C._has_storage(t)
and t.layout == torch.strided
)
return True
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor.flatten()[0].item()
self.node_replacements_shapes[node] = node.meta["val"].shape
self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
shape = list(tensor.shape)
assert all(type(dim) is int for dim in shape)
self.node_replacements_shapes[node] = shape
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
env[n] = n.meta["val"]
else:
env[n] = self.unknown_value
def _deduce_value(self, node: torch.fx.Node):
# deduce value for full-like nodes
# 1. for constructors, substitute value is a tensor of size [1]
# 2. for view ops/indexing, substitute value is the same as the input
# 3. for pointwise ops, run node to get the substitute value
# 4. deal with some special ops
# otherwise, stop deduce value and return unknown value
# TODO: cat, more indexing
# TODO - do on cpu to avoid syncs
# single-elem attrs
if node.op == "get_attr" or (
node.op == "call_function"
and node.target == torch.ops.aten.lift_fresh_copy.default
):
out = super(ConstantFolder, self).run_node(node)
if isinstance(out, torch.Tensor) and out.numel() == 1:
return out
# constructors ops
if (
node.op == "call_function"
and node.target == aten.full.default
and len(node.args) == 2
):
args, kwargs = self.fetch_args_kwargs_from_env(node)
new_args = [[1], args[1]]
return aten.full.default(*new_args, **node.kwargs)
# handle before view ops because this changes value
if node.target == aten.view.dtype:
return super(ConstantFolder, self).run_node(node)
# view ops, return input tensor, the first argument
if hasattr(node.target, "overloadpacket") and (
node.target.overloadpacket in self.view_op_packets
or node.target.overloadpacket in self.indexing_op_packets
):
assert isinstance(node.args[0], torch.fx.Node)
return self.env[node.args[0]]
# we don't want to return unknown value for symints so that we can
# still constant fold through their use in constructors or views
# if we see them in a pointwise node (e.g., tensor * symint)
# we will bail
if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
return node.meta["val"]
# pointwise ops
if isinstance(node.target, torch._ops.OpOverload) and (
torch.Tag.pointwise in node.target.tags
or node.target is torch.ops.aten.scalar_tensor.default
):
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
return self.unknown_value
# we run the ops with dim 1, so remove memory_format to avoid error
kwargs = dict(kwargs)
kwargs.pop("memory_format", None)
return node.target(*args, **kwargs)
return self.unknown_value
@torch.utils._python_dispatch._disable_current_modes()
@ -262,6 +359,10 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
continue
# TODO - not sure about lossy uint->python value->uint conversions
if fake_tensor.dtype in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
continue
if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
continue
@ -280,10 +381,24 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
):
torch._check(runtime_size == compile_time_size)
# replace SymInt as Node before creating a new full node
# e.g. (1, s0) -> (1, arg0_1)
node_shape = node_replacements_shapes[node]
if not all(
not isinstance(s, torch.SymInt) or s in cf.symint_nodes
for s in node_shape
):
continue
shapes = [
cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
for s in node_replacements_shapes[node]
]
# zeros and ones just get traced into full, so we insert those
new_node = graph.call_function(
aten.full.default,
args=(node_replacements_shapes[node], value),
args=(shapes, value),
kwargs={
"dtype": fake_tensor.dtype,
"layout": torch.strided,