Fix inconsistent test and add new tracer as config (#162558)

It is better to have the new tracer as global config that can be manipulated easily. Also I believe dynamo-like config infra is useful instead of relying on custom way of patching stuff.

Differential Revision: [D82478649](https://our.internmc.facebook.com/intern/diff/D82478649)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162558
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162557
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-16 14:42:32 -07:00
committed by PyTorch MergeBot
parent 0e9e3cf996
commit 0e9f9c3a61
7 changed files with 61 additions and 60 deletions

View File

@ -29,6 +29,7 @@ from torch._decomp import decomposition_table, get_decompositions
from torch._dynamo._trace_wrapped_higher_order_op import mod_index
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import normalize_gm
from torch._export import config
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.utils import (
get_buffer,
@ -1727,22 +1728,12 @@ class GraphModule(torch.nn.Module):
trigger = 0
target = 2
args = (x, trigger, target)
ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC))
if is_training_ir_strict_test(self._testMethodName):
# In strict mode export's result capturing compiler, we create
# 2 new symints when re-fakifying the symint inputs.
# Then in run_decompositions, ep.range_constraints was updated
# where it checks the var_to_range and put the two newly added ones into the range_constraints.
self.assertExpectedInline(
str(tuple(ep.range_constraints.values())),
"""(VR[0, int_oo], VR[0, int_oo], VR[-int_oo, int_oo], VR[-int_oo, int_oo])""",
)
else:
with config.patch(use_new_tracer_experimental=True):
ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC))
self.assertExpectedInline(
str(tuple(ep.range_constraints.values())),
"""(VR[0, int_oo], VR[0, int_oo])""",
)
self.assertEqual(m(*args), ep.module()(*args))
def test_cond_branches_return_same_int(self):
@ -13812,14 +13803,14 @@ def forward(self, x, y):
inputs = (torch.randn(10, 72),)
dx, dy = dims("dx", "dy")
for use_new_tracer in [True, False]:
ep = torch.export._trace._export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
prefer_deferred_runtime_asserts_over_guards=True,
pre_dispatch=True,
_use_new_tracer_experimental=use_new_tracer,
)
with torch._export.config.patch(use_new_tracer_experimental=use_new_tracer):
ep = torch.export._trace._export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
prefer_deferred_runtime_asserts_over_guards=True,
pre_dispatch=True,
)
out1 = ep.module()(torch.randn(8, 7))
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
out2 = ep.module()(torch.randn(12, 11))

View File

@ -6,6 +6,7 @@ except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch._export import config
from torch.export import export
@ -14,12 +15,10 @@ test_classes = {}
def mocked_strict_export_v2(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
if kwargs["strict"]:
return export(*args, **kwargs, _use_new_tracer_experimental=True)
else:
with config.patch(use_new_tracer_experimental=True):
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
return export(*args, **kwargs, strict=True)
def make_dynamic_cls(cls):

View File

@ -9,6 +9,7 @@ from parameterized import parameterized_class
import torch
import torch._dynamo as torchdynamo
from torch import Tensor
from torch._export import config
from torch._export.utils import register_dataclass_as_pytree_node
from torch.export import export, register_dataclass
from torch.export._swap import _swap_modules
@ -378,13 +379,13 @@ def forward(self, x, y):
return x + torch.matmul(inputs.a, inputs.b)
for use_new_tracer in [True, False]:
ep = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
_use_new_tracer_experimental=use_new_tracer,
)
with config.patch(use_new_tracer_experimental=use_new_tracer):
ep = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
)
swapped = _swap_modules(ep, {})
inp_args = (torch.randn(2, 2),)
inp_kwargs = {"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))}
@ -408,13 +409,13 @@ def forward(self, x, y):
return x + torch.matmul(inputs.a, inputs.b)
# shouldn't error
_ = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
_use_new_tracer_experimental=True,
)
with config.patch(use_new_tracer_experimental=True):
_ = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
)
def test_custom_output(self):
@dataclass

View File

@ -47,6 +47,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules
from .utils import _materialize_cpp_cia_ops
from . import config
if TYPE_CHECKING:
from torch._C._aoti import AOTIModelContainerRunner
@ -65,7 +66,6 @@ class ExportDynamoConfig:
# is called multiple times.
@lru_cache
def aot_compile_warning():
from torch._inductor import config
log.warning("+============================+")
log.warning("| !!! WARNING !!! |")
@ -124,11 +124,11 @@ def aot_compile(
"""
from torch.export._trace import _export_to_torch_ir
from torch._inductor.decomposition import select_decomp_table
from torch._inductor import config
from torch._inductor import config as inductor_config
aot_compile_warning()
if config.is_predispatch:
if inductor_config.is_predispatch:
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
else:
# We want to export to Torch IR here to utilize the pre_grad passes in

26
torch/_export/config.py Normal file
View File

@ -0,0 +1,26 @@
"""
Configuration module for torch.export.export.
This module contains various configuration flags and settings that control torch.export's
behavior, including:
- Runtime behavior flags
- Debugging and development options
"""
import sys
from typing import Any, TYPE_CHECKING
from torch.utils._config_module import install_config_module
# this flag controls whether we use new functional tracer. It
# should be True in the long term.
use_new_tracer_experimental = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
def _make_closure_patcher(**changes: Any) -> Any: ...
install_config_module(sys.modules[__name__])

View File

@ -70,7 +70,6 @@ def export_for_training(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -160,7 +159,6 @@ def export_for_training(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
@ -173,7 +171,6 @@ def export(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -286,7 +283,6 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
except Exception as e:
draft_export_msg = (

View File

@ -757,7 +757,6 @@ def _export_to_torch_ir(
preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
same_signature: bool = True,
@ -815,7 +814,7 @@ def _export_to_torch_ir(
f, preserve_module_call_signature, module_call_specs
)
with ctx, _ignore_backend_decomps():
if _use_new_tracer_experimental:
if torch._export.config.use_new_tracer_experimental:
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
)
@ -1422,7 +1421,6 @@ def _strict_export(
orig_in_spec: TreeSpec,
prefer_deferred_runtime_asserts_over_guards: bool,
_to_aten_func: Callable,
_use_new_tracer_experimental: bool = False,
) -> ExportArtifact:
"""
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
@ -1437,7 +1435,6 @@ def _strict_export(
restore_fqn=False, # don't need to restore because we will do it later
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=False,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
@ -2058,7 +2055,6 @@ def _export_for_training(
strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2074,13 +2070,7 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
# Call the appropriate export function based on the strictness of tracing.
export_func = (
functools.partial(
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
)
if strict
else _non_strict_export
)
export_func = _strict_export if strict else _non_strict_export
alive_fake_input_ids_before_export: list[int] = []
@ -2209,7 +2199,6 @@ def _export(
preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
@ -2285,7 +2274,6 @@ def _export(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep