mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix inconsistent test and add new tracer as config (#162558)
It is better to have the new tracer as global config that can be manipulated easily. Also I believe dynamo-like config infra is useful instead of relying on custom way of patching stuff. Differential Revision: [D82478649](https://our.internmc.facebook.com/intern/diff/D82478649) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162558 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162557
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e9e3cf996
commit
0e9f9c3a61
@ -29,6 +29,7 @@ from torch._decomp import decomposition_table, get_decompositions
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import mod_index
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._export import config
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
from torch._export.utils import (
|
||||
get_buffer,
|
||||
@ -1727,22 +1728,12 @@ class GraphModule(torch.nn.Module):
|
||||
trigger = 0
|
||||
target = 2
|
||||
args = (x, trigger, target)
|
||||
ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC))
|
||||
if is_training_ir_strict_test(self._testMethodName):
|
||||
# In strict mode export's result capturing compiler, we create
|
||||
# 2 new symints when re-fakifying the symint inputs.
|
||||
# Then in run_decompositions, ep.range_constraints was updated
|
||||
# where it checks the var_to_range and put the two newly added ones into the range_constraints.
|
||||
self.assertExpectedInline(
|
||||
str(tuple(ep.range_constraints.values())),
|
||||
"""(VR[0, int_oo], VR[0, int_oo], VR[-int_oo, int_oo], VR[-int_oo, int_oo])""",
|
||||
)
|
||||
else:
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC))
|
||||
self.assertExpectedInline(
|
||||
str(tuple(ep.range_constraints.values())),
|
||||
"""(VR[0, int_oo], VR[0, int_oo])""",
|
||||
)
|
||||
|
||||
self.assertEqual(m(*args), ep.module()(*args))
|
||||
|
||||
def test_cond_branches_return_same_int(self):
|
||||
@ -13812,14 +13803,14 @@ def forward(self, x, y):
|
||||
inputs = (torch.randn(10, 72),)
|
||||
dx, dy = dims("dx", "dy")
|
||||
for use_new_tracer in [True, False]:
|
||||
ep = torch.export._trace._export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
pre_dispatch=True,
|
||||
_use_new_tracer_experimental=use_new_tracer,
|
||||
)
|
||||
with torch._export.config.patch(use_new_tracer_experimental=use_new_tracer):
|
||||
ep = torch.export._trace._export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
pre_dispatch=True,
|
||||
)
|
||||
out1 = ep.module()(torch.randn(8, 7))
|
||||
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
|
||||
out2 = ep.module()(torch.randn(12, 11))
|
||||
|
@ -6,6 +6,7 @@ except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch._export import config
|
||||
from torch.export import export
|
||||
|
||||
|
||||
@ -14,12 +15,10 @@ test_classes = {}
|
||||
|
||||
def mocked_strict_export_v2(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
if kwargs["strict"]:
|
||||
return export(*args, **kwargs, _use_new_tracer_experimental=True)
|
||||
else:
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
|
@ -9,6 +9,7 @@ from parameterized import parameterized_class
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch import Tensor
|
||||
from torch._export import config
|
||||
from torch._export.utils import register_dataclass_as_pytree_node
|
||||
from torch.export import export, register_dataclass
|
||||
from torch.export._swap import _swap_modules
|
||||
@ -378,13 +379,13 @@ def forward(self, x, y):
|
||||
return x + torch.matmul(inputs.a, inputs.b)
|
||||
|
||||
for use_new_tracer in [True, False]:
|
||||
ep = export(
|
||||
Foo(),
|
||||
(torch.randn(2, 2),),
|
||||
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
||||
strict=self.strict,
|
||||
_use_new_tracer_experimental=use_new_tracer,
|
||||
)
|
||||
with config.patch(use_new_tracer_experimental=use_new_tracer):
|
||||
ep = export(
|
||||
Foo(),
|
||||
(torch.randn(2, 2),),
|
||||
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
||||
strict=self.strict,
|
||||
)
|
||||
swapped = _swap_modules(ep, {})
|
||||
inp_args = (torch.randn(2, 2),)
|
||||
inp_kwargs = {"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))}
|
||||
@ -408,13 +409,13 @@ def forward(self, x, y):
|
||||
return x + torch.matmul(inputs.a, inputs.b)
|
||||
|
||||
# shouldn't error
|
||||
_ = export(
|
||||
Foo(),
|
||||
(torch.randn(2, 2),),
|
||||
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
||||
strict=self.strict,
|
||||
_use_new_tracer_experimental=True,
|
||||
)
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
_ = export(
|
||||
Foo(),
|
||||
(torch.randn(2, 2),),
|
||||
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
||||
strict=self.strict,
|
||||
)
|
||||
|
||||
def test_custom_output(self):
|
||||
@dataclass
|
||||
|
@ -47,6 +47,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from .wrappers import _wrap_submodules
|
||||
from .utils import _materialize_cpp_cia_ops
|
||||
from . import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._aoti import AOTIModelContainerRunner
|
||||
@ -65,7 +66,6 @@ class ExportDynamoConfig:
|
||||
# is called multiple times.
|
||||
@lru_cache
|
||||
def aot_compile_warning():
|
||||
from torch._inductor import config
|
||||
|
||||
log.warning("+============================+")
|
||||
log.warning("| !!! WARNING !!! |")
|
||||
@ -124,11 +124,11 @@ def aot_compile(
|
||||
"""
|
||||
from torch.export._trace import _export_to_torch_ir
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch._inductor import config
|
||||
from torch._inductor import config as inductor_config
|
||||
|
||||
aot_compile_warning()
|
||||
|
||||
if config.is_predispatch:
|
||||
if inductor_config.is_predispatch:
|
||||
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
|
||||
else:
|
||||
# We want to export to Torch IR here to utilize the pre_grad passes in
|
||||
|
26
torch/_export/config.py
Normal file
26
torch/_export/config.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Configuration module for torch.export.export.
|
||||
|
||||
This module contains various configuration flags and settings that control torch.export's
|
||||
behavior, including:
|
||||
- Runtime behavior flags
|
||||
- Debugging and development options
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
# this flag controls whether we use new functional tracer. It
|
||||
# should be True in the long term.
|
||||
use_new_tracer_experimental = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
def _make_closure_patcher(**changes: Any) -> Any: ...
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
@ -70,7 +70,6 @@ def export_for_training(
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -160,7 +159,6 @@ def export_for_training(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
|
||||
|
||||
@ -173,7 +171,6 @@ def export(
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -286,7 +283,6 @@ def export(
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
pre_dispatch=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
except Exception as e:
|
||||
draft_export_msg = (
|
||||
|
@ -757,7 +757,6 @@ def _export_to_torch_ir(
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
disable_constraint_solver: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
restore_fqn: bool = True,
|
||||
_log_export_usage: bool = True,
|
||||
same_signature: bool = True,
|
||||
@ -815,7 +814,7 @@ def _export_to_torch_ir(
|
||||
f, preserve_module_call_signature, module_call_specs
|
||||
)
|
||||
with ctx, _ignore_backend_decomps():
|
||||
if _use_new_tracer_experimental:
|
||||
if torch._export.config.use_new_tracer_experimental:
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
)
|
||||
@ -1422,7 +1421,6 @@ def _strict_export(
|
||||
orig_in_spec: TreeSpec,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||
_to_aten_func: Callable,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportArtifact:
|
||||
"""
|
||||
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
|
||||
@ -1437,7 +1435,6 @@ def _strict_export(
|
||||
restore_fqn=False, # don't need to restore because we will do it later
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=False,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
|
||||
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
|
||||
@ -2058,7 +2055,6 @@ def _export_for_training(
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
global _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
@ -2074,13 +2070,7 @@ def _export_for_training(
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = (
|
||||
functools.partial(
|
||||
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
|
||||
)
|
||||
if strict
|
||||
else _non_strict_export
|
||||
)
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
|
||||
alive_fake_input_ids_before_export: list[int] = []
|
||||
|
||||
@ -2209,7 +2199,6 @@ def _export(
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
pre_dispatch: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||
@ -2285,7 +2274,6 @@ def _export(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||
return ep
|
||||
|
Reference in New Issue
Block a user