mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
[export] stop gap strict export v2 enable and testing. (#167236)
Summary: Added a new flag called "use_legacy_dynamo_graph_capture" which defaults to True and only False with the updated test_strict_export_v2.py In addiotion to this flag, we also use legacy tracer when the following features are used: 1. dynamic shape 2. preserve module call signature 3. retracing. 4. draft mode. Test Plan: test_strict_export_v2.py Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/167236 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
05c6a06b2b
commit
3260bf3b19
@ -590,6 +590,7 @@ class TestExport(TestCase):
|
||||
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
|
||||
self._test_export_same_as_eager(f, inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
@skipIfCrossRef
|
||||
def test_custom_tag_metadata_re_export(self):
|
||||
class Foo(torch.nn.Module):
|
||||
@ -1026,6 +1027,7 @@ graph():
|
||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1361,6 +1363,7 @@ def forward(self, primals, tangents):
|
||||
# instead of the scripted function, so we get x.sin()
|
||||
self.assertEqual(res, x.sin())
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_2(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1379,6 +1382,7 @@ graph():
|
||||
return (x,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_3(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1397,6 +1401,7 @@ graph():
|
||||
return (5,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_4(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1939,6 +1944,7 @@ graph():
|
||||
for vr_upper in vr_upper_bounds:
|
||||
self.assertEqual(vr_upper, 1)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_detect_leak_strict(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2687,6 +2693,7 @@ class GraphModule(torch.nn.Module):
|
||||
gm = export(m, (torch.rand(64, 64),))
|
||||
torch.export.unflatten(gm)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_closure(self):
|
||||
class Dummy(torch.nn.Module):
|
||||
def forward(self, fn, x):
|
||||
@ -4192,6 +4199,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
if str(sym) in ["u0", "s0"]:
|
||||
self.assertEqual(vr.lower, 1)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_duplicate_modules_with_non_persistent_buffers(self):
|
||||
class FooWithBuf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -4835,6 +4843,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||
table.materialize()
|
||||
self.assertFalse(torch.ops.mylib.foo123.default in table)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_if_post_autograd_op_preserved(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -7223,6 +7232,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
||||
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_linear_conv(self):
|
||||
strict = True
|
||||
|
||||
@ -8821,6 +8831,7 @@ def forward(self, x):
|
||||
)
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_automatic_constrain_size(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -8932,6 +8943,7 @@ def forward(self, x):
|
||||
):
|
||||
ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constrain_decomp(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -9570,6 +9582,7 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_export_associative_scan_lifted_buffers(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
@ -9660,6 +9673,7 @@ def forward(self, b_a_buffer, x):
|
||||
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_check_is_size_error(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -9813,6 +9827,7 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertEqual(len(ep.graph_signature.input_specs), 4)
|
||||
self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_tensor_attribute_zero_args(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, value):
|
||||
@ -9826,6 +9841,7 @@ def forward(self, b_a_buffer, x):
|
||||
ep = export(m, ())
|
||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp3,
|
||||
@ -9995,6 +10011,7 @@ def forward(self, p_lin_weight, p_lin_bias, x):
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_export_decomp_torture_case_2(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10130,6 +10147,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
# expected 4, but got 7
|
||||
ep_v2.module()(*test_inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constant_output(self):
|
||||
class ModuleConstant(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10214,6 +10232,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
# expected >= 3, but got 2
|
||||
ep.module()(*test_inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_nested_module(self):
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -10251,6 +10270,7 @@ graph():
|
||||
unflattened = unflatten(ep)
|
||||
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_nested_module_with_init_buffer(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10378,6 +10398,7 @@ graph():
|
||||
ep = export(m, sample_inputs)
|
||||
self.assertEqual(ep.module()(*sample_inputs), m(*sample_inputs))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_lazy_module_kwargs(self):
|
||||
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
||||
def initialize_parameters(self, *args, **kwargs):
|
||||
@ -12251,6 +12272,7 @@ graph():
|
||||
ep.module()(x)
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_symint_input_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -12970,6 +12992,7 @@ def forward(self, c_submod_params, x):
|
||||
ufm = torch.export.unflatten(ep)
|
||||
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_multiple_graphs_shared_submodule(self):
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
@ -14021,6 +14044,7 @@ def forward(self, x):
|
||||
return (foo_functional,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_order(self):
|
||||
# See https://github.com/pytorch/pytorch/issues/143732
|
||||
|
||||
@ -14072,6 +14096,7 @@ def forward(self, x):
|
||||
).run_decompositions()
|
||||
ep.module()(torch.ones(4, 4), **kwargs)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_order_variadic(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, a, b, c, **kwargs):
|
||||
@ -14096,6 +14121,7 @@ def forward(self, x):
|
||||
):
|
||||
export(Foo(), (torch.randn(4, 4),), strict=False)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_collisions(self):
|
||||
# test collisions between nested user inputs
|
||||
class Foo(torch.nn.Module):
|
||||
@ -14168,6 +14194,7 @@ def forward(self, x):
|
||||
self.assertEqual(expected_names_and_ops, real_names_and_ops)
|
||||
|
||||
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_collisions_hoo_subgraphs(self):
|
||||
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
|
||||
class Foo(torch.nn.Module):
|
||||
@ -14245,6 +14272,7 @@ def forward(self, x):
|
||||
]
|
||||
self.assertEqual(expected_getattr_names, real_getattr_names)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constant_input_naming(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y, div="floor"):
|
||||
@ -14936,6 +14964,7 @@ graph():
|
||||
]
|
||||
self.assertEqual(len(repeat_nodes), 0)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_checks_to_constrain_range(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -15270,6 +15299,7 @@ graph():
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_enum_str(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
@ -15431,6 +15461,7 @@ def forward(self, x):
|
||||
return (getitem_3, cos_1)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_run_decompositions_keep_metadata(self):
|
||||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||
|
||||
@ -15460,6 +15491,7 @@ def forward(self, x):
|
||||
for node in decomposed_program.graph.nodes:
|
||||
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_run_decompositions_keep_tensor_constant_metadata(self):
|
||||
"""Make sure the metadata of tensor constants are kept after run_decompositions."""
|
||||
|
||||
@ -16091,6 +16123,7 @@ def forward(self, x):
|
||||
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_hints_wrapper(self):
|
||||
strict = True
|
||||
|
||||
@ -16665,6 +16698,7 @@ def forward(self, args_0):
|
||||
return (abs_1,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_sdpa_gqa(self):
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ test_classes = {}
|
||||
|
||||
def mocked_strict_export_v2(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
with config.patch(use_legacy_dynamo_graph_capture=False):
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
@ -1001,10 +1001,24 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
import inspect
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
|
||||
# Mirrored from NNModuleVariable.call_function:
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
|
||||
if (
|
||||
len(mod._forward_pre_hooks) == 0
|
||||
and len(mod._forward_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_forward_pre_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_forward_hooks) == 0
|
||||
and len(mod._backward_pre_hooks) == 0
|
||||
and len(mod._backward_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_hooks) == 0
|
||||
):
|
||||
mod = mod.forward
|
||||
elif isinstance(mod, torch.fx.GraphModule):
|
||||
mod = mod._call_impl
|
||||
else:
|
||||
mod = mod.__call__
|
||||
|
||||
if hasattr(mod, "__self__"):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return mod.__func__, mod.__self__
|
||||
|
||||
@ -637,7 +637,7 @@ def dynamo_graph_capture_for_export(
|
||||
pyt.in_shuffle_graph,
|
||||
pyt.out_shuffle_graph,
|
||||
tree_leaf_names,
|
||||
pyt.root,
|
||||
graph_module if isinstance(pyt.root, torch.nn.Module) else pyt.root,
|
||||
) # type: ignore[attr-defined]
|
||||
normalize_graph_module(graph_module)
|
||||
if pyt.root is not None:
|
||||
@ -648,6 +648,10 @@ def dynamo_graph_capture_for_export(
|
||||
graph_module._non_persistent_buffers_set = (
|
||||
pyt.root._non_persistent_buffers_set.copy()
|
||||
)
|
||||
annotations = torch.nn.Module.__dict__.get("__annotations__", None)
|
||||
for name, value in pyt.root.__dict__.items():
|
||||
if annotations and name not in annotations:
|
||||
graph_module.__dict__[name] = value
|
||||
graph_module._in_spec = pyt.in_spec
|
||||
graph_module._out_spec = pyt.out_spec
|
||||
assert not hasattr(graph_module, "_in_shuffle_graph")
|
||||
|
||||
@ -33,6 +33,9 @@ error_on_lifted_constant_tensors = True
|
||||
# being ready to handle auto_functionalized_v2.
|
||||
enable_auto_functionalized_v2_for_export = not is_fbcode()
|
||||
|
||||
use_legacy_dynamo_graph_capture = True
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from collections.abc import Callable
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
from unittest import mock
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -274,6 +275,24 @@ def _extract_fake_inputs(gm, args, kwargs):
|
||||
else:
|
||||
fake_vals.append(node.meta.get("example_value"))
|
||||
|
||||
if in_shuffle_graph := getattr(gm, "_in_shuffle_graph", None):
|
||||
flat_args = pytree.tree_leaves((args, kwargs))
|
||||
node_map = {
|
||||
node: i
|
||||
for i, node in enumerate(
|
||||
next(iter(reversed(in_shuffle_graph.graph.nodes))).args[0]
|
||||
)
|
||||
if node.op == "placeholder"
|
||||
}
|
||||
new_fake_inps: list[Any] = []
|
||||
for i, node in enumerate(
|
||||
in_shuffle_graph.graph.find_nodes(op="placeholder")[1:]
|
||||
):
|
||||
if node in node_map:
|
||||
new_fake_inps.append(fake_inps[node_map[node]])
|
||||
else:
|
||||
new_fake_inps.append(flat_args[i])
|
||||
fake_inps = new_fake_inps
|
||||
# We get both because now we might have a combination of symint and tensor
|
||||
# inputs, and we want to check that the shape env is consistent between
|
||||
# both. Unfortunately we can't see what fake mode is attached to the shape
|
||||
@ -798,6 +817,16 @@ def _export_to_torch_ir(
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
|
||||
def use_legacy_dynamo_graph_capture() -> bool:
|
||||
return bool(
|
||||
constraints # dynamic shape
|
||||
or dynamic_shapes # dynamic shape
|
||||
or isinstance(f, torch.fx.GraphModule) # retracing
|
||||
or preserve_module_call_signature # unflatten
|
||||
or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
|
||||
or torch._export.config.use_legacy_dynamo_graph_capture
|
||||
)
|
||||
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
|
||||
try:
|
||||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
|
||||
@ -812,11 +841,20 @@ def _export_to_torch_ir(
|
||||
if torch._export.config.use_new_tracer_experimental:
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
|
||||
gm_torch_level = _dynamo_graph_capture_for_export(
|
||||
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||
)(*args, **kwargs)
|
||||
if use_legacy_dynamo_graph_capture():
|
||||
dynamo_graph_capture = _dynamo_graph_capture_for_export(
|
||||
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
else:
|
||||
dynamo_graph_capture = dynamo_graph_capture_for_export(f)
|
||||
# We can't serialize entire fake mode yet, so this is to make sure
|
||||
# things like copy.deepcopy(ep.graph_module) not crash.
|
||||
# see test_export.py::test_custom_tag_metadata_re_export
|
||||
# Once we delete the old strict export, we can use
|
||||
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
|
||||
# We can't serialize entire fake mode yet, so this is to make sure
|
||||
# things like copy.deepcopy(ep.graph_module) not crash.
|
||||
# see test_export.py::test_custom_tag_metadata_re_export
|
||||
@ -1568,7 +1606,11 @@ def _strict_export(
|
||||
}
|
||||
|
||||
tx = TracingContext(dynamo_fake_mode)
|
||||
with dynamo_fake_mode, tracing(tx):
|
||||
with (
|
||||
dynamo_fake_mode,
|
||||
tracing(tx),
|
||||
mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
|
||||
):
|
||||
aten_export_artifact = _to_aten_func(
|
||||
gm_torch_level,
|
||||
# NOTE: graph module expects only positional args
|
||||
|
||||
@ -1709,8 +1709,11 @@ def _convert_guards_to_code(graph_module):
|
||||
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
|
||||
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
|
||||
)
|
||||
return [
|
||||
ret = [
|
||||
py_printer.doprint(guard.expr)
|
||||
for guard in shape_env.guards
|
||||
if guard.expr.free_symbols.issubset(local_vars)
|
||||
]
|
||||
# TODO Figure out how to resolve guards containing weight sizes.
|
||||
# This is not a big deal as _guards_code is mostly empty today.
|
||||
return [guard for guard in ret if "L['self']" not in guard]
|
||||
|
||||
Reference in New Issue
Block a user