testing infra and some fixes (#162183)

This PR is quite large in that it covers most of rough edges in the new strict export flow:

1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162167
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-08 11:34:12 -07:00
committed by PyTorch MergeBot
parent a965f09793
commit d8b6622bb6
10 changed files with 520 additions and 78 deletions

View File

@ -217,12 +217,19 @@ TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
STRICT_EXPORT_V2_SUFFIX = "_strict_export_v2"
# Now default mode is non strict, so original unammended test names
# should be treated as non-strict
def is_non_strict_test(test_name):
return not test_name.endswith(STRICT_SUFFIX)
return not test_name.endswith(STRICT_SUFFIX) and not test_name.endswith(
STRICT_EXPORT_V2_SUFFIX
)
def is_strict_v2_test(test_name):
return test_name.endswith(STRICT_EXPORT_V2_SUFFIX)
def is_inline_and_install_strict_test(test_name: str) -> bool:
@ -9736,6 +9743,7 @@ graph():
inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
self.assertEqual(ep.module()(**inputs), m(**inputs))
@testing.expectedFailureStrictV2 # AssertionError: RuntimeError not raised
def test_retrace_pre_autograd(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
@ -11764,6 +11772,7 @@ graph():
self.assertEqual(ep.module()(3, 5), 8)
self.assertEqual(ep.module()(5, 4), 9)
@testing.expectedFailureStrictV2 # ValueError: Found conflicts between user-specified and inferred ranges
def test_dynamic_shapes_bounds(self):
class M(torch.nn.Module):
"""
@ -12070,6 +12079,8 @@ graph():
test(export(M(), inp))
# Preserving signature hook is messing with dynamo tracing
@testing.expectedFailureStrictV2
def test_unflatten_multiple_graphs_state(self):
class N(torch.nn.Module):
def __init__(self):
@ -13688,7 +13699,7 @@ def forward(self, x, y):
inputs = (torch.randn(10, 72),)
dx, dy = dims("dx", "dy")
ep = torch.export.export(
ep = torch.export._trace._export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
@ -14536,6 +14547,14 @@ graph():
if is_inline_and_install_strict_test(self._testMethodName):
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
# This is fine since both of these will be deprecated soon.
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).0"
)
self.assertEqual(
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
else:
self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
@ -15374,6 +15393,7 @@ class GraphModule(torch.nn.Module):
)
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation
@testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation
@testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation
@testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation
@testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet

View File

@ -0,0 +1,54 @@
# Owner(s): ["oncall: export"]
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_strict_export_v2(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
if kwargs["strict"]:
return export(*args, **kwargs, _use_new_tracer_experimental=True)
else:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
def make_dynamic_cls(cls):
cls_prefix = "StrictExportV2"
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.STRICT_EXPORT_V2_SUFFIX,
mocked_strict_export_v2,
xfail_prop="_expected_failure_strict_v2",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
test_export.TestDynamismExpression,
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -257,6 +257,12 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
return fn
# Controls tests generated in test/export/test_export_strict_v2.py
def expectedFailureStrictV2(fn):
fn._expected_failure_strict_v2 = True
return fn
# Controls tests generated in test/export/test_export_strict.py
def expectedFailureStrict(fn):
fn._expected_failure_strict = True

View File

@ -258,6 +258,7 @@ def aot_compile_fullgraph(
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
output_graph = dynamo_output.tracer_output.output_graph
assert output_graph is not None
assert backend_input is not None
import_sources = output_graph.import_sources
with (
torch._guards.tracing(TracingContext(backend_input.fake_mode)),

View File

@ -95,6 +95,7 @@ from .cache_size import (
)
from .eval_frame import (
always_optimize_code_objects,
Constraint,
dynamo_tls,
skip_code,
TorchPatcher,
@ -894,7 +895,8 @@ class CaptureOutput:
"""
dynamo_output: DynamoOutput
backend_input: BackendInput
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
backend_input: Optional[BackendInput]
@dataclass
@ -907,7 +909,10 @@ class FrameInfo:
def fullgraph_capture(
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
frame: FrameInfo,
*,
constraints: Optional[list[Constraint]] = None,
_is_export_deprecated_do_not_use: bool = False,
) -> CaptureOutput:
"""
A standalone function which takes a frame and returns dynamo captured graph
@ -951,6 +956,7 @@ def fullgraph_capture(
frame.closure,
compiler_fn=fullgraph_compiler,
export=_is_export_deprecated_do_not_use,
export_constraints=constraints, # type: ignore[arg-type]
one_graph=True,
restart_reasons=set(),
)
@ -966,7 +972,6 @@ def fullgraph_capture(
cur_exn = cur_exn.__cause__
raise e.with_traceback(None) from e.__cause__ # User compiler error
assert backend_input is not None
return CaptureOutput(dynamo_output, backend_input)

View File

@ -1,17 +1,70 @@
import builtins
import inspect
import logging
import traceback
from collections import namedtuple
from typing import Any, Callable
from typing import Any, Callable, Optional, Union
import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
from torch.fx import Node
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
StatelessSymbolicContext,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
log = logging.getLogger(__name__)
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in graph_module.graph.nodes:
if "nn_module_stack" in node.meta:
nn_module_stack = node.meta["nn_module_stack"].copy()
first_key = next(iter(nn_module_stack.keys()))
if "export_root" in first_key:
del nn_module_stack[first_key]
nn_module_stack_corrected = {}
for k, v in nn_module_stack.items():
k_new = "".join(k.split("__export_root"))
child_name, child_class = v
child_name = child_name.replace("._export_root", "")
nn_module_stack_corrected[k_new] = (child_name, child_class)
node.meta["nn_module_stack"] = nn_module_stack_corrected
return graph_module
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
"""Remove export_root artifacts from FX graph in-place"""
# Clean parameter names: L__self____export_root_param -> L__self___param
def clean_name(name) -> str:
return name.replace("__export_root_", "_") if "__export_root_" in name else name
# Update get_attr nodes in-place
for node in graph_module.graph.nodes:
if node.op == "get_attr":
old_target = node.target
new_target = clean_name(old_target)
if new_target != old_target:
node.target = new_target
# Move the parameter to the new name
if hasattr(graph_module, old_target):
param = torch.fx.graph_module._get_attr(graph_module, old_target)
torch.fx.graph_module._set_attr(graph_module, new_target, param)
torch.fx.graph_module._del_attr(graph_module, old_target)
class ModuleToTrace(torch.nn.Module):
def __init__(self, foo: Any, in_spec: Any) -> None:
super().__init__()
@ -28,29 +81,214 @@ class ModuleToTrace(torch.nn.Module):
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
# mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator"
class DynamoGraphTransformer(torch.fx.Transformer):
"""Graph transformer for dynamo export that flattens inputs/outputs without complex matching."""
def __init__(
self,
module: torch.fx.GraphModule,
flat_inputs: list[Any],
flat_args_dynamic_dims: list[set[int]],
graph_input_order: dict[int, int],
graph_output_map: dict[int, tuple[str, Any]],
fake_mode: Optional[Any] = None,
) -> None:
super().__init__(module)
assert len(flat_args_dynamic_dims) == len(flat_inputs)
self.flat_inputs = flat_inputs
self.flat_args_dynamic_dims = flat_args_dynamic_dims
self.graph_input_order = graph_input_order
self.graph_output_map = graph_output_map
self.fake_mode = fake_mode
# Get original placeholders and output
self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
self.output_node = next(n for n in module.graph.nodes if n.op == "output")
# Create new flattened input placeholders
self.new_input_nodes: dict[int, torch.fx.Node] = {}
self._create_flattened_inputs()
# Iterator for replacing old placeholders
self.old_to_new_mapping = {}
self._create_placeholder_mapping()
def _create_flattened_inputs(self) -> None:
"""Create new placeholder nodes for flattened inputs with proper fake tensors."""
for i in range(len(self.flat_inputs)):
placeholder = super().placeholder(f"arg_{i}", (), {})
# Check if this user input (index i) maps to a graph placeholder
if i in self.graph_input_order:
# graph_input_order[i] gives us which graph placeholder this user input corresponds to
graph_placeholder_idx = self.graph_input_order[i]
if graph_placeholder_idx < len(self.placeholders):
orig_placeholder = self.placeholders[graph_placeholder_idx]
# Copy other metadata but not "val" yet
for key, value in orig_placeholder.meta.items():
if key != "val":
placeholder.node.meta[key] = value
# Always ensure we have proper "val" metadata from fake tensor
if self.fake_mode is not None and isinstance(
self.flat_inputs[i], torch.Tensor
):
placeholder.node.meta["val"] = self.fake_mode.from_tensor(
self.flat_inputs[i],
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[
(
DimDynamic.DYNAMIC
if d in self.flat_args_dynamic_dims[i]
else DimDynamic.STATIC
)
for d in range(len(self.flat_inputs[i].shape))
],
constraint_sizes=[None] * len(self.flat_inputs[i].shape),
),
)
elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case
placeholder.node.meta["val"] = self.flat_inputs[i].val
else:
placeholder.node.meta["val"] = self.flat_inputs[i]
self.new_input_nodes[i] = placeholder
def _create_placeholder_mapping(self) -> None:
"""Create mapping from old placeholders to new ones."""
# graph_input_order maps: user_input_index -> graph_placeholder_index
# We need to create: old_graph_placeholder -> new_user_input_placeholder
for user_input_idx, graph_placeholder_idx in self.graph_input_order.items():
if graph_placeholder_idx < len(self.placeholders):
old_placeholder = self.placeholders[graph_placeholder_idx]
new_placeholder = self.new_input_nodes[user_input_idx]
self.old_to_new_mapping[old_placeholder] = new_placeholder
def placeholder(self, target, args, kwargs) -> Any:
"""Replace old placeholders with new flattened ones."""
# Return the corresponding new placeholder
if self.current_node in self.old_to_new_mapping:
new_arg = self.old_to_new_mapping[self.current_node]
# Copy over additional metadata from current node, but don't overwrite "val"
for key in ["tensor_dict", "example_value", "unbacked_bindings"]:
if key in self.current_node.meta:
new_arg.node.meta[key] = self.current_node.meta[key]
# Only copy "val" if we don't already have a good one
if "val" in self.current_node.meta and "val" not in new_arg.node.meta:
new_arg.node.meta["val"] = self.current_node.meta["val"]
return new_arg
else:
# Shouldn't happen if mapping is correct, but fallback
return super().placeholder(target, args, kwargs)
def output(self, target, args, kwargs) -> Any:
"""Transform output according to graph_output_map."""
original_outputs = args[0]
# Build new output list based on graph_output_map
new_outputs = []
for i in sorted(self.graph_output_map.keys()):
output_type, val = self.graph_output_map[i]
if output_type == "graph_out":
new_outputs.append(original_outputs[val])
elif output_type == "input":
input_idx = val.index
new_outputs.append(self.new_input_nodes[input_idx])
elif output_type == "constant":
new_outputs.append(val)
return super().output(target, (tuple(new_outputs),), {})
def run_node(self, node: Node) -> Any:
"""Run node transformation and preserve metadata."""
self.current_node = node
result = super().run_node(node)
# Copy important metadata
if hasattr(result, "node") and result.node is not node:
for key in ["val", "example_value", "unbacked_bindings"]:
if key in node.meta:
result.node.meta[key] = node.meta[key]
# Preserve node names (except output)
if node.op != "output" and hasattr(node, "name"):
result.node._rename(node.name)
return result
def transform(self) -> torch.fx.GraphModule:
"""Perform the graph transformation and copy module metadata."""
result_gm = super().transform()
# Copy module metadata like the original implementation
if hasattr(self.module, "meta"):
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
"dynamo_flat_name_to_original_fqn"
]
if "dynamo_compile_id" in self.module.meta:
result_gm.meta["dynamo_compile_id"] = self.module.meta[
"dynamo_compile_id"
]
return result_gm
def _dynamo_graph_capture_for_export(
mod: torch.nn.Module,
mod: Callable[..., Any],
*,
constraints: Optional[list[Constraint]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
) -> Callable[..., torch.fx.GraphModule]:
"""
This is lower level API that is used for export to capture dynamo level
torch IR.
Improved dynamo graph capture using transformer approach with proper fake tensor handling.
Notable TODOs:
This function creates a capture instance that handles:
1. PyTree flattening/unflattening with proper input ordering
2. Dynamo graph capture with export-specific context
3. FX graph transformation for export compatibility
4. Proper fake tensor metadata preservation
5. Dynamic dimension constraint handling
Notable improvements over manual approach:
- Uses FX Transformer for cleaner graph manipulation
- Properly handles fake tensor metadata and dynamic dimensions
- Preserves all necessary metadata for export
- More robust error handling and edge case management
TODO:
1. Are we actually gonna run the bytecode?
2. Need to attach guards
"""
_dynamic_shapes = dynamic_shapes
_constraints = constraints
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
module_to_trace = ModuleToTrace(mod, in_spec)
signature = inspect.signature(module_to_trace.forward)
bound_arguments = signature.bind(*flat_inputs)
bound_arguments.apply_defaults()
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
constraints: Optional[list[Constraint]] = _constraints
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
_dynamic_shapes
)
from . import reset # type: ignore[attr-defined]
reset()
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
frame = FrameInfo(
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
@ -60,7 +298,14 @@ def _dynamo_graph_capture_for_export(
)
dynamo_config_ctx = torch._dynamo.config.patch(
"log_graph_in_out_metadata", True
specialize_int=True,
specialize_float=True,
assume_static_by_default=True,
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=False,
log_graph_in_out_metadata=True,
)
with (
@ -69,74 +314,137 @@ def _dynamo_graph_capture_for_export(
dynamo_timed("fullgraph_capture"),
dynamo_config_ctx,
):
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
out = fullgraph_capture(
frame,
constraints=_constraints,
_is_export_deprecated_do_not_use=True,
)
assert out.dynamo_output.tracer_output.output_graph is not None
# Extract export metadata from the new location
export_metadata = (
out.dynamo_output.tracer_output.output_graph.export_metadata
)
graph_inputs = export_metadata.graph_input_idx_to_local_source
output_return_type = export_metadata.output_return_type
# We need to extract out_spec here because we are not actually running the bytecode
graph_output_map = export_metadata.output_return_type
out_spec = export_metadata.out_spec
module_call_spec = export_metadata.module_call_spec
example_inputs: list[Any] = []
if out.backend_input is not None:
graph = out.backend_input.graph_module
fake_mode = out.backend_input.fake_mode
example_inputs = out.backend_input.example_inputs
else:
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
graph.graph.output(None)
graph.recompile()
fake_mode = out.dynamo_output.tracer_output.output_graph.fake_mode
# It is not guaranteed that dynamo puts inputs in right order, so we need to
# map the actual user order to the dynamo order.
graph_input_order: dict[int, int] = {}
for inp in graph_inputs:
source = graph_inputs[inp]
assert isinstance(source, torch._dynamo.source.GetItemSource)
graph_input_order[source.index] = len(graph_input_order)
# Compute dynamic dimensions for each input based on constraints
flat_args_dynamic_dims = [
{
c.dim
for c in (constraints or ())
if (
c.t_id == id(x)
and not isinstance(c, _RelaxedConstraint)
and c.constraint_range.vr.lower != c.constraint_range.vr.upper
)
}
for x in flat_inputs
]
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
# Sometimes there can be empty inputs
anchor = placeholders[0] if len(placeholders) > 0 else output
inp_to_node = {}
# Create input order mapping from dynamo's internal order to user order
graph_input_order: dict[int, int] = {}
for inp in graph_inputs:
source = graph_inputs[inp]
assert isinstance(source, torch._dynamo.source.GetItemSource)
graph_input_order[source.index] = len(graph_input_order)
with graph.graph.inserting_before(anchor):
for i in range(len(flat_inputs)):
node_new = graph.graph.placeholder(f"arg_{i}")
if i in graph_input_order:
placeholders[graph_input_order[i]]
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
inp_to_node[i] = node_new
for real_idx, graph_idx in graph_input_order.items():
flat_inputs[real_idx] = example_inputs[graph_idx]
new_args = []
for i in output_return_type:
type, val = output_return_type[i]
if type == "graph_out":
new_args.append(output.args[0][val])
if type == "input":
input_idx = val.index
new_args.append(inp_to_node[input_idx])
if type == "constant":
new_args.append(val)
output.args = (tuple(new_args),)
# Use FX transformer to rebuild the graph cleanly
transformed_graph = DynamoGraphTransformer(
graph,
flat_inputs,
flat_args_dynamic_dims,
graph_input_order,
graph_output_map,
fake_mode,
).transform()
for src_idx, i in graph_input_order.items():
old = placeholders[src_idx]
new = inp_to_node[i]
old.replace_all_uses_with(new)
graph.graph.erase_node(old)
# Dynamo uses _lazyGraphModule, so we need to force recompile
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(graph)
graph.graph._codegen = _PyTreeCodeGen(
# Set up PyTree codegen for proper input/output handling
transformed_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(signature, args, kwargs), # type: ignore[arg-type]
argument_names(inspect.signature(mod.forward), args, kwargs), # type: ignore[attr-defined, arg-type]
in_spec,
out_spec,
)
)
transformed_graph.recompile()
graph.recompile()
return graph
clean_nn_module_stack(transformed_graph)
clean_export_root(transformed_graph)
transformed_graph.meta["module_call_specs"] = module_call_spec
constraint_violation_error = None
try:
# Check if we have any constraint violations
check_fn = out.dynamo_output.build_guards(
module_to_trace.forward.__code__
).guard_manager
check_fn.check(f_locals)
except (
ConstraintViolationError,
torch.utils._sympy.value_ranges.ValueRangeError,
) as e:
constraint_violation_error = e
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not isinstance(
module_to_trace.forward,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
)
):
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
inspect.signature(mod.forward), # type: ignore[attr-defined]
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
'Set TORCH_LOGS="+export" for more information.'
)
if constraint_violation_error:
raise constraint_violation_error
return transformed_graph
return inner

View File

@ -379,6 +379,10 @@ class ExportMetaData:
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
torch.utils._pytree._LEAF_SPEC
)
module_call_spec: dict[
str,
dict[str, Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec]],
] = dc_field(default_factory=dict)
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
@ -1695,6 +1699,19 @@ class OutputGraph(OutputGraphGuardsState):
if isinstance(
mut_type, (AttributeMutationExisting, ValueMutationExisting)
):
if isinstance(var, UserDefinedDictVariable) and isinstance(
var.value, _ExportModuleSpecTrackerDict
):
for k, v in var.items.items():
specs = {}
for k_spec, val in v.items.items():
specs[k_spec.vt.as_python_constant()] = (
val.as_python_constant()
)
assert ["in_spec", "out_spec"] == list(specs.keys())
self.export_metadata.module_call_spec[
k.vt.as_python_constant()
] = specs
# export uses tracepoint pass to dump submodule inp/out spec
# into global state, so we filter it here
if not (

View File

@ -70,6 +70,7 @@ def export_for_training(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -159,6 +160,7 @@ def export_for_training(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
@ -171,6 +173,7 @@ def export(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -283,6 +286,7 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
except Exception as e:
draft_export_msg = (

View File

@ -757,6 +757,7 @@ def _export_to_torch_ir(
preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
same_signature: bool = True,
@ -809,20 +810,31 @@ def _export_to_torch_ir(
f, preserve_module_call_signature, module_call_specs
)
with ctx, _ignore_backend_decomps():
gm_torch_level, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
constraints=constraints, # type: ignore[arg-type]
assume_static_by_default=True,
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
*args,
**kwargs,
)
if _use_new_tracer_experimental:
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
)
gm_torch_level = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)(*args, **kwargs)
else:
gm_torch_level, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
constraints=constraints, # type: ignore[arg-type]
assume_static_by_default=True,
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
*args,
**kwargs,
)
gm_torch_level.meta["module_call_specs"] = module_call_specs
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
except GuardOnDataDependentSymNode as e:
@ -832,8 +844,6 @@ def _export_to_torch_ir(
case_name="constrain_as_size_example",
)
gm_torch_level.meta["module_call_specs"] = module_call_specs
if isinstance(f, torch.nn.Module) and restore_fqn:
_restore_state_dict(f, gm_torch_level)
@ -1407,6 +1417,7 @@ def _strict_export(
orig_in_spec: TreeSpec,
prefer_deferred_runtime_asserts_over_guards: bool,
_to_aten_func: Callable,
_use_new_tracer_experimental: bool = False,
) -> ExportArtifact:
"""
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
@ -1421,6 +1432,7 @@ def _strict_export(
restore_fqn=False, # don't need to restore because we will do it later
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=False,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
@ -2041,6 +2053,7 @@ def _export_for_training(
strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2056,7 +2069,13 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
# Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export
export_func = (
functools.partial(
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
)
if strict
else _non_strict_export
)
alive_fake_input_ids_before_export: list[int] = []
@ -2185,6 +2204,7 @@ def _export(
preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
@ -2260,6 +2280,7 @@ def _export(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep

View File

@ -302,6 +302,12 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
return hasattr(t, field)
def _set_attr(model: torch.nn.Module, attr_name: str, value):
attr_names = attr_name.split(".")
t = _get_attr_via_attr_list(model, attr_names[:-1])
setattr(t, attr_names[-1], value)
def _print_readable(
module,
module_name,