Propogate dynamo shape_env to make_fx (#96437)

Currently, when we use assume_static_by_default flag, dynamo won't produce any symbols for input tensors. But when we pass the dynamo generated graph onto make_fx via torchdynamo.export(aten_graph=True), there is no way to pass this flag. We enable this by directly passing the fake tensors dynamo used to make_fx and call make_fx with "real" mode with fake tensors from dynamo.

Note that this is modified version of (https://github.com/pytorch/pytorch/pull/96143)

Differential Revision: [D43994693](https://our.internmc.facebook.com/intern/diff/D43994693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96437
Approved by: https://github.com/jansel, https://github.com/ezyang
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2023-03-29 15:25:02 -07:00
committed by PyTorch MergeBot
parent 7257de6eac
commit 3a22916c7a
8 changed files with 82 additions and 15 deletions

View File

@ -185,6 +185,7 @@ def _has_potential_branch_input_alias(branch, inputs):
"""
try:
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized

View File

@ -15,7 +15,10 @@ from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._export import dynamic_dim
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
is_concrete_int,
)
from torch.testing._internal import common_utils
@ -2281,6 +2284,28 @@ class ExportTests(torch._dynamo.test_case.TestCase):
constraints = [dynamic_dim(y, 0)]
torch._dynamo.export(my_dyn_fn, y, constraints=constraints)
@config.patch(assume_static_by_default=True, dynamic_shapes=True)
def test_propagate_assume_static_by_default(self):
def f(x):
if x.shape[0] > 3:
return x.sin()
return x.cos()
gm, _ = torch._dynamo.export(
f, torch.ones(6, 4), aten_graph=True, tracing_mode="symbolic"
)
for node in gm.graph.nodes:
val = node.meta.get("val", None)
if val is not None:
shapes = val.shape
# there should no symbols
for shape in shapes:
self.assertTrue(is_concrete_int(shape))
# this should be captured as static, as export won't generate any symbols.
self.assertEqual(gm(torch.ones(2, 4)), torch.ones(2, 4).sin())
common_utils.instantiate_parametrized_tests(ExportTests)

View File

@ -32,8 +32,8 @@ from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Unio
import builtins
__all__ = [
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
'set_default_device',
'typename', 'is_tensor', 'is_storage',
'set_default_tensor_type', 'set_default_device',
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
@ -49,7 +49,7 @@ __all__ = [
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap'
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
]
################################################################################

View File

@ -51,6 +51,8 @@ from .utils import compile_times
log = logging.getLogger(__name__)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental import proxy_tensor
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
@ -635,6 +637,7 @@ def export(
graph = None
out_guards = None
graph_captured_input = None
example_fake_inputs = []
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
def produce_matching(source_args, candidate_args):
@ -680,6 +683,9 @@ def export(
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
graph = gm
nonlocal example_fake_inputs
example_fake_inputs = example_inputs
def result_capturing_wrapper(*graph_inputs):
nonlocal graph_captured_result
nonlocal graph_captured_input
@ -699,7 +705,10 @@ def export(
):
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
hooks=Hooks(guard_export_fn=guard_export_print, guard_fail_fn=None),
hooks=Hooks(
guard_export_fn=guard_export_print,
guard_fail_fn=None,
),
export=True,
export_constraints=constraints,
dynamic=(tracing_mode == "symbolic"),
@ -763,12 +772,19 @@ def export(
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode=tracing_mode,
_allow_non_fake_inputs=True,
)(*graph_captured_input)
fake_tensor_mode = null_context()
for val in example_fake_inputs:
if isinstance(val, FakeTensor):
fake_tensor_mode = val.fake_mode
break
with enable_python_dispatcher(), fake_tensor_mode:
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
)(*example_fake_inputs)
new_graph = ChangeInputOutputSignature(
graph,

View File

@ -203,6 +203,8 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
)
if config.dynamic_shapes
else None,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
if config.dynamic_shapes:

View File

@ -17,7 +17,7 @@ def _get_tensor_constant_from_node(node, m):
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
continue
bn_op = n
n = bn_op.args[0]
@ -39,7 +39,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
# bn running variance
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
bn_eps = bn_op.args[7]
bn_eps = bn_op.args[6]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)

View File

@ -20,7 +20,12 @@ import weakref
import operator
from torch.utils._stats import count
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch.utils._python_dispatch import (
TorchDispatchMode,
_pop_mode_temporarily,
_get_current_dispatch_mode,
)
from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
from torch.fx import Proxy

View File

@ -47,7 +47,7 @@ from sympy.core.logic import fuzzy_and, fuzzy_or
aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
"method_to_operator", "hint_int", "SYMPY_INTERP",
]
@ -128,6 +128,24 @@ def has_hint(a):
return a.node.has_hint()
return True
def is_concrete_int(a: Union[int, SymInt]):
r""" Utility to check if underlying object
in SymInt is concrete value. Also returns
true if integer is passed in.
Args:
a (SymInt or int): Object to test if it int
"""
assert isinstance(a, SymInt) or isinstance(a, int)
if isinstance(a, int):
return True
if isinstance(a.node.expr, sympy.core.numbers.Integer):
return True
return False
# Returns True if every size dim on the tensor has a hint
# TODO: Should this include strides too? For now it doesn't matter,
# that's quite an obscure case