mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Propogate dynamo shape_env to make_fx (#96437)
Currently, when we use assume_static_by_default flag, dynamo won't produce any symbols for input tensors. But when we pass the dynamo generated graph onto make_fx via torchdynamo.export(aten_graph=True), there is no way to pass this flag. We enable this by directly passing the fake tensors dynamo used to make_fx and call make_fx with "real" mode with fake tensors from dynamo. Note that this is modified version of (https://github.com/pytorch/pytorch/pull/96143) Differential Revision: [D44561753](https://our.internmc.facebook.com/intern/diff/D44561753) Pull Request resolved: https://github.com/pytorch/pytorch/pull/96437 Approved by: https://github.com/jansel, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0eab3ab51e
commit
75ac6fdcdd
@ -185,6 +185,7 @@ def _has_potential_branch_input_alias(branch, inputs):
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*inputs)
|
||||
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond is
|
||||
# functionalized
|
||||
|
@ -15,7 +15,10 @@ from functorch.experimental.control_flow import cond
|
||||
from torch._dynamo import config
|
||||
from torch._export import dynamic_dim
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
is_concrete_int,
|
||||
)
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@ -1631,6 +1634,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(gm(*inp), f(*inp))
|
||||
|
||||
@config.patch(assume_static_by_default=False)
|
||||
def test_export_symbolic_shape(self):
|
||||
def f(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0] * 2)
|
||||
@ -1645,7 +1649,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertTrue(has_sym_size)
|
||||
|
||||
@config.patch(dynamic_shapes=True)
|
||||
@config.patch(dynamic_shapes=True, assume_static_by_default=False)
|
||||
def test_dynamic_slicing(self):
|
||||
def f(x):
|
||||
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
|
||||
@ -2296,6 +2300,28 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
f, torch.randn(5, 6), aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
@config.patch(assume_static_by_default=True, dynamic_shapes=True)
|
||||
def test_propagate_assume_static_by_default(self):
|
||||
def f(x):
|
||||
if x.shape[0] > 3:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
gm, _ = torch._dynamo.export(
|
||||
f, torch.ones(6, 4), aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
val = node.meta.get("val", None)
|
||||
if val is not None:
|
||||
shapes = val.shape
|
||||
# there should no symbols
|
||||
for shape in shapes:
|
||||
self.assertTrue(is_concrete_int(shape))
|
||||
|
||||
# this should be captured as static, as export won't generate any symbols.
|
||||
self.assertEqual(gm(torch.ones(2, 4)), torch.ones(2, 4).sin())
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(ExportTests)
|
||||
|
||||
|
@ -2605,6 +2605,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(buffer_ref, buffer_test))
|
||||
|
||||
@torch._dynamo.config.patch("dynamic_shapes", True)
|
||||
@torch._dynamo.config.patch("assume_static_by_default", False)
|
||||
def test_dynamic_shapes_right_side(self):
|
||||
def f(x):
|
||||
return torch.ones(5 * x.shape[0])
|
||||
@ -2700,6 +2701,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch("dynamic_shapes", True)
|
||||
@torch._dynamo.config.patch("assume_static_by_default", False)
|
||||
def test_tensor_split(self):
|
||||
def f(x):
|
||||
return torch.split(x, x.shape[0] // 2, dim=0)[0]
|
||||
|
@ -220,7 +220,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.ops.aten.convolution.default): 1,
|
||||
ns.call_function(torch.ops.aten.native_batch_norm.default): 1,
|
||||
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 1,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@ -234,7 +234,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
# make sure bn is fused into conv
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.ops.aten.convolution.default): 1,
|
||||
ns.call_function(torch.ops.aten.native_batch_norm.default): 0,
|
||||
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 0,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
|
@ -32,8 +32,8 @@ from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Unio
|
||||
import builtins
|
||||
|
||||
__all__ = [
|
||||
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
|
||||
'set_default_device',
|
||||
'typename', 'is_tensor', 'is_storage',
|
||||
'set_default_tensor_type', 'set_default_device',
|
||||
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
|
||||
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
|
||||
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
|
||||
@ -49,7 +49,7 @@ __all__ = [
|
||||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
'SymBool', 'sym_not',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap'
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
|
@ -52,6 +52,8 @@ from .utils import compile_times
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
|
||||
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
||||
@ -653,6 +655,7 @@ def export(
|
||||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
example_fake_inputs = []
|
||||
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
|
||||
def produce_matching(source_args, candidate_args):
|
||||
@ -698,6 +701,9 @@ def export(
|
||||
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
|
||||
graph = gm
|
||||
|
||||
nonlocal example_fake_inputs
|
||||
example_fake_inputs = example_inputs
|
||||
|
||||
def result_capturing_wrapper(*graph_inputs):
|
||||
nonlocal graph_captured_result
|
||||
nonlocal graph_captured_input
|
||||
@ -717,7 +723,10 @@ def export(
|
||||
):
|
||||
opt_f = optimize_assert(
|
||||
dynamo_normalization_capturing_compiler,
|
||||
hooks=Hooks(guard_export_fn=guard_export_print, guard_fail_fn=None),
|
||||
hooks=Hooks(
|
||||
guard_export_fn=guard_export_print,
|
||||
guard_fail_fn=None,
|
||||
),
|
||||
export=True,
|
||||
export_constraints=constraints,
|
||||
dynamic=(tracing_mode == "symbolic"),
|
||||
@ -781,12 +790,19 @@ def export(
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode=tracing_mode,
|
||||
_allow_non_fake_inputs=True,
|
||||
)(*graph_captured_input)
|
||||
fake_tensor_mode = null_context()
|
||||
for val in example_fake_inputs:
|
||||
if isinstance(val, FakeTensor):
|
||||
fake_tensor_mode = val.fake_mode
|
||||
break
|
||||
|
||||
with enable_python_dispatcher(), fake_tensor_mode:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
)(*example_fake_inputs)
|
||||
|
||||
new_graph = ChangeInputOutputSignature(
|
||||
graph,
|
||||
|
@ -205,6 +205,8 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||
)
|
||||
if config.dynamic_shapes
|
||||
else None,
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
if config.dynamic_shapes:
|
||||
|
@ -17,7 +17,7 @@ def _get_tensor_constant_from_node(node, m):
|
||||
# fuse conv bn weights, inplace modification of the graph_module and graph
|
||||
def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
for n in m.graph.nodes:
|
||||
if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
|
||||
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
|
||||
continue
|
||||
bn_op = n
|
||||
n = bn_op.args[0]
|
||||
@ -39,7 +39,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
|
||||
# bn running variance
|
||||
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
|
||||
bn_eps = bn_op.args[7]
|
||||
bn_eps = bn_op.args[6]
|
||||
|
||||
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
||||
|
||||
|
@ -20,7 +20,12 @@ import weakref
|
||||
import operator
|
||||
from torch.utils._stats import count
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
|
||||
from torch.utils._python_dispatch import (
|
||||
TorchDispatchMode,
|
||||
_pop_mode_temporarily,
|
||||
_get_current_dispatch_mode,
|
||||
)
|
||||
|
||||
from torch._subclasses import FakeTensor
|
||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
|
||||
from torch.fx import Proxy
|
||||
|
@ -47,7 +47,7 @@ from sympy.core.logic import fuzzy_and, fuzzy_or
|
||||
aten = torch._ops.ops.aten # type: ignore[has-type]
|
||||
|
||||
__all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
|
||||
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
|
||||
"method_to_operator", "hint_int", "SYMPY_INTERP",
|
||||
]
|
||||
@ -128,6 +128,24 @@ def has_hint(a):
|
||||
return a.node.has_hint()
|
||||
return True
|
||||
|
||||
def is_concrete_int(a: Union[int, SymInt]):
|
||||
r""" Utility to check if underlying object
|
||||
in SymInt is concrete value. Also returns
|
||||
true if integer is passed in.
|
||||
|
||||
Args:
|
||||
a (SymInt or int): Object to test if it int
|
||||
"""
|
||||
assert isinstance(a, (SymInt, int))
|
||||
|
||||
if isinstance(a, int):
|
||||
return True
|
||||
|
||||
if isinstance(a.node.expr, sympy.core.numbers.Integer):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Returns True if every size dim on the tensor has a hint
|
||||
# TODO: Should this include strides too? For now it doesn't matter,
|
||||
# that's quite an obscure case
|
||||
|
Reference in New Issue
Block a user