mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Constant folding for dynamic shape node (#129686)
Extend constant folding for dynamic shape node, only support pointwise op and some restricted ops We support dynamic shapes by limiting constant folding of ops that are guaranteed to have uniform values (full, pointwise ops, and views) and running these operators with tensors of shape 1. This also eliminates the possibility of memory overhead of constant folding. Taken over from https://github.com/pytorch/pytorch/pull/128937 joint work with @imzhuhl Pull Request resolved: https://github.com/pytorch/pytorch/pull/129686 Approved by: https://github.com/Chillee ghstack dependencies: #130367
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea4f310ff1
commit
9ab8d47f9d
@ -94,7 +94,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_to_run,3
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -1522,7 +1522,7 @@ class CPUReproTests(TestCase):
|
||||
def test_int_div(self):
|
||||
def fn(x, y):
|
||||
s3 = x.size(1)
|
||||
a = torch.zeros((1 + s3) // 2)
|
||||
a = torch.ones((1 + s3) // 2)
|
||||
a += y
|
||||
return a, s3
|
||||
|
||||
|
@ -1151,7 +1151,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
for _ in range(3):
|
||||
out = foo(inp)
|
||||
node = self.curr_node()
|
||||
self.assertEqual(len(list(node.path_live_weakrefs())), 2)
|
||||
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
||||
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(x):
|
||||
|
@ -5153,7 +5153,7 @@ class CommonTemplate:
|
||||
def fn(x, y):
|
||||
z = y.item()
|
||||
torch._check(z // 2 == 3)
|
||||
return x + x.new_zeros(z)
|
||||
return x + x.new_ones(z)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
@ -11171,7 +11171,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
UniformValueConstantFolder(mod).run()
|
||||
|
||||
# there are a couple extra tensors created in `insertable_tensor_check`
|
||||
self.assertTrue(max_live_tensors == 4)
|
||||
self.assertTrue(max_live_tensors == 3)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/100348
|
||||
def test_inductor_detach_view(self):
|
||||
|
@ -323,6 +323,9 @@ test_failures = {
|
||||
"test_list_clearing_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu"), is_skip=True
|
||||
),
|
||||
"test_dropout_trivial_1_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu"), is_skip=True
|
||||
),
|
||||
"test_dropout2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
|
||||
"test_dropout3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
|
||||
"test_masked_fill_promotion_dynamic_shapes": TestFailure(
|
||||
|
@ -18,7 +18,9 @@ from torch._inductor.codegen.common import device_codegens, register_backend_for
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.wrapper import WrapperCodeGen
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
@ -146,6 +148,51 @@ class TestInductorDynamic(TestCase):
|
||||
TestCase.tearDown(self)
|
||||
torch._dynamo.reset()
|
||||
|
||||
def test_constant_fold_uniform_value_dynamic(self, device):
|
||||
def full_add_zero(x):
|
||||
a = torch.full(x.shape, 1, dtype=x.dtype, device=x.device)
|
||||
b = a - 1
|
||||
return x + b
|
||||
|
||||
def full_mul_one(x):
|
||||
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
|
||||
b = 2 + a
|
||||
return x * b
|
||||
|
||||
def full_view_op(x):
|
||||
a = torch.ones([1], dtype=x.dtype, device=x.device)
|
||||
a = a[:, None]
|
||||
return x * a
|
||||
|
||||
def full_mul_symint(x):
|
||||
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
|
||||
b = 2 + a
|
||||
return b * x.shape[0]
|
||||
|
||||
fns = (full_add_zero, full_mul_one, full_view_op)
|
||||
|
||||
x = torch.randn((2, 4), device=device)
|
||||
y = torch.randn((3, 4), device=device)
|
||||
|
||||
for dynamic in [False, True]:
|
||||
torch._dynamo.reset()
|
||||
for fn in fns:
|
||||
ref = fn(x)
|
||||
fn_c = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
actual, source_codes = run_and_get_code(fn_c, x)
|
||||
|
||||
if fn is not full_mul_symint:
|
||||
# due to constant folding, fn returns x directly.
|
||||
if device == "cpu":
|
||||
FileCheck().check_not("cpp_fused").run(source_codes[0])
|
||||
else:
|
||||
FileCheck().check_not("triton.jit").run(source_codes[0])
|
||||
|
||||
self.assertEqual(ref, actual)
|
||||
self.assertEqual(fn(x), fn_c(x))
|
||||
self.assertEqual(fn(y), fn_c(y))
|
||||
|
||||
def test_arange_dynamic(self, device):
|
||||
def fn(a):
|
||||
batch_size = a.numel()
|
||||
|
@ -4591,9 +4591,10 @@ else:
|
||||
# FIXME: move to test distributions
|
||||
@deviceCountAtLeast(2)
|
||||
@onlyCUDA
|
||||
@skipIfTorchInductor("FIXME: error not thrown")
|
||||
def test_multinomial_gpu_device_constrain(self, devices):
|
||||
x = torch.empty(3, device=devices[0])
|
||||
y = torch.empty(3, device=devices[1])
|
||||
y = torch.empty(3, device=devices[1], dtype=torch.long)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "Expected all tensors to be on the same device",
|
||||
lambda: torch.multinomial(x, 2, out=y))
|
||||
|
@ -60,6 +60,13 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
# is the output
|
||||
self.user_to_last_uses = self.node_to_last_non_output_use()
|
||||
|
||||
def _support_dynamic_shape(self):
|
||||
# ConstantFolder not support dynamic shape now
|
||||
return False
|
||||
|
||||
def _deduce_value(self, node):
|
||||
return super().run_node(node)
|
||||
|
||||
def is_impure(self, node: torch.fx.node.Node):
|
||||
if (
|
||||
node.target == torch.ops.prims.convert_element_type.default
|
||||
@ -159,7 +166,9 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
):
|
||||
return self.unknown_value
|
||||
|
||||
out = super().run_node(node)
|
||||
out = self._deduce_value(node)
|
||||
if out == self.unknown_value:
|
||||
return self.unknown_value
|
||||
|
||||
if node.op != "get_attr" and isinstance(out, torch.Tensor):
|
||||
if out.device.type == "meta":
|
||||
@ -194,10 +203,13 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
self.node_replacements[node] = tensor
|
||||
|
||||
def run(self):
|
||||
env = {}
|
||||
env: Dict[torch.fx.Node, Any] = {}
|
||||
self.insert_placerholder_values(env)
|
||||
return super().run(initial_env=env)
|
||||
|
||||
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
||||
for n in self.module.graph.find_nodes(op="placeholder"):
|
||||
env[n] = self.unknown_value
|
||||
return super().run(initial_env=env)
|
||||
|
||||
|
||||
@torch.utils._python_dispatch._disable_current_modes()
|
||||
|
@ -3,11 +3,13 @@ import itertools
|
||||
import logging
|
||||
import typing
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Set, Union
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
|
||||
import torch
|
||||
import torch._guards
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.constant_folding import ConstantFolder
|
||||
from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
@ -201,21 +203,116 @@ class UniformValueConstantFolder(ConstantFolder):
|
||||
# see: [constant folding refining of symints]
|
||||
self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
|
||||
|
||||
# initialize symint -> node mapping so that we can
|
||||
# use symint nodes in full constructors
|
||||
self.symint_nodes = _SymHashingDict()
|
||||
for n in self.module.graph.nodes:
|
||||
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
|
||||
self.symint_nodes[n.meta["val"]] = n
|
||||
|
||||
# reference from torch/_funtorch/partitioners.py:get_default_op_list
|
||||
self.view_op_packets = [
|
||||
aten.squeeze,
|
||||
aten.unsqueeze,
|
||||
aten.alias,
|
||||
aten.view,
|
||||
aten.slice,
|
||||
aten.t,
|
||||
prims.broadcast_in_dim,
|
||||
aten.expand,
|
||||
aten.as_strided,
|
||||
aten.permute,
|
||||
]
|
||||
|
||||
self.indexing_op_packets = {
|
||||
aten.slice,
|
||||
}
|
||||
|
||||
def _support_dynamic_shape(self):
|
||||
return True
|
||||
|
||||
def insertable_tensor_check(self, t: torch.Tensor) -> bool:
|
||||
# TODO - we could also Tensors which get replaced with arange here
|
||||
return (
|
||||
t.numel() != 0
|
||||
and bool((t == t.flatten()[0]).all())
|
||||
and torch._C._has_storage(t)
|
||||
and t.layout == torch.strided
|
||||
)
|
||||
return True
|
||||
|
||||
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
||||
self.node_replacements[node] = tensor.flatten()[0].item()
|
||||
self.node_replacements_shapes[node] = node.meta["val"].shape
|
||||
self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
|
||||
shape = list(tensor.shape)
|
||||
assert all(type(dim) is int for dim in shape)
|
||||
self.node_replacements_shapes[node] = shape
|
||||
|
||||
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
||||
for n in self.module.graph.find_nodes(op="placeholder"):
|
||||
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
|
||||
env[n] = n.meta["val"]
|
||||
else:
|
||||
env[n] = self.unknown_value
|
||||
|
||||
def _deduce_value(self, node: torch.fx.Node):
|
||||
# deduce value for full-like nodes
|
||||
# 1. for constructors, substitute value is a tensor of size [1]
|
||||
# 2. for view ops/indexing, substitute value is the same as the input
|
||||
# 3. for pointwise ops, run node to get the substitute value
|
||||
# 4. deal with some special ops
|
||||
# otherwise, stop deduce value and return unknown value
|
||||
|
||||
# TODO: cat, more indexing
|
||||
# TODO - do on cpu to avoid syncs
|
||||
|
||||
# single-elem attrs
|
||||
if node.op == "get_attr" or (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.lift_fresh_copy.default
|
||||
):
|
||||
out = super(ConstantFolder, self).run_node(node)
|
||||
if isinstance(out, torch.Tensor) and out.numel() == 1:
|
||||
return out
|
||||
|
||||
# constructors ops
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == aten.full.default
|
||||
and len(node.args) == 2
|
||||
):
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
new_args = [[1], args[1]]
|
||||
return aten.full.default(*new_args, **node.kwargs)
|
||||
|
||||
# handle before view ops because this changes value
|
||||
if node.target == aten.view.dtype:
|
||||
return super(ConstantFolder, self).run_node(node)
|
||||
|
||||
# view ops, return input tensor, the first argument
|
||||
if hasattr(node.target, "overloadpacket") and (
|
||||
node.target.overloadpacket in self.view_op_packets
|
||||
or node.target.overloadpacket in self.indexing_op_packets
|
||||
):
|
||||
assert isinstance(node.args[0], torch.fx.Node)
|
||||
return self.env[node.args[0]]
|
||||
|
||||
# we don't want to return unknown value for symints so that we can
|
||||
# still constant fold through their use in constructors or views
|
||||
# if we see them in a pointwise node (e.g., tensor * symint)
|
||||
# we will bail
|
||||
if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
|
||||
return node.meta["val"]
|
||||
|
||||
# pointwise ops
|
||||
if isinstance(node.target, torch._ops.OpOverload) and (
|
||||
torch.Tag.pointwise in node.target.tags
|
||||
or node.target is torch.ops.aten.scalar_tensor.default
|
||||
):
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
||||
|
||||
if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
|
||||
return self.unknown_value
|
||||
|
||||
# we run the ops with dim 1, so remove memory_format to avoid error
|
||||
kwargs = dict(kwargs)
|
||||
kwargs.pop("memory_format", None)
|
||||
|
||||
return node.target(*args, **kwargs)
|
||||
|
||||
return self.unknown_value
|
||||
|
||||
|
||||
@torch.utils._python_dispatch._disable_current_modes()
|
||||
@ -262,6 +359,10 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
|
||||
if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
continue
|
||||
|
||||
# TODO - not sure about lossy uint->python value->uint conversions
|
||||
if fake_tensor.dtype in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
|
||||
continue
|
||||
|
||||
if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
|
||||
continue
|
||||
|
||||
@ -280,10 +381,24 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
|
||||
):
|
||||
torch._check(runtime_size == compile_time_size)
|
||||
|
||||
# replace SymInt as Node before creating a new full node
|
||||
# e.g. (1, s0) -> (1, arg0_1)
|
||||
node_shape = node_replacements_shapes[node]
|
||||
if not all(
|
||||
not isinstance(s, torch.SymInt) or s in cf.symint_nodes
|
||||
for s in node_shape
|
||||
):
|
||||
continue
|
||||
|
||||
shapes = [
|
||||
cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
|
||||
for s in node_replacements_shapes[node]
|
||||
]
|
||||
|
||||
# zeros and ones just get traced into full, so we insert those
|
||||
new_node = graph.call_function(
|
||||
aten.full.default,
|
||||
args=(node_replacements_shapes[node], value),
|
||||
args=(shapes, value),
|
||||
kwargs={
|
||||
"dtype": fake_tensor.dtype,
|
||||
"layout": torch.strided,
|
||||
|
Reference in New Issue
Block a user