testing infra and some fixes (#162183)

This PR is quite large in that it covers most of rough edges in the new strict export flow:

1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162167
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-08 11:34:12 -07:00
committed by PyTorch MergeBot
parent a965f09793
commit d8b6622bb6
10 changed files with 520 additions and 78 deletions

View File

@ -217,12 +217,19 @@ TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict" TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict" CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict" CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
STRICT_EXPORT_V2_SUFFIX = "_strict_export_v2"
# Now default mode is non strict, so original unammended test names # Now default mode is non strict, so original unammended test names
# should be treated as non-strict # should be treated as non-strict
def is_non_strict_test(test_name): def is_non_strict_test(test_name):
return not test_name.endswith(STRICT_SUFFIX) return not test_name.endswith(STRICT_SUFFIX) and not test_name.endswith(
STRICT_EXPORT_V2_SUFFIX
)
def is_strict_v2_test(test_name):
return test_name.endswith(STRICT_EXPORT_V2_SUFFIX)
def is_inline_and_install_strict_test(test_name: str) -> bool: def is_inline_and_install_strict_test(test_name: str) -> bool:
@ -9736,6 +9743,7 @@ graph():
inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)} inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
self.assertEqual(ep.module()(**inputs), m(**inputs)) self.assertEqual(ep.module()(**inputs), m(**inputs))
@testing.expectedFailureStrictV2 # AssertionError: RuntimeError not raised
def test_retrace_pre_autograd(self): def test_retrace_pre_autograd(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -11764,6 +11772,7 @@ graph():
self.assertEqual(ep.module()(3, 5), 8) self.assertEqual(ep.module()(3, 5), 8)
self.assertEqual(ep.module()(5, 4), 9) self.assertEqual(ep.module()(5, 4), 9)
@testing.expectedFailureStrictV2 # ValueError: Found conflicts between user-specified and inferred ranges
def test_dynamic_shapes_bounds(self): def test_dynamic_shapes_bounds(self):
class M(torch.nn.Module): class M(torch.nn.Module):
""" """
@ -12070,6 +12079,8 @@ graph():
test(export(M(), inp)) test(export(M(), inp))
# Preserving signature hook is messing with dynamo tracing
@testing.expectedFailureStrictV2
def test_unflatten_multiple_graphs_state(self): def test_unflatten_multiple_graphs_state(self):
class N(torch.nn.Module): class N(torch.nn.Module):
def __init__(self): def __init__(self):
@ -13688,7 +13699,7 @@ def forward(self, x, y):
inputs = (torch.randn(10, 72),) inputs = (torch.randn(10, 72),)
dx, dy = dims("dx", "dy") dx, dy = dims("dx", "dy")
ep = torch.export.export( ep = torch.export._trace._export(
Mod4Reshape(), Mod4Reshape(),
inputs, inputs,
dynamic_shapes={"x": (dx, dy)}, dynamic_shapes={"x": (dx, dy)},
@ -14536,6 +14547,14 @@ graph():
if is_inline_and_install_strict_test(self._testMethodName): if is_inline_and_install_strict_test(self._testMethodName):
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2") self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2") self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
# This is fine since both of these will be deprecated soon.
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).0"
)
self.assertEqual(
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
else: else:
self.assertEqual( self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2" filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
@ -15374,6 +15393,7 @@ class GraphModule(torch.nn.Module):
) )
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation @testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation
@testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation
@testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation @testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation
@testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation @testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation
@testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet @testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet

View File

@ -0,0 +1,54 @@
# Owner(s): ["oncall: export"]
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_strict_export_v2(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
if kwargs["strict"]:
return export(*args, **kwargs, _use_new_tracer_experimental=True)
else:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
def make_dynamic_cls(cls):
cls_prefix = "StrictExportV2"
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.STRICT_EXPORT_V2_SUFFIX,
mocked_strict_export_v2,
xfail_prop="_expected_failure_strict_v2",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
test_export.TestDynamismExpression,
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -257,6 +257,12 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
return fn return fn
# Controls tests generated in test/export/test_export_strict_v2.py
def expectedFailureStrictV2(fn):
fn._expected_failure_strict_v2 = True
return fn
# Controls tests generated in test/export/test_export_strict.py # Controls tests generated in test/export/test_export_strict.py
def expectedFailureStrict(fn): def expectedFailureStrict(fn):
fn._expected_failure_strict = True fn._expected_failure_strict = True

View File

@ -258,6 +258,7 @@ def aot_compile_fullgraph(
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
output_graph = dynamo_output.tracer_output.output_graph output_graph = dynamo_output.tracer_output.output_graph
assert output_graph is not None assert output_graph is not None
assert backend_input is not None
import_sources = output_graph.import_sources import_sources = output_graph.import_sources
with ( with (
torch._guards.tracing(TracingContext(backend_input.fake_mode)), torch._guards.tracing(TracingContext(backend_input.fake_mode)),

View File

@ -95,6 +95,7 @@ from .cache_size import (
) )
from .eval_frame import ( from .eval_frame import (
always_optimize_code_objects, always_optimize_code_objects,
Constraint,
dynamo_tls, dynamo_tls,
skip_code, skip_code,
TorchPatcher, TorchPatcher,
@ -894,7 +895,8 @@ class CaptureOutput:
""" """
dynamo_output: DynamoOutput dynamo_output: DynamoOutput
backend_input: BackendInput # BackendInput can be None when dynamo didn't compile any graph (no tensor op)
backend_input: Optional[BackendInput]
@dataclass @dataclass
@ -907,7 +909,10 @@ class FrameInfo:
def fullgraph_capture( def fullgraph_capture(
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False frame: FrameInfo,
*,
constraints: Optional[list[Constraint]] = None,
_is_export_deprecated_do_not_use: bool = False,
) -> CaptureOutput: ) -> CaptureOutput:
""" """
A standalone function which takes a frame and returns dynamo captured graph A standalone function which takes a frame and returns dynamo captured graph
@ -951,6 +956,7 @@ def fullgraph_capture(
frame.closure, frame.closure,
compiler_fn=fullgraph_compiler, compiler_fn=fullgraph_compiler,
export=_is_export_deprecated_do_not_use, export=_is_export_deprecated_do_not_use,
export_constraints=constraints, # type: ignore[arg-type]
one_graph=True, one_graph=True,
restart_reasons=set(), restart_reasons=set(),
) )
@ -966,7 +972,6 @@ def fullgraph_capture(
cur_exn = cur_exn.__cause__ cur_exn = cur_exn.__cause__
raise e.with_traceback(None) from e.__cause__ # User compiler error raise e.with_traceback(None) from e.__cause__ # User compiler error
assert backend_input is not None
return CaptureOutput(dynamo_output, backend_input) return CaptureOutput(dynamo_output, backend_input)

View File

@ -1,17 +1,70 @@
import builtins import builtins
import inspect import inspect
import logging
import traceback
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable from typing import Any, Callable, Optional, Union
import sympy
import torch import torch
import torch.fx
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
from torch._dynamo.eval_frame import argument_names from torch._dynamo.eval_frame import argument_names
from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext from torch._guards import compile_context, CompileContext
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
from torch.fx import Node
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
StatelessSymbolicContext,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
log = logging.getLogger(__name__)
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in graph_module.graph.nodes:
if "nn_module_stack" in node.meta:
nn_module_stack = node.meta["nn_module_stack"].copy()
first_key = next(iter(nn_module_stack.keys()))
if "export_root" in first_key:
del nn_module_stack[first_key]
nn_module_stack_corrected = {}
for k, v in nn_module_stack.items():
k_new = "".join(k.split("__export_root"))
child_name, child_class = v
child_name = child_name.replace("._export_root", "")
nn_module_stack_corrected[k_new] = (child_name, child_class)
node.meta["nn_module_stack"] = nn_module_stack_corrected
return graph_module
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
"""Remove export_root artifacts from FX graph in-place"""
# Clean parameter names: L__self____export_root_param -> L__self___param
def clean_name(name) -> str:
return name.replace("__export_root_", "_") if "__export_root_" in name else name
# Update get_attr nodes in-place
for node in graph_module.graph.nodes:
if node.op == "get_attr":
old_target = node.target
new_target = clean_name(old_target)
if new_target != old_target:
node.target = new_target
# Move the parameter to the new name
if hasattr(graph_module, old_target):
param = torch.fx.graph_module._get_attr(graph_module, old_target)
torch.fx.graph_module._set_attr(graph_module, new_target, param)
torch.fx.graph_module._del_attr(graph_module, old_target)
class ModuleToTrace(torch.nn.Module): class ModuleToTrace(torch.nn.Module):
def __init__(self, foo: Any, in_spec: Any) -> None: def __init__(self, foo: Any, in_spec: Any) -> None:
super().__init__() super().__init__()
@ -28,29 +81,214 @@ class ModuleToTrace(torch.nn.Module):
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"]) ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
# mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator"
class DynamoGraphTransformer(torch.fx.Transformer):
"""Graph transformer for dynamo export that flattens inputs/outputs without complex matching."""
def __init__(
self,
module: torch.fx.GraphModule,
flat_inputs: list[Any],
flat_args_dynamic_dims: list[set[int]],
graph_input_order: dict[int, int],
graph_output_map: dict[int, tuple[str, Any]],
fake_mode: Optional[Any] = None,
) -> None:
super().__init__(module)
assert len(flat_args_dynamic_dims) == len(flat_inputs)
self.flat_inputs = flat_inputs
self.flat_args_dynamic_dims = flat_args_dynamic_dims
self.graph_input_order = graph_input_order
self.graph_output_map = graph_output_map
self.fake_mode = fake_mode
# Get original placeholders and output
self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
self.output_node = next(n for n in module.graph.nodes if n.op == "output")
# Create new flattened input placeholders
self.new_input_nodes: dict[int, torch.fx.Node] = {}
self._create_flattened_inputs()
# Iterator for replacing old placeholders
self.old_to_new_mapping = {}
self._create_placeholder_mapping()
def _create_flattened_inputs(self) -> None:
"""Create new placeholder nodes for flattened inputs with proper fake tensors."""
for i in range(len(self.flat_inputs)):
placeholder = super().placeholder(f"arg_{i}", (), {})
# Check if this user input (index i) maps to a graph placeholder
if i in self.graph_input_order:
# graph_input_order[i] gives us which graph placeholder this user input corresponds to
graph_placeholder_idx = self.graph_input_order[i]
if graph_placeholder_idx < len(self.placeholders):
orig_placeholder = self.placeholders[graph_placeholder_idx]
# Copy other metadata but not "val" yet
for key, value in orig_placeholder.meta.items():
if key != "val":
placeholder.node.meta[key] = value
# Always ensure we have proper "val" metadata from fake tensor
if self.fake_mode is not None and isinstance(
self.flat_inputs[i], torch.Tensor
):
placeholder.node.meta["val"] = self.fake_mode.from_tensor(
self.flat_inputs[i],
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[
(
DimDynamic.DYNAMIC
if d in self.flat_args_dynamic_dims[i]
else DimDynamic.STATIC
)
for d in range(len(self.flat_inputs[i].shape))
],
constraint_sizes=[None] * len(self.flat_inputs[i].shape),
),
)
elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case
placeholder.node.meta["val"] = self.flat_inputs[i].val
else:
placeholder.node.meta["val"] = self.flat_inputs[i]
self.new_input_nodes[i] = placeholder
def _create_placeholder_mapping(self) -> None:
"""Create mapping from old placeholders to new ones."""
# graph_input_order maps: user_input_index -> graph_placeholder_index
# We need to create: old_graph_placeholder -> new_user_input_placeholder
for user_input_idx, graph_placeholder_idx in self.graph_input_order.items():
if graph_placeholder_idx < len(self.placeholders):
old_placeholder = self.placeholders[graph_placeholder_idx]
new_placeholder = self.new_input_nodes[user_input_idx]
self.old_to_new_mapping[old_placeholder] = new_placeholder
def placeholder(self, target, args, kwargs) -> Any:
"""Replace old placeholders with new flattened ones."""
# Return the corresponding new placeholder
if self.current_node in self.old_to_new_mapping:
new_arg = self.old_to_new_mapping[self.current_node]
# Copy over additional metadata from current node, but don't overwrite "val"
for key in ["tensor_dict", "example_value", "unbacked_bindings"]:
if key in self.current_node.meta:
new_arg.node.meta[key] = self.current_node.meta[key]
# Only copy "val" if we don't already have a good one
if "val" in self.current_node.meta and "val" not in new_arg.node.meta:
new_arg.node.meta["val"] = self.current_node.meta["val"]
return new_arg
else:
# Shouldn't happen if mapping is correct, but fallback
return super().placeholder(target, args, kwargs)
def output(self, target, args, kwargs) -> Any:
"""Transform output according to graph_output_map."""
original_outputs = args[0]
# Build new output list based on graph_output_map
new_outputs = []
for i in sorted(self.graph_output_map.keys()):
output_type, val = self.graph_output_map[i]
if output_type == "graph_out":
new_outputs.append(original_outputs[val])
elif output_type == "input":
input_idx = val.index
new_outputs.append(self.new_input_nodes[input_idx])
elif output_type == "constant":
new_outputs.append(val)
return super().output(target, (tuple(new_outputs),), {})
def run_node(self, node: Node) -> Any:
"""Run node transformation and preserve metadata."""
self.current_node = node
result = super().run_node(node)
# Copy important metadata
if hasattr(result, "node") and result.node is not node:
for key in ["val", "example_value", "unbacked_bindings"]:
if key in node.meta:
result.node.meta[key] = node.meta[key]
# Preserve node names (except output)
if node.op != "output" and hasattr(node, "name"):
result.node._rename(node.name)
return result
def transform(self) -> torch.fx.GraphModule:
"""Perform the graph transformation and copy module metadata."""
result_gm = super().transform()
# Copy module metadata like the original implementation
if hasattr(self.module, "meta"):
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
"dynamo_flat_name_to_original_fqn"
]
if "dynamo_compile_id" in self.module.meta:
result_gm.meta["dynamo_compile_id"] = self.module.meta[
"dynamo_compile_id"
]
return result_gm
def _dynamo_graph_capture_for_export( def _dynamo_graph_capture_for_export(
mod: torch.nn.Module, mod: Callable[..., Any],
*,
constraints: Optional[list[Constraint]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
) -> Callable[..., torch.fx.GraphModule]: ) -> Callable[..., torch.fx.GraphModule]:
""" """
This is lower level API that is used for export to capture dynamo level Improved dynamo graph capture using transformer approach with proper fake tensor handling.
torch IR.
Notable TODOs: This function creates a capture instance that handles:
1. PyTree flattening/unflattening with proper input ordering
2. Dynamo graph capture with export-specific context
3. FX graph transformation for export compatibility
4. Proper fake tensor metadata preservation
5. Dynamic dimension constraint handling
Notable improvements over manual approach:
- Uses FX Transformer for cleaner graph manipulation
- Properly handles fake tensor metadata and dynamic dimensions
- Preserves all necessary metadata for export
- More robust error handling and edge case management
TODO:
1. Are we actually gonna run the bytecode? 1. Are we actually gonna run the bytecode?
2. Need to attach guards 2. Need to attach guards
""" """
_dynamic_shapes = dynamic_shapes
_constraints = constraints
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule: def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs)) flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
module_to_trace = ModuleToTrace(mod, in_spec) module_to_trace = ModuleToTrace(mod, in_spec)
signature = inspect.signature(module_to_trace.forward) signature = inspect.signature(module_to_trace.forward)
bound_arguments = signature.bind(*flat_inputs) bound_arguments = signature.bind(*flat_inputs)
bound_arguments.apply_defaults() bound_arguments.apply_defaults()
f_locals = {"self": module_to_trace, **bound_arguments.arguments} constraints: Optional[list[Constraint]] = _constraints
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
_dynamic_shapes
)
from . import reset # type: ignore[attr-defined]
reset()
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
frame = FrameInfo( frame = FrameInfo(
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined] module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined] module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
@ -60,7 +298,14 @@ def _dynamo_graph_capture_for_export(
) )
dynamo_config_ctx = torch._dynamo.config.patch( dynamo_config_ctx = torch._dynamo.config.patch(
"log_graph_in_out_metadata", True specialize_int=True,
specialize_float=True,
assume_static_by_default=True,
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=False,
log_graph_in_out_metadata=True,
) )
with ( with (
@ -69,74 +314,137 @@ def _dynamo_graph_capture_for_export(
dynamo_timed("fullgraph_capture"), dynamo_timed("fullgraph_capture"),
dynamo_config_ctx, dynamo_config_ctx,
): ):
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True) out = fullgraph_capture(
frame,
constraints=_constraints,
_is_export_deprecated_do_not_use=True,
)
assert out.dynamo_output.tracer_output.output_graph is not None assert out.dynamo_output.tracer_output.output_graph is not None
# Extract export metadata from the new location
export_metadata = ( export_metadata = (
out.dynamo_output.tracer_output.output_graph.export_metadata out.dynamo_output.tracer_output.output_graph.export_metadata
) )
graph_inputs = export_metadata.graph_input_idx_to_local_source graph_inputs = export_metadata.graph_input_idx_to_local_source
output_return_type = export_metadata.output_return_type graph_output_map = export_metadata.output_return_type
# We need to extract out_spec here because we are not actually running the bytecode
out_spec = export_metadata.out_spec out_spec = export_metadata.out_spec
module_call_spec = export_metadata.module_call_spec
example_inputs: list[Any] = []
if out.backend_input is not None:
graph = out.backend_input.graph_module graph = out.backend_input.graph_module
fake_mode = out.backend_input.fake_mode
example_inputs = out.backend_input.example_inputs
else:
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
graph.graph.output(None)
graph.recompile()
fake_mode = out.dynamo_output.tracer_output.output_graph.fake_mode
# It is not guaranteed that dynamo puts inputs in right order, so we need to # Compute dynamic dimensions for each input based on constraints
# map the actual user order to the dynamo order. flat_args_dynamic_dims = [
graph_input_order: dict[int, int] = {} {
for inp in graph_inputs: c.dim
source = graph_inputs[inp] for c in (constraints or ())
assert isinstance(source, torch._dynamo.source.GetItemSource) if (
graph_input_order[source.index] = len(graph_input_order) c.t_id == id(x)
and not isinstance(c, _RelaxedConstraint)
and c.constraint_range.vr.lower != c.constraint_range.vr.upper
)
}
for x in flat_inputs
]
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"] # Create input order mapping from dynamo's internal order to user order
output = next(n for n in list(graph.graph.nodes) if n.op == "output") graph_input_order: dict[int, int] = {}
# Sometimes there can be empty inputs for inp in graph_inputs:
anchor = placeholders[0] if len(placeholders) > 0 else output source = graph_inputs[inp]
inp_to_node = {} assert isinstance(source, torch._dynamo.source.GetItemSource)
graph_input_order[source.index] = len(graph_input_order)
with graph.graph.inserting_before(anchor): for real_idx, graph_idx in graph_input_order.items():
for i in range(len(flat_inputs)): flat_inputs[real_idx] = example_inputs[graph_idx]
node_new = graph.graph.placeholder(f"arg_{i}")
if i in graph_input_order:
placeholders[graph_input_order[i]]
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
inp_to_node[i] = node_new
new_args = [] # Use FX transformer to rebuild the graph cleanly
for i in output_return_type: transformed_graph = DynamoGraphTransformer(
type, val = output_return_type[i] graph,
if type == "graph_out": flat_inputs,
new_args.append(output.args[0][val]) flat_args_dynamic_dims,
if type == "input": graph_input_order,
input_idx = val.index graph_output_map,
new_args.append(inp_to_node[input_idx]) fake_mode,
if type == "constant": ).transform()
new_args.append(val)
output.args = (tuple(new_args),)
for src_idx, i in graph_input_order.items(): # Set up PyTree codegen for proper input/output handling
old = placeholders[src_idx] transformed_graph.graph._codegen = _PyTreeCodeGen(
new = inp_to_node[i]
old.replace_all_uses_with(new)
graph.graph.erase_node(old)
# Dynamo uses _lazyGraphModule, so we need to force recompile
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(graph)
graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo( _PyTreeInfo(
argument_names(signature, args, kwargs), # type: ignore[arg-type] argument_names(inspect.signature(mod.forward), args, kwargs), # type: ignore[attr-defined, arg-type]
in_spec, in_spec,
out_spec, out_spec,
) )
) )
transformed_graph.recompile()
graph.recompile() clean_nn_module_stack(transformed_graph)
return graph clean_export_root(transformed_graph)
transformed_graph.meta["module_call_specs"] = module_call_spec
constraint_violation_error = None
try:
# Check if we have any constraint violations
check_fn = out.dynamo_output.build_guards(
module_to_trace.forward.__code__
).guard_manager
check_fn.check(f_locals)
except (
ConstraintViolationError,
torch.utils._sympy.value_ranges.ValueRangeError,
) as e:
constraint_violation_error = e
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not isinstance(
module_to_trace.forward,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
)
):
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
inspect.signature(mod.forward), # type: ignore[attr-defined]
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
'Set TORCH_LOGS="+export" for more information.'
)
if constraint_violation_error:
raise constraint_violation_error
return transformed_graph
return inner return inner

View File

@ -379,6 +379,10 @@ class ExportMetaData:
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = ( out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
torch.utils._pytree._LEAF_SPEC torch.utils._pytree._LEAF_SPEC
) )
module_call_spec: dict[
str,
dict[str, Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec]],
] = dc_field(default_factory=dict)
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]: def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
@ -1695,6 +1699,19 @@ class OutputGraph(OutputGraphGuardsState):
if isinstance( if isinstance(
mut_type, (AttributeMutationExisting, ValueMutationExisting) mut_type, (AttributeMutationExisting, ValueMutationExisting)
): ):
if isinstance(var, UserDefinedDictVariable) and isinstance(
var.value, _ExportModuleSpecTrackerDict
):
for k, v in var.items.items():
specs = {}
for k_spec, val in v.items.items():
specs[k_spec.vt.as_python_constant()] = (
val.as_python_constant()
)
assert ["in_spec", "out_spec"] == list(specs.keys())
self.export_metadata.module_call_spec[
k.vt.as_python_constant()
] = specs
# export uses tracepoint pass to dump submodule inp/out spec # export uses tracepoint pass to dump submodule inp/out spec
# into global state, so we filter it here # into global state, so we filter it here
if not ( if not (

View File

@ -70,6 +70,7 @@ def export_for_training(
strict: bool = False, strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
""" """
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -159,6 +160,7 @@ def export_for_training(
strict=strict, strict=strict,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
) )
@ -171,6 +173,7 @@ def export(
strict: bool = False, strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
""" """
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -283,6 +286,7 @@ def export(
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True, pre_dispatch=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
) )
except Exception as e: except Exception as e:
draft_export_msg = ( draft_export_msg = (

View File

@ -757,6 +757,7 @@ def _export_to_torch_ir(
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False, disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
restore_fqn: bool = True, restore_fqn: bool = True,
_log_export_usage: bool = True, _log_export_usage: bool = True,
same_signature: bool = True, same_signature: bool = True,
@ -809,20 +810,31 @@ def _export_to_torch_ir(
f, preserve_module_call_signature, module_call_specs f, preserve_module_call_signature, module_call_specs
) )
with ctx, _ignore_backend_decomps(): with ctx, _ignore_backend_decomps():
gm_torch_level, _ = torch._dynamo.export( if _use_new_tracer_experimental:
f, from torch._dynamo.functional_export import (
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type] _dynamo_graph_capture_for_export,
constraints=constraints, # type: ignore[arg-type] )
assume_static_by_default=True,
tracing_mode="symbolic", gm_torch_level = _dynamo_graph_capture_for_export(
disable_constraint_solver=disable_constraint_solver, f, constraints=constraints, dynamic_shapes=dynamic_shapes
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, )(*args, **kwargs)
_log_export_usage=_log_export_usage,
same_signature=same_signature, else:
)( gm_torch_level, _ = torch._dynamo.export(
*args, f,
**kwargs, dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
) constraints=constraints, # type: ignore[arg-type]
assume_static_by_default=True,
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
*args,
**kwargs,
)
gm_torch_level.meta["module_call_specs"] = module_call_specs
except (ConstraintViolationError, ValueRangeError) as e: except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
except GuardOnDataDependentSymNode as e: except GuardOnDataDependentSymNode as e:
@ -832,8 +844,6 @@ def _export_to_torch_ir(
case_name="constrain_as_size_example", case_name="constrain_as_size_example",
) )
gm_torch_level.meta["module_call_specs"] = module_call_specs
if isinstance(f, torch.nn.Module) and restore_fqn: if isinstance(f, torch.nn.Module) and restore_fqn:
_restore_state_dict(f, gm_torch_level) _restore_state_dict(f, gm_torch_level)
@ -1407,6 +1417,7 @@ def _strict_export(
orig_in_spec: TreeSpec, orig_in_spec: TreeSpec,
prefer_deferred_runtime_asserts_over_guards: bool, prefer_deferred_runtime_asserts_over_guards: bool,
_to_aten_func: Callable, _to_aten_func: Callable,
_use_new_tracer_experimental: bool = False,
) -> ExportArtifact: ) -> ExportArtifact:
""" """
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir` _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
@ -1421,6 +1432,7 @@ def _strict_export(
restore_fqn=False, # don't need to restore because we will do it later restore_fqn=False, # don't need to restore because we will do it later
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=False, _log_export_usage=False,
_use_new_tracer_experimental=_use_new_tracer_experimental,
) )
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
@ -2041,6 +2053,7 @@ def _export_for_training(
strict: bool = True, strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2056,7 +2069,13 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod) original_state_dict = _get_original_state_dict(mod)
# Call the appropriate export function based on the strictness of tracing. # Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export export_func = (
functools.partial(
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
)
if strict
else _non_strict_export
)
alive_fake_input_ids_before_export: list[int] = [] alive_fake_input_ids_before_export: list[int] = []
@ -2185,6 +2204,7 @@ def _export(
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False, pre_dispatch: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
""" """
Traces either an nn.Module's forward function or just a callable with PyTorch Traces either an nn.Module's forward function or just a callable with PyTorch
@ -2260,6 +2280,7 @@ def _export(
strict=strict, strict=strict,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
) )
dtrace_structured("exported_program", payload_fn=lambda: str(ep)) dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep return ep

View File

@ -302,6 +302,12 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
return hasattr(t, field) return hasattr(t, field)
def _set_attr(model: torch.nn.Module, attr_name: str, value):
attr_names = attr_name.split(".")
t = _get_attr_via_attr_list(model, attr_names[:-1])
setattr(t, attr_names[-1], value)
def _print_readable( def _print_readable(
module, module,
module_name, module_name,