mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
testing infra and some fixes (#162183)
This PR is quite large in that it covers most of rough edges in the new strict export flow: 1. Handle nn_module_stack correctly now that we are tracing wrapper module 2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore. 3. Correct input and output handling. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162167
This commit is contained in:
committed by
PyTorch MergeBot
parent
a965f09793
commit
d8b6622bb6
@ -217,12 +217,19 @@ TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
|
||||
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
|
||||
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
|
||||
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
|
||||
STRICT_EXPORT_V2_SUFFIX = "_strict_export_v2"
|
||||
|
||||
|
||||
# Now default mode is non strict, so original unammended test names
|
||||
# should be treated as non-strict
|
||||
def is_non_strict_test(test_name):
|
||||
return not test_name.endswith(STRICT_SUFFIX)
|
||||
return not test_name.endswith(STRICT_SUFFIX) and not test_name.endswith(
|
||||
STRICT_EXPORT_V2_SUFFIX
|
||||
)
|
||||
|
||||
|
||||
def is_strict_v2_test(test_name):
|
||||
return test_name.endswith(STRICT_EXPORT_V2_SUFFIX)
|
||||
|
||||
|
||||
def is_inline_and_install_strict_test(test_name: str) -> bool:
|
||||
@ -9736,6 +9743,7 @@ graph():
|
||||
inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
|
||||
self.assertEqual(ep.module()(**inputs), m(**inputs))
|
||||
|
||||
@testing.expectedFailureStrictV2 # AssertionError: RuntimeError not raised
|
||||
def test_retrace_pre_autograd(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -11764,6 +11772,7 @@ graph():
|
||||
self.assertEqual(ep.module()(3, 5), 8)
|
||||
self.assertEqual(ep.module()(5, 4), 9)
|
||||
|
||||
@testing.expectedFailureStrictV2 # ValueError: Found conflicts between user-specified and inferred ranges
|
||||
def test_dynamic_shapes_bounds(self):
|
||||
class M(torch.nn.Module):
|
||||
"""
|
||||
@ -12070,6 +12079,8 @@ graph():
|
||||
|
||||
test(export(M(), inp))
|
||||
|
||||
# Preserving signature hook is messing with dynamo tracing
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_multiple_graphs_state(self):
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -13688,7 +13699,7 @@ def forward(self, x, y):
|
||||
|
||||
inputs = (torch.randn(10, 72),)
|
||||
dx, dy = dims("dx", "dy")
|
||||
ep = torch.export.export(
|
||||
ep = torch.export._trace._export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
@ -14536,6 +14547,14 @@ graph():
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
|
||||
# This is fine since both of these will be deprecated soon.
|
||||
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
|
||||
self.assertEqual(
|
||||
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).0"
|
||||
)
|
||||
self.assertEqual(
|
||||
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
|
||||
@ -15374,6 +15393,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation
|
||||
@testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation
|
||||
@testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation
|
||||
@testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation
|
||||
@testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet
|
||||
|
54
test/export/test_strict_export_v2.py
Normal file
54
test/export/test_strict_export_v2.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export_v2(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
if kwargs["strict"]:
|
||||
return export(*args, **kwargs, _use_new_tracer_experimental=True)
|
||||
else:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
cls_prefix = "StrictExportV2"
|
||||
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.STRICT_EXPORT_V2_SUFFIX,
|
||||
mocked_strict_export_v2,
|
||||
xfail_prop="_expected_failure_strict_v2",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -257,6 +257,12 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
||||
return fn
|
||||
|
||||
|
||||
# Controls tests generated in test/export/test_export_strict_v2.py
|
||||
def expectedFailureStrictV2(fn):
|
||||
fn._expected_failure_strict_v2 = True
|
||||
return fn
|
||||
|
||||
|
||||
# Controls tests generated in test/export/test_export_strict.py
|
||||
def expectedFailureStrict(fn):
|
||||
fn._expected_failure_strict = True
|
||||
|
@ -258,6 +258,7 @@ def aot_compile_fullgraph(
|
||||
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
||||
output_graph = dynamo_output.tracer_output.output_graph
|
||||
assert output_graph is not None
|
||||
assert backend_input is not None
|
||||
import_sources = output_graph.import_sources
|
||||
with (
|
||||
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
||||
|
@ -95,6 +95,7 @@ from .cache_size import (
|
||||
)
|
||||
from .eval_frame import (
|
||||
always_optimize_code_objects,
|
||||
Constraint,
|
||||
dynamo_tls,
|
||||
skip_code,
|
||||
TorchPatcher,
|
||||
@ -894,7 +895,8 @@ class CaptureOutput:
|
||||
"""
|
||||
|
||||
dynamo_output: DynamoOutput
|
||||
backend_input: BackendInput
|
||||
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
|
||||
backend_input: Optional[BackendInput]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -907,7 +909,10 @@ class FrameInfo:
|
||||
|
||||
|
||||
def fullgraph_capture(
|
||||
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
|
||||
frame: FrameInfo,
|
||||
*,
|
||||
constraints: Optional[list[Constraint]] = None,
|
||||
_is_export_deprecated_do_not_use: bool = False,
|
||||
) -> CaptureOutput:
|
||||
"""
|
||||
A standalone function which takes a frame and returns dynamo captured graph
|
||||
@ -951,6 +956,7 @@ def fullgraph_capture(
|
||||
frame.closure,
|
||||
compiler_fn=fullgraph_compiler,
|
||||
export=_is_export_deprecated_do_not_use,
|
||||
export_constraints=constraints, # type: ignore[arg-type]
|
||||
one_graph=True,
|
||||
restart_reasons=set(),
|
||||
)
|
||||
@ -966,7 +972,6 @@ def fullgraph_capture(
|
||||
cur_exn = cur_exn.__cause__
|
||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||
|
||||
assert backend_input is not None
|
||||
return CaptureOutput(dynamo_output, backend_input)
|
||||
|
||||
|
||||
|
@ -1,17 +1,70 @@
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
|
||||
from torch._dynamo.eval_frame import argument_names
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext
|
||||
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
DimDynamic,
|
||||
StatelessSymbolicContext,
|
||||
)
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for node in graph_module.graph.nodes:
|
||||
if "nn_module_stack" in node.meta:
|
||||
nn_module_stack = node.meta["nn_module_stack"].copy()
|
||||
first_key = next(iter(nn_module_stack.keys()))
|
||||
if "export_root" in first_key:
|
||||
del nn_module_stack[first_key]
|
||||
nn_module_stack_corrected = {}
|
||||
for k, v in nn_module_stack.items():
|
||||
k_new = "".join(k.split("__export_root"))
|
||||
child_name, child_class = v
|
||||
child_name = child_name.replace("._export_root", "")
|
||||
nn_module_stack_corrected[k_new] = (child_name, child_class)
|
||||
node.meta["nn_module_stack"] = nn_module_stack_corrected
|
||||
return graph_module
|
||||
|
||||
|
||||
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
"""Remove export_root artifacts from FX graph in-place"""
|
||||
|
||||
# Clean parameter names: L__self____export_root_param -> L__self___param
|
||||
def clean_name(name) -> str:
|
||||
return name.replace("__export_root_", "_") if "__export_root_" in name else name
|
||||
|
||||
# Update get_attr nodes in-place
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
old_target = node.target
|
||||
new_target = clean_name(old_target)
|
||||
if new_target != old_target:
|
||||
node.target = new_target
|
||||
# Move the parameter to the new name
|
||||
if hasattr(graph_module, old_target):
|
||||
param = torch.fx.graph_module._get_attr(graph_module, old_target)
|
||||
torch.fx.graph_module._set_attr(graph_module, new_target, param)
|
||||
torch.fx.graph_module._del_attr(graph_module, old_target)
|
||||
|
||||
|
||||
class ModuleToTrace(torch.nn.Module):
|
||||
def __init__(self, foo: Any, in_spec: Any) -> None:
|
||||
super().__init__()
|
||||
@ -28,29 +81,214 @@ class ModuleToTrace(torch.nn.Module):
|
||||
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
|
||||
|
||||
|
||||
# mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator"
|
||||
class DynamoGraphTransformer(torch.fx.Transformer):
|
||||
"""Graph transformer for dynamo export that flattens inputs/outputs without complex matching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
flat_inputs: list[Any],
|
||||
flat_args_dynamic_dims: list[set[int]],
|
||||
graph_input_order: dict[int, int],
|
||||
graph_output_map: dict[int, tuple[str, Any]],
|
||||
fake_mode: Optional[Any] = None,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
|
||||
assert len(flat_args_dynamic_dims) == len(flat_inputs)
|
||||
|
||||
self.flat_inputs = flat_inputs
|
||||
self.flat_args_dynamic_dims = flat_args_dynamic_dims
|
||||
self.graph_input_order = graph_input_order
|
||||
self.graph_output_map = graph_output_map
|
||||
self.fake_mode = fake_mode
|
||||
|
||||
# Get original placeholders and output
|
||||
self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
|
||||
self.output_node = next(n for n in module.graph.nodes if n.op == "output")
|
||||
|
||||
# Create new flattened input placeholders
|
||||
self.new_input_nodes: dict[int, torch.fx.Node] = {}
|
||||
self._create_flattened_inputs()
|
||||
|
||||
# Iterator for replacing old placeholders
|
||||
self.old_to_new_mapping = {}
|
||||
self._create_placeholder_mapping()
|
||||
|
||||
def _create_flattened_inputs(self) -> None:
|
||||
"""Create new placeholder nodes for flattened inputs with proper fake tensors."""
|
||||
for i in range(len(self.flat_inputs)):
|
||||
placeholder = super().placeholder(f"arg_{i}", (), {})
|
||||
|
||||
# Check if this user input (index i) maps to a graph placeholder
|
||||
if i in self.graph_input_order:
|
||||
# graph_input_order[i] gives us which graph placeholder this user input corresponds to
|
||||
graph_placeholder_idx = self.graph_input_order[i]
|
||||
if graph_placeholder_idx < len(self.placeholders):
|
||||
orig_placeholder = self.placeholders[graph_placeholder_idx]
|
||||
# Copy other metadata but not "val" yet
|
||||
for key, value in orig_placeholder.meta.items():
|
||||
if key != "val":
|
||||
placeholder.node.meta[key] = value
|
||||
|
||||
# Always ensure we have proper "val" metadata from fake tensor
|
||||
if self.fake_mode is not None and isinstance(
|
||||
self.flat_inputs[i], torch.Tensor
|
||||
):
|
||||
placeholder.node.meta["val"] = self.fake_mode.from_tensor(
|
||||
self.flat_inputs[i],
|
||||
symbolic_context=StatelessSymbolicContext(
|
||||
dynamic_sizes=[
|
||||
(
|
||||
DimDynamic.DYNAMIC
|
||||
if d in self.flat_args_dynamic_dims[i]
|
||||
else DimDynamic.STATIC
|
||||
)
|
||||
for d in range(len(self.flat_inputs[i].shape))
|
||||
],
|
||||
constraint_sizes=[None] * len(self.flat_inputs[i].shape),
|
||||
),
|
||||
)
|
||||
elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case
|
||||
placeholder.node.meta["val"] = self.flat_inputs[i].val
|
||||
else:
|
||||
placeholder.node.meta["val"] = self.flat_inputs[i]
|
||||
|
||||
self.new_input_nodes[i] = placeholder
|
||||
|
||||
def _create_placeholder_mapping(self) -> None:
|
||||
"""Create mapping from old placeholders to new ones."""
|
||||
# graph_input_order maps: user_input_index -> graph_placeholder_index
|
||||
# We need to create: old_graph_placeholder -> new_user_input_placeholder
|
||||
for user_input_idx, graph_placeholder_idx in self.graph_input_order.items():
|
||||
if graph_placeholder_idx < len(self.placeholders):
|
||||
old_placeholder = self.placeholders[graph_placeholder_idx]
|
||||
new_placeholder = self.new_input_nodes[user_input_idx]
|
||||
self.old_to_new_mapping[old_placeholder] = new_placeholder
|
||||
|
||||
def placeholder(self, target, args, kwargs) -> Any:
|
||||
"""Replace old placeholders with new flattened ones."""
|
||||
# Return the corresponding new placeholder
|
||||
if self.current_node in self.old_to_new_mapping:
|
||||
new_arg = self.old_to_new_mapping[self.current_node]
|
||||
|
||||
# Copy over additional metadata from current node, but don't overwrite "val"
|
||||
for key in ["tensor_dict", "example_value", "unbacked_bindings"]:
|
||||
if key in self.current_node.meta:
|
||||
new_arg.node.meta[key] = self.current_node.meta[key]
|
||||
|
||||
# Only copy "val" if we don't already have a good one
|
||||
if "val" in self.current_node.meta and "val" not in new_arg.node.meta:
|
||||
new_arg.node.meta["val"] = self.current_node.meta["val"]
|
||||
|
||||
return new_arg
|
||||
else:
|
||||
# Shouldn't happen if mapping is correct, but fallback
|
||||
return super().placeholder(target, args, kwargs)
|
||||
|
||||
def output(self, target, args, kwargs) -> Any:
|
||||
"""Transform output according to graph_output_map."""
|
||||
original_outputs = args[0]
|
||||
|
||||
# Build new output list based on graph_output_map
|
||||
new_outputs = []
|
||||
for i in sorted(self.graph_output_map.keys()):
|
||||
output_type, val = self.graph_output_map[i]
|
||||
|
||||
if output_type == "graph_out":
|
||||
new_outputs.append(original_outputs[val])
|
||||
elif output_type == "input":
|
||||
input_idx = val.index
|
||||
new_outputs.append(self.new_input_nodes[input_idx])
|
||||
elif output_type == "constant":
|
||||
new_outputs.append(val)
|
||||
|
||||
return super().output(target, (tuple(new_outputs),), {})
|
||||
|
||||
def run_node(self, node: Node) -> Any:
|
||||
"""Run node transformation and preserve metadata."""
|
||||
self.current_node = node
|
||||
result = super().run_node(node)
|
||||
|
||||
# Copy important metadata
|
||||
if hasattr(result, "node") and result.node is not node:
|
||||
for key in ["val", "example_value", "unbacked_bindings"]:
|
||||
if key in node.meta:
|
||||
result.node.meta[key] = node.meta[key]
|
||||
|
||||
# Preserve node names (except output)
|
||||
if node.op != "output" and hasattr(node, "name"):
|
||||
result.node._rename(node.name)
|
||||
|
||||
return result
|
||||
|
||||
def transform(self) -> torch.fx.GraphModule:
|
||||
"""Perform the graph transformation and copy module metadata."""
|
||||
result_gm = super().transform()
|
||||
|
||||
# Copy module metadata like the original implementation
|
||||
if hasattr(self.module, "meta"):
|
||||
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
|
||||
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
|
||||
"dynamo_flat_name_to_original_fqn"
|
||||
]
|
||||
if "dynamo_compile_id" in self.module.meta:
|
||||
result_gm.meta["dynamo_compile_id"] = self.module.meta[
|
||||
"dynamo_compile_id"
|
||||
]
|
||||
|
||||
return result_gm
|
||||
|
||||
|
||||
def _dynamo_graph_capture_for_export(
|
||||
mod: torch.nn.Module,
|
||||
mod: Callable[..., Any],
|
||||
*,
|
||||
constraints: Optional[list[Constraint]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
) -> Callable[..., torch.fx.GraphModule]:
|
||||
"""
|
||||
This is lower level API that is used for export to capture dynamo level
|
||||
torch IR.
|
||||
Improved dynamo graph capture using transformer approach with proper fake tensor handling.
|
||||
|
||||
Notable TODOs:
|
||||
This function creates a capture instance that handles:
|
||||
1. PyTree flattening/unflattening with proper input ordering
|
||||
2. Dynamo graph capture with export-specific context
|
||||
3. FX graph transformation for export compatibility
|
||||
4. Proper fake tensor metadata preservation
|
||||
5. Dynamic dimension constraint handling
|
||||
|
||||
Notable improvements over manual approach:
|
||||
- Uses FX Transformer for cleaner graph manipulation
|
||||
- Properly handles fake tensor metadata and dynamic dimensions
|
||||
- Preserves all necessary metadata for export
|
||||
- More robust error handling and edge case management
|
||||
|
||||
TODO:
|
||||
1. Are we actually gonna run the bytecode?
|
||||
2. Need to attach guards
|
||||
"""
|
||||
|
||||
_dynamic_shapes = dynamic_shapes
|
||||
_constraints = constraints
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
|
||||
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
module_to_trace = ModuleToTrace(mod, in_spec)
|
||||
|
||||
signature = inspect.signature(module_to_trace.forward)
|
||||
|
||||
bound_arguments = signature.bind(*flat_inputs)
|
||||
bound_arguments.apply_defaults()
|
||||
|
||||
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
||||
constraints: Optional[list[Constraint]] = _constraints
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
|
||||
_dynamic_shapes
|
||||
)
|
||||
|
||||
from . import reset # type: ignore[attr-defined]
|
||||
|
||||
reset()
|
||||
|
||||
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
||||
frame = FrameInfo(
|
||||
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
|
||||
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
|
||||
@ -60,7 +298,14 @@ def _dynamo_graph_capture_for_export(
|
||||
)
|
||||
|
||||
dynamo_config_ctx = torch._dynamo.config.patch(
|
||||
"log_graph_in_out_metadata", True
|
||||
specialize_int=True,
|
||||
specialize_float=True,
|
||||
assume_static_by_default=True,
|
||||
automatic_dynamic_shapes=False,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=False,
|
||||
log_graph_in_out_metadata=True,
|
||||
)
|
||||
|
||||
with (
|
||||
@ -69,74 +314,137 @@ def _dynamo_graph_capture_for_export(
|
||||
dynamo_timed("fullgraph_capture"),
|
||||
dynamo_config_ctx,
|
||||
):
|
||||
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
|
||||
out = fullgraph_capture(
|
||||
frame,
|
||||
constraints=_constraints,
|
||||
_is_export_deprecated_do_not_use=True,
|
||||
)
|
||||
|
||||
assert out.dynamo_output.tracer_output.output_graph is not None
|
||||
|
||||
# Extract export metadata from the new location
|
||||
export_metadata = (
|
||||
out.dynamo_output.tracer_output.output_graph.export_metadata
|
||||
)
|
||||
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
||||
output_return_type = export_metadata.output_return_type
|
||||
# We need to extract out_spec here because we are not actually running the bytecode
|
||||
graph_output_map = export_metadata.output_return_type
|
||||
out_spec = export_metadata.out_spec
|
||||
module_call_spec = export_metadata.module_call_spec
|
||||
|
||||
example_inputs: list[Any] = []
|
||||
if out.backend_input is not None:
|
||||
graph = out.backend_input.graph_module
|
||||
fake_mode = out.backend_input.fake_mode
|
||||
example_inputs = out.backend_input.example_inputs
|
||||
else:
|
||||
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
graph.graph.output(None)
|
||||
graph.recompile()
|
||||
fake_mode = out.dynamo_output.tracer_output.output_graph.fake_mode
|
||||
|
||||
# It is not guaranteed that dynamo puts inputs in right order, so we need to
|
||||
# map the actual user order to the dynamo order.
|
||||
graph_input_order: dict[int, int] = {}
|
||||
for inp in graph_inputs:
|
||||
source = graph_inputs[inp]
|
||||
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
||||
graph_input_order[source.index] = len(graph_input_order)
|
||||
# Compute dynamic dimensions for each input based on constraints
|
||||
flat_args_dynamic_dims = [
|
||||
{
|
||||
c.dim
|
||||
for c in (constraints or ())
|
||||
if (
|
||||
c.t_id == id(x)
|
||||
and not isinstance(c, _RelaxedConstraint)
|
||||
and c.constraint_range.vr.lower != c.constraint_range.vr.upper
|
||||
)
|
||||
}
|
||||
for x in flat_inputs
|
||||
]
|
||||
|
||||
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
|
||||
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
|
||||
# Sometimes there can be empty inputs
|
||||
anchor = placeholders[0] if len(placeholders) > 0 else output
|
||||
inp_to_node = {}
|
||||
# Create input order mapping from dynamo's internal order to user order
|
||||
graph_input_order: dict[int, int] = {}
|
||||
for inp in graph_inputs:
|
||||
source = graph_inputs[inp]
|
||||
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
||||
graph_input_order[source.index] = len(graph_input_order)
|
||||
|
||||
with graph.graph.inserting_before(anchor):
|
||||
for i in range(len(flat_inputs)):
|
||||
node_new = graph.graph.placeholder(f"arg_{i}")
|
||||
if i in graph_input_order:
|
||||
placeholders[graph_input_order[i]]
|
||||
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
|
||||
inp_to_node[i] = node_new
|
||||
for real_idx, graph_idx in graph_input_order.items():
|
||||
flat_inputs[real_idx] = example_inputs[graph_idx]
|
||||
|
||||
new_args = []
|
||||
for i in output_return_type:
|
||||
type, val = output_return_type[i]
|
||||
if type == "graph_out":
|
||||
new_args.append(output.args[0][val])
|
||||
if type == "input":
|
||||
input_idx = val.index
|
||||
new_args.append(inp_to_node[input_idx])
|
||||
if type == "constant":
|
||||
new_args.append(val)
|
||||
output.args = (tuple(new_args),)
|
||||
# Use FX transformer to rebuild the graph cleanly
|
||||
transformed_graph = DynamoGraphTransformer(
|
||||
graph,
|
||||
flat_inputs,
|
||||
flat_args_dynamic_dims,
|
||||
graph_input_order,
|
||||
graph_output_map,
|
||||
fake_mode,
|
||||
).transform()
|
||||
|
||||
for src_idx, i in graph_input_order.items():
|
||||
old = placeholders[src_idx]
|
||||
new = inp_to_node[i]
|
||||
old.replace_all_uses_with(new)
|
||||
graph.graph.erase_node(old)
|
||||
|
||||
# Dynamo uses _lazyGraphModule, so we need to force recompile
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(graph)
|
||||
|
||||
graph.graph._codegen = _PyTreeCodeGen(
|
||||
# Set up PyTree codegen for proper input/output handling
|
||||
transformed_graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
argument_names(signature, args, kwargs), # type: ignore[arg-type]
|
||||
argument_names(inspect.signature(mod.forward), args, kwargs), # type: ignore[attr-defined, arg-type]
|
||||
in_spec,
|
||||
out_spec,
|
||||
)
|
||||
)
|
||||
transformed_graph.recompile()
|
||||
|
||||
graph.recompile()
|
||||
return graph
|
||||
clean_nn_module_stack(transformed_graph)
|
||||
clean_export_root(transformed_graph)
|
||||
|
||||
transformed_graph.meta["module_call_specs"] = module_call_spec
|
||||
|
||||
constraint_violation_error = None
|
||||
try:
|
||||
# Check if we have any constraint violations
|
||||
check_fn = out.dynamo_output.build_guards(
|
||||
module_to_trace.forward.__code__
|
||||
).guard_manager
|
||||
check_fn.check(f_locals)
|
||||
except (
|
||||
ConstraintViolationError,
|
||||
torch.utils._sympy.value_ranges.ValueRangeError,
|
||||
) as e:
|
||||
constraint_violation_error = e
|
||||
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not isinstance(
|
||||
module_to_trace.forward,
|
||||
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
|
||||
)
|
||||
):
|
||||
dim_constraints.solve()
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
msg = dim_constraints.prettify_results(
|
||||
inspect.signature(mod.forward), # type: ignore[attr-defined]
|
||||
dynamic_shapes,
|
||||
constraint_violation_error,
|
||||
forced_specializations,
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
constraint_violation_error.args[0] + msg,
|
||||
)
|
||||
else:
|
||||
if forced_specializations:
|
||||
constraint_violation_error = ConstraintViolationError(msg)
|
||||
else:
|
||||
log.info(
|
||||
"Summary of dimension constraints:%s",
|
||||
msg,
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
'Set TORCH_LOGS="+export" for more information.'
|
||||
)
|
||||
if constraint_violation_error:
|
||||
raise constraint_violation_error
|
||||
|
||||
return transformed_graph
|
||||
|
||||
return inner
|
||||
|
@ -379,6 +379,10 @@ class ExportMetaData:
|
||||
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
|
||||
torch.utils._pytree._LEAF_SPEC
|
||||
)
|
||||
module_call_spec: dict[
|
||||
str,
|
||||
dict[str, Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec]],
|
||||
] = dc_field(default_factory=dict)
|
||||
|
||||
|
||||
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
|
||||
@ -1695,6 +1699,19 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
if isinstance(
|
||||
mut_type, (AttributeMutationExisting, ValueMutationExisting)
|
||||
):
|
||||
if isinstance(var, UserDefinedDictVariable) and isinstance(
|
||||
var.value, _ExportModuleSpecTrackerDict
|
||||
):
|
||||
for k, v in var.items.items():
|
||||
specs = {}
|
||||
for k_spec, val in v.items.items():
|
||||
specs[k_spec.vt.as_python_constant()] = (
|
||||
val.as_python_constant()
|
||||
)
|
||||
assert ["in_spec", "out_spec"] == list(specs.keys())
|
||||
self.export_metadata.module_call_spec[
|
||||
k.vt.as_python_constant()
|
||||
] = specs
|
||||
# export uses tracepoint pass to dump submodule inp/out spec
|
||||
# into global state, so we filter it here
|
||||
if not (
|
||||
|
@ -70,6 +70,7 @@ def export_for_training(
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -159,6 +160,7 @@ def export_for_training(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
|
||||
|
||||
@ -171,6 +173,7 @@ def export(
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -283,6 +286,7 @@ def export(
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
pre_dispatch=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
except Exception as e:
|
||||
draft_export_msg = (
|
||||
|
@ -757,6 +757,7 @@ def _export_to_torch_ir(
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
disable_constraint_solver: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
restore_fqn: bool = True,
|
||||
_log_export_usage: bool = True,
|
||||
same_signature: bool = True,
|
||||
@ -809,20 +810,31 @@ def _export_to_torch_ir(
|
||||
f, preserve_module_call_signature, module_call_specs
|
||||
)
|
||||
with ctx, _ignore_backend_decomps():
|
||||
gm_torch_level, _ = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
|
||||
constraints=constraints, # type: ignore[arg-type]
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=_log_export_usage,
|
||||
same_signature=same_signature,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if _use_new_tracer_experimental:
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
)
|
||||
|
||||
gm_torch_level = _dynamo_graph_capture_for_export(
|
||||
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||
)(*args, **kwargs)
|
||||
|
||||
else:
|
||||
gm_torch_level, _ = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
|
||||
constraints=constraints, # type: ignore[arg-type]
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=_log_export_usage,
|
||||
same_signature=same_signature,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
gm_torch_level.meta["module_call_specs"] = module_call_specs
|
||||
except (ConstraintViolationError, ValueRangeError) as e:
|
||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
||||
except GuardOnDataDependentSymNode as e:
|
||||
@ -832,8 +844,6 @@ def _export_to_torch_ir(
|
||||
case_name="constrain_as_size_example",
|
||||
)
|
||||
|
||||
gm_torch_level.meta["module_call_specs"] = module_call_specs
|
||||
|
||||
if isinstance(f, torch.nn.Module) and restore_fqn:
|
||||
_restore_state_dict(f, gm_torch_level)
|
||||
|
||||
@ -1407,6 +1417,7 @@ def _strict_export(
|
||||
orig_in_spec: TreeSpec,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||
_to_aten_func: Callable,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportArtifact:
|
||||
"""
|
||||
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
|
||||
@ -1421,6 +1432,7 @@ def _strict_export(
|
||||
restore_fqn=False, # don't need to restore because we will do it later
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=False,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
|
||||
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
|
||||
@ -2041,6 +2053,7 @@ def _export_for_training(
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
global _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
@ -2056,7 +2069,13 @@ def _export_for_training(
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
export_func = (
|
||||
functools.partial(
|
||||
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
|
||||
)
|
||||
if strict
|
||||
else _non_strict_export
|
||||
)
|
||||
|
||||
alive_fake_input_ids_before_export: list[int] = []
|
||||
|
||||
@ -2185,6 +2204,7 @@ def _export(
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
pre_dispatch: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
_use_new_tracer_experimental: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||
@ -2260,6 +2280,7 @@ def _export(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||
)
|
||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||
return ep
|
||||
|
@ -302,6 +302,12 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
|
||||
return hasattr(t, field)
|
||||
|
||||
|
||||
def _set_attr(model: torch.nn.Module, attr_name: str, value):
|
||||
attr_names = attr_name.split(".")
|
||||
t = _get_attr_via_attr_list(model, attr_names[:-1])
|
||||
setattr(t, attr_names[-1], value)
|
||||
|
||||
|
||||
def _print_readable(
|
||||
module,
|
||||
module_name,
|
||||
|
Reference in New Issue
Block a user