[export] stop gap strict export v2 enable and testing. (#167236)

Summary:
Added a new flag called "use_legacy_dynamo_graph_capture" which defaults to True and only False with the updated test_strict_export_v2.py

In addiotion to this flag, we also use legacy tracer when the following features are used:
1. dynamic shape
2. preserve module call signature
3. retracing.
4. draft mode.

Test Plan:
test_strict_export_v2.py

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167236
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2025-11-12 03:33:37 +00:00
committed by PyTorch MergeBot
parent 05c6a06b2b
commit 3260bf3b19
7 changed files with 108 additions and 8 deletions

View File

@ -590,6 +590,7 @@ class TestExport(TestCase):
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
self._test_export_same_as_eager(f, inp)
@testing.expectedFailureStrictV2
@skipIfCrossRef
def test_custom_tag_metadata_re_export(self):
class Foo(torch.nn.Module):
@ -1026,6 +1027,7 @@ graph():
dynamic_shapes = {"x": (dim0_x, dim1_x)}
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
@testing.expectedFailureStrictV2
def test_no_tensor_computation(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1361,6 +1363,7 @@ def forward(self, primals, tangents):
# instead of the scripted function, so we get x.sin()
self.assertEqual(res, x.sin())
@testing.expectedFailureStrictV2
def test_no_tensor_computation_2(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1379,6 +1382,7 @@ graph():
return (x,)""",
)
@testing.expectedFailureStrictV2
def test_no_tensor_computation_3(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1397,6 +1401,7 @@ graph():
return (5,)""",
)
@testing.expectedFailureStrictV2
def test_no_tensor_computation_4(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1939,6 +1944,7 @@ graph():
for vr_upper in vr_upper_bounds:
self.assertEqual(vr_upper, 1)
@testing.expectedFailureStrictV2
def test_detect_leak_strict(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2687,6 +2693,7 @@ class GraphModule(torch.nn.Module):
gm = export(m, (torch.rand(64, 64),))
torch.export.unflatten(gm)
@testing.expectedFailureStrictV2
def test_unflatten_closure(self):
class Dummy(torch.nn.Module):
def forward(self, fn, x):
@ -4192,6 +4199,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
if str(sym) in ["u0", "s0"]:
self.assertEqual(vr.lower, 1)
@testing.expectedFailureStrictV2
def test_duplicate_modules_with_non_persistent_buffers(self):
class FooWithBuf(torch.nn.Module):
def __init__(self):
@ -4835,6 +4843,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
table.materialize()
self.assertFalse(torch.ops.mylib.foo123.default in table)
@testing.expectedFailureStrictV2
def test_if_post_autograd_op_preserved(self):
class Foo(torch.nn.Module):
def forward(self, x):
@ -7223,6 +7232,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
@testing.expectedFailureSerDer # we don't save placeholder metadata
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureStrictV2
def test_linear_conv(self):
strict = True
@ -8821,6 +8831,7 @@ def forward(self, x):
)
)
@testing.expectedFailureStrictV2
def test_automatic_constrain_size(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -8932,6 +8943,7 @@ def forward(self, x):
):
ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1))
@testing.expectedFailureStrictV2
def test_constrain_decomp(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
@ -9570,6 +9582,7 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
@requires_cuda_and_triton
@testing.expectedFailureStrictV2
def test_export_associative_scan_lifted_buffers(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
@ -9660,6 +9673,7 @@ def forward(self, b_a_buffer, x):
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
)
@testing.expectedFailureStrictV2
def test_no_check_is_size_error(self):
class Module(torch.nn.Module):
def forward(self, x):
@ -9813,6 +9827,7 @@ def forward(self, b_a_buffer, x):
self.assertEqual(len(ep.graph_signature.input_specs), 4)
self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp)))
@testing.expectedFailureStrictV2
def test_tensor_attribute_zero_args(self):
class Foo(torch.nn.Module):
def __init__(self, value):
@ -9826,6 +9841,7 @@ def forward(self, b_a_buffer, x):
ep = export(m, ())
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
@testing.expectedFailureStrictV2
def test_preserve_shape_dynamism_for_unused_inputs(self):
torch.export.register_dataclass(
Inp3,
@ -9995,6 +10011,7 @@ def forward(self, p_lin_weight, p_lin_bias, x):
)
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
@testing.expectedFailureStrictV2
def test_export_decomp_torture_case_2(self):
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
@ -10130,6 +10147,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
# expected 4, but got 7
ep_v2.module()(*test_inp)
@testing.expectedFailureStrictV2
def test_constant_output(self):
class ModuleConstant(torch.nn.Module):
def __init__(self) -> None:
@ -10214,6 +10232,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
# expected >= 3, but got 2
ep.module()(*test_inp)
@testing.expectedFailureStrictV2
def test_nested_module(self):
class M1(torch.nn.Module):
def forward(self, x):
@ -10251,6 +10270,7 @@ graph():
unflattened = unflatten(ep)
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
@testing.expectedFailureStrictV2
def test_nested_module_with_init_buffer(self):
class M1(torch.nn.Module):
def __init__(self) -> None:
@ -10378,6 +10398,7 @@ graph():
ep = export(m, sample_inputs)
self.assertEqual(ep.module()(*sample_inputs), m(*sample_inputs))
@testing.expectedFailureStrictV2
def test_lazy_module_kwargs(self):
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
def initialize_parameters(self, *args, **kwargs):
@ -12251,6 +12272,7 @@ graph():
ep.module()(x)
@testing.expectedFailureCppRuntime
@testing.expectedFailureStrictV2
def test_symint_input_basic(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -12970,6 +12992,7 @@ def forward(self, c_submod_params, x):
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))
@testing.expectedFailureStrictV2
def test_unflatten_multiple_graphs_shared_submodule(self):
class N(torch.nn.Module):
def forward(self, x, b):
@ -14021,6 +14044,7 @@ def forward(self, x):
return (foo_functional,)""",
)
@testing.expectedFailureStrictV2
def test_placeholder_naming_order(self):
# See https://github.com/pytorch/pytorch/issues/143732
@ -14072,6 +14096,7 @@ def forward(self, x):
).run_decompositions()
ep.module()(torch.ones(4, 4), **kwargs)
@testing.expectedFailureStrictV2
def test_placeholder_naming_order_variadic(self):
class Mod(torch.nn.Module):
def forward(self, a, b, c, **kwargs):
@ -14096,6 +14121,7 @@ def forward(self, x):
):
export(Foo(), (torch.randn(4, 4),), strict=False)
@testing.expectedFailureStrictV2
def test_placeholder_naming_collisions(self):
# test collisions between nested user inputs
class Foo(torch.nn.Module):
@ -14168,6 +14194,7 @@ def forward(self, x):
self.assertEqual(expected_names_and_ops, real_names_and_ops)
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
@testing.expectedFailureStrictV2
def test_placeholder_naming_collisions_hoo_subgraphs(self):
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
class Foo(torch.nn.Module):
@ -14245,6 +14272,7 @@ def forward(self, x):
]
self.assertEqual(expected_getattr_names, real_getattr_names)
@testing.expectedFailureStrictV2
def test_constant_input_naming(self):
class Foo(torch.nn.Module):
def forward(self, x, y, div="floor"):
@ -14936,6 +14964,7 @@ graph():
]
self.assertEqual(len(repeat_nodes), 0)
@testing.expectedFailureStrictV2
def test_checks_to_constrain_range(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
@ -15270,6 +15299,7 @@ graph():
Block(torch.randn(4, 4), torch.randn(4, 4))
)
@testing.expectedFailureStrictV2
def test_enum_str(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"
@ -15431,6 +15461,7 @@ def forward(self, x):
return (getitem_3, cos_1)""",
)
@testing.expectedFailureStrictV2
def test_run_decompositions_keep_metadata(self):
"""Make sure the metadata is kept after exported program run_decompositions."""
@ -15460,6 +15491,7 @@ def forward(self, x):
for node in decomposed_program.graph.nodes:
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
@testing.expectedFailureStrictV2
def test_run_decompositions_keep_tensor_constant_metadata(self):
"""Make sure the metadata of tensor constants are kept after run_decompositions."""
@ -16091,6 +16123,7 @@ def forward(self, x):
@testing.expectedFailureSerDer # T195866111
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureStrictV2
def test_hints_wrapper(self):
strict = True
@ -16665,6 +16698,7 @@ def forward(self, args_0):
return (abs_1,)""",
)
@testing.expectedFailureStrictV2
def test_sdpa_gqa(self):
from torch.nn.attention import sdpa_kernel, SDPBackend

View File

@ -15,7 +15,7 @@ test_classes = {}
def mocked_strict_export_v2(*args, **kwargs):
# If user already specified strict, don't make it strict
with config.patch(use_new_tracer_experimental=True):
with config.patch(use_legacy_dynamo_graph_capture=False):
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True)

View File

@ -1001,10 +1001,24 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
import inspect
if isinstance(mod, torch.nn.Module):
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
# Mirrored from NNModuleVariable.call_function:
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
if (
len(mod._forward_pre_hooks) == 0
and len(mod._forward_hooks) == 0
and len(torch.nn.modules.module._global_forward_pre_hooks) == 0
and len(torch.nn.modules.module._global_forward_hooks) == 0
and len(mod._backward_pre_hooks) == 0
and len(mod._backward_hooks) == 0
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
and len(torch.nn.modules.module._global_backward_hooks) == 0
):
mod = mod.forward
elif isinstance(mod, torch.fx.GraphModule):
mod = mod._call_impl
else:
mod = mod.__call__
if hasattr(mod, "__self__"):
# pyrefly: ignore [missing-attribute]
return mod.__func__, mod.__self__

View File

@ -637,7 +637,7 @@ def dynamo_graph_capture_for_export(
pyt.in_shuffle_graph,
pyt.out_shuffle_graph,
tree_leaf_names,
pyt.root,
graph_module if isinstance(pyt.root, torch.nn.Module) else pyt.root,
) # type: ignore[attr-defined]
normalize_graph_module(graph_module)
if pyt.root is not None:
@ -648,6 +648,10 @@ def dynamo_graph_capture_for_export(
graph_module._non_persistent_buffers_set = (
pyt.root._non_persistent_buffers_set.copy()
)
annotations = torch.nn.Module.__dict__.get("__annotations__", None)
for name, value in pyt.root.__dict__.items():
if annotations and name not in annotations:
graph_module.__dict__[name] = value
graph_module._in_spec = pyt.in_spec
graph_module._out_spec = pyt.out_spec
assert not hasattr(graph_module, "_in_shuffle_graph")

View File

@ -33,6 +33,9 @@ error_on_lifted_constant_tensors = True
# being ready to handle auto_functionalized_v2.
enable_auto_functionalized_v2_for_export = not is_fbcode()
use_legacy_dynamo_graph_capture = True
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -12,6 +12,7 @@ from collections.abc import Callable
from contextlib import contextmanager, ExitStack, nullcontext
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
from unittest import mock
if TYPE_CHECKING:
@ -274,6 +275,24 @@ def _extract_fake_inputs(gm, args, kwargs):
else:
fake_vals.append(node.meta.get("example_value"))
if in_shuffle_graph := getattr(gm, "_in_shuffle_graph", None):
flat_args = pytree.tree_leaves((args, kwargs))
node_map = {
node: i
for i, node in enumerate(
next(iter(reversed(in_shuffle_graph.graph.nodes))).args[0]
)
if node.op == "placeholder"
}
new_fake_inps: list[Any] = []
for i, node in enumerate(
in_shuffle_graph.graph.find_nodes(op="placeholder")[1:]
):
if node in node_map:
new_fake_inps.append(fake_inps[node_map[node]])
else:
new_fake_inps.append(flat_args[i])
fake_inps = new_fake_inps
# We get both because now we might have a combination of symint and tensor
# inputs, and we want to check that the shape env is consistent between
# both. Unfortunately we can't see what fake mode is attached to the shape
@ -798,6 +817,16 @@ def _export_to_torch_ir(
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
def use_legacy_dynamo_graph_capture() -> bool:
return bool(
constraints # dynamic shape
or dynamic_shapes # dynamic shape
or isinstance(f, torch.fx.GraphModule) # retracing
or preserve_module_call_signature # unflatten
or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
or torch._export.config.use_legacy_dynamo_graph_capture
)
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
try:
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
@ -812,11 +841,20 @@ def _export_to_torch_ir(
if torch._export.config.use_new_tracer_experimental:
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
gm_torch_level = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)(*args, **kwargs)
if use_legacy_dynamo_graph_capture():
dynamo_graph_capture = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)
else:
dynamo_graph_capture = dynamo_graph_capture_for_export(f)
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
# Once we delete the old strict export, we can use
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
@ -1568,7 +1606,11 @@ def _strict_export(
}
tx = TracingContext(dynamo_fake_mode)
with dynamo_fake_mode, tracing(tx):
with (
dynamo_fake_mode,
tracing(tx),
mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
):
aten_export_artifact = _to_aten_func(
gm_torch_level,
# NOTE: graph module expects only positional args

View File

@ -1709,8 +1709,11 @@ def _convert_guards_to_code(graph_module):
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
)
return [
ret = [
py_printer.doprint(guard.expr)
for guard in shape_env.guards
if guard.expr.free_symbols.issubset(local_vars)
]
# TODO Figure out how to resolve guards containing weight sizes.
# This is not a big deal as _guards_code is mostly empty today.
return [guard for guard in ret if "L['self']" not in guard]