mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
testing infra and some fixes (#162183)
This PR is quite large in that it covers most of rough edges in the new strict export flow: 1. Handle nn_module_stack correctly now that we are tracing wrapper module 2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore. 3. Correct input and output handling. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162167
This commit is contained in:
committed by
PyTorch MergeBot
parent
a965f09793
commit
d8b6622bb6
@ -217,12 +217,19 @@ TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
|
|||||||
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
|
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
|
||||||
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
|
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
|
||||||
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
|
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
|
||||||
|
STRICT_EXPORT_V2_SUFFIX = "_strict_export_v2"
|
||||||
|
|
||||||
|
|
||||||
# Now default mode is non strict, so original unammended test names
|
# Now default mode is non strict, so original unammended test names
|
||||||
# should be treated as non-strict
|
# should be treated as non-strict
|
||||||
def is_non_strict_test(test_name):
|
def is_non_strict_test(test_name):
|
||||||
return not test_name.endswith(STRICT_SUFFIX)
|
return not test_name.endswith(STRICT_SUFFIX) and not test_name.endswith(
|
||||||
|
STRICT_EXPORT_V2_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_strict_v2_test(test_name):
|
||||||
|
return test_name.endswith(STRICT_EXPORT_V2_SUFFIX)
|
||||||
|
|
||||||
|
|
||||||
def is_inline_and_install_strict_test(test_name: str) -> bool:
|
def is_inline_and_install_strict_test(test_name: str) -> bool:
|
||||||
@ -9736,6 +9743,7 @@ graph():
|
|||||||
inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
|
inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
|
||||||
self.assertEqual(ep.module()(**inputs), m(**inputs))
|
self.assertEqual(ep.module()(**inputs), m(**inputs))
|
||||||
|
|
||||||
|
@testing.expectedFailureStrictV2 # AssertionError: RuntimeError not raised
|
||||||
def test_retrace_pre_autograd(self):
|
def test_retrace_pre_autograd(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -11764,6 +11772,7 @@ graph():
|
|||||||
self.assertEqual(ep.module()(3, 5), 8)
|
self.assertEqual(ep.module()(3, 5), 8)
|
||||||
self.assertEqual(ep.module()(5, 4), 9)
|
self.assertEqual(ep.module()(5, 4), 9)
|
||||||
|
|
||||||
|
@testing.expectedFailureStrictV2 # ValueError: Found conflicts between user-specified and inferred ranges
|
||||||
def test_dynamic_shapes_bounds(self):
|
def test_dynamic_shapes_bounds(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -12070,6 +12079,8 @@ graph():
|
|||||||
|
|
||||||
test(export(M(), inp))
|
test(export(M(), inp))
|
||||||
|
|
||||||
|
# Preserving signature hook is messing with dynamo tracing
|
||||||
|
@testing.expectedFailureStrictV2
|
||||||
def test_unflatten_multiple_graphs_state(self):
|
def test_unflatten_multiple_graphs_state(self):
|
||||||
class N(torch.nn.Module):
|
class N(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -13688,7 +13699,7 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
inputs = (torch.randn(10, 72),)
|
inputs = (torch.randn(10, 72),)
|
||||||
dx, dy = dims("dx", "dy")
|
dx, dy = dims("dx", "dy")
|
||||||
ep = torch.export.export(
|
ep = torch.export._trace._export(
|
||||||
Mod4Reshape(),
|
Mod4Reshape(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes={"x": (dx, dy)},
|
dynamic_shapes={"x": (dx, dy)},
|
||||||
@ -14536,6 +14547,14 @@ graph():
|
|||||||
if is_inline_and_install_strict_test(self._testMethodName):
|
if is_inline_and_install_strict_test(self._testMethodName):
|
||||||
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
|
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
|
||||||
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
|
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
|
||||||
|
# This is fine since both of these will be deprecated soon.
|
||||||
|
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
|
||||||
|
self.assertEqual(
|
||||||
|
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).0"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
|
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
|
||||||
@ -15374,6 +15393,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation
|
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation
|
||||||
|
@testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation
|
||||||
@testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation
|
@testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation
|
||||||
@testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation
|
@testing.expectedFailureTrainingIRToRunDecomp # test_hop doesn't have a dynamo implementation
|
||||||
@testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet
|
@testing.expectedFailureSerDerNonStrict # TODO: serde torch.FunctionSchema is not implemented yet
|
||||||
|
54
test/export/test_strict_export_v2.py
Normal file
54
test/export/test_strict_export_v2.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Owner(s): ["oncall: export"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import test_export, testing
|
||||||
|
except ImportError:
|
||||||
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||||
|
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||||
|
|
||||||
|
from torch.export import export
|
||||||
|
|
||||||
|
|
||||||
|
test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
|
def mocked_strict_export_v2(*args, **kwargs):
|
||||||
|
# If user already specified strict, don't make it strict
|
||||||
|
if "strict" in kwargs:
|
||||||
|
if kwargs["strict"]:
|
||||||
|
return export(*args, **kwargs, _use_new_tracer_experimental=True)
|
||||||
|
else:
|
||||||
|
return export(*args, **kwargs)
|
||||||
|
return export(*args, **kwargs, strict=True, _use_new_tracer_experimental=True)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dynamic_cls(cls):
|
||||||
|
cls_prefix = "StrictExportV2"
|
||||||
|
|
||||||
|
test_class = testing.make_test_cls_with_mocked_export(
|
||||||
|
cls,
|
||||||
|
cls_prefix,
|
||||||
|
test_export.STRICT_EXPORT_V2_SUFFIX,
|
||||||
|
mocked_strict_export_v2,
|
||||||
|
xfail_prop="_expected_failure_strict_v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
test_classes[test_class.__name__] = test_class
|
||||||
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||||
|
globals()[test_class.__name__] = test_class
|
||||||
|
test_class.__module__ = __name__
|
||||||
|
return test_class
|
||||||
|
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_export.TestDynamismExpression,
|
||||||
|
test_export.TestExport,
|
||||||
|
]
|
||||||
|
for test in tests:
|
||||||
|
make_dynamic_cls(test)
|
||||||
|
del test
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
@ -257,6 +257,12 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
# Controls tests generated in test/export/test_export_strict_v2.py
|
||||||
|
def expectedFailureStrictV2(fn):
|
||||||
|
fn._expected_failure_strict_v2 = True
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
# Controls tests generated in test/export/test_export_strict.py
|
# Controls tests generated in test/export/test_export_strict.py
|
||||||
def expectedFailureStrict(fn):
|
def expectedFailureStrict(fn):
|
||||||
fn._expected_failure_strict = True
|
fn._expected_failure_strict = True
|
||||||
|
@ -258,6 +258,7 @@ def aot_compile_fullgraph(
|
|||||||
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
||||||
output_graph = dynamo_output.tracer_output.output_graph
|
output_graph = dynamo_output.tracer_output.output_graph
|
||||||
assert output_graph is not None
|
assert output_graph is not None
|
||||||
|
assert backend_input is not None
|
||||||
import_sources = output_graph.import_sources
|
import_sources = output_graph.import_sources
|
||||||
with (
|
with (
|
||||||
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
||||||
|
@ -95,6 +95,7 @@ from .cache_size import (
|
|||||||
)
|
)
|
||||||
from .eval_frame import (
|
from .eval_frame import (
|
||||||
always_optimize_code_objects,
|
always_optimize_code_objects,
|
||||||
|
Constraint,
|
||||||
dynamo_tls,
|
dynamo_tls,
|
||||||
skip_code,
|
skip_code,
|
||||||
TorchPatcher,
|
TorchPatcher,
|
||||||
@ -894,7 +895,8 @@ class CaptureOutput:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dynamo_output: DynamoOutput
|
dynamo_output: DynamoOutput
|
||||||
backend_input: BackendInput
|
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
|
||||||
|
backend_input: Optional[BackendInput]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -907,7 +909,10 @@ class FrameInfo:
|
|||||||
|
|
||||||
|
|
||||||
def fullgraph_capture(
|
def fullgraph_capture(
|
||||||
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
|
frame: FrameInfo,
|
||||||
|
*,
|
||||||
|
constraints: Optional[list[Constraint]] = None,
|
||||||
|
_is_export_deprecated_do_not_use: bool = False,
|
||||||
) -> CaptureOutput:
|
) -> CaptureOutput:
|
||||||
"""
|
"""
|
||||||
A standalone function which takes a frame and returns dynamo captured graph
|
A standalone function which takes a frame and returns dynamo captured graph
|
||||||
@ -951,6 +956,7 @@ def fullgraph_capture(
|
|||||||
frame.closure,
|
frame.closure,
|
||||||
compiler_fn=fullgraph_compiler,
|
compiler_fn=fullgraph_compiler,
|
||||||
export=_is_export_deprecated_do_not_use,
|
export=_is_export_deprecated_do_not_use,
|
||||||
|
export_constraints=constraints, # type: ignore[arg-type]
|
||||||
one_graph=True,
|
one_graph=True,
|
||||||
restart_reasons=set(),
|
restart_reasons=set(),
|
||||||
)
|
)
|
||||||
@ -966,7 +972,6 @@ def fullgraph_capture(
|
|||||||
cur_exn = cur_exn.__cause__
|
cur_exn = cur_exn.__cause__
|
||||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||||
|
|
||||||
assert backend_input is not None
|
|
||||||
return CaptureOutput(dynamo_output, backend_input)
|
return CaptureOutput(dynamo_output, backend_input)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,17 +1,70 @@
|
|||||||
import builtins
|
import builtins
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.fx
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
|
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
|
||||||
from torch._dynamo.eval_frame import argument_names
|
from torch._dynamo.eval_frame import argument_names
|
||||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||||
from torch._guards import compile_context, CompileContext
|
from torch._guards import compile_context, CompileContext
|
||||||
|
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||||
|
from torch.fx import Node
|
||||||
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
ConstraintViolationError,
|
||||||
|
DimDynamic,
|
||||||
|
StatelessSymbolicContext,
|
||||||
|
)
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||||
|
for node in graph_module.graph.nodes:
|
||||||
|
if "nn_module_stack" in node.meta:
|
||||||
|
nn_module_stack = node.meta["nn_module_stack"].copy()
|
||||||
|
first_key = next(iter(nn_module_stack.keys()))
|
||||||
|
if "export_root" in first_key:
|
||||||
|
del nn_module_stack[first_key]
|
||||||
|
nn_module_stack_corrected = {}
|
||||||
|
for k, v in nn_module_stack.items():
|
||||||
|
k_new = "".join(k.split("__export_root"))
|
||||||
|
child_name, child_class = v
|
||||||
|
child_name = child_name.replace("._export_root", "")
|
||||||
|
nn_module_stack_corrected[k_new] = (child_name, child_class)
|
||||||
|
node.meta["nn_module_stack"] = nn_module_stack_corrected
|
||||||
|
return graph_module
|
||||||
|
|
||||||
|
|
||||||
|
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||||
|
"""Remove export_root artifacts from FX graph in-place"""
|
||||||
|
|
||||||
|
# Clean parameter names: L__self____export_root_param -> L__self___param
|
||||||
|
def clean_name(name) -> str:
|
||||||
|
return name.replace("__export_root_", "_") if "__export_root_" in name else name
|
||||||
|
|
||||||
|
# Update get_attr nodes in-place
|
||||||
|
for node in graph_module.graph.nodes:
|
||||||
|
if node.op == "get_attr":
|
||||||
|
old_target = node.target
|
||||||
|
new_target = clean_name(old_target)
|
||||||
|
if new_target != old_target:
|
||||||
|
node.target = new_target
|
||||||
|
# Move the parameter to the new name
|
||||||
|
if hasattr(graph_module, old_target):
|
||||||
|
param = torch.fx.graph_module._get_attr(graph_module, old_target)
|
||||||
|
torch.fx.graph_module._set_attr(graph_module, new_target, param)
|
||||||
|
torch.fx.graph_module._del_attr(graph_module, old_target)
|
||||||
|
|
||||||
|
|
||||||
class ModuleToTrace(torch.nn.Module):
|
class ModuleToTrace(torch.nn.Module):
|
||||||
def __init__(self, foo: Any, in_spec: Any) -> None:
|
def __init__(self, foo: Any, in_spec: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -28,29 +81,214 @@ class ModuleToTrace(torch.nn.Module):
|
|||||||
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
|
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
|
||||||
|
|
||||||
|
|
||||||
|
# mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator"
|
||||||
|
class DynamoGraphTransformer(torch.fx.Transformer):
|
||||||
|
"""Graph transformer for dynamo export that flattens inputs/outputs without complex matching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.fx.GraphModule,
|
||||||
|
flat_inputs: list[Any],
|
||||||
|
flat_args_dynamic_dims: list[set[int]],
|
||||||
|
graph_input_order: dict[int, int],
|
||||||
|
graph_output_map: dict[int, tuple[str, Any]],
|
||||||
|
fake_mode: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(module)
|
||||||
|
|
||||||
|
assert len(flat_args_dynamic_dims) == len(flat_inputs)
|
||||||
|
|
||||||
|
self.flat_inputs = flat_inputs
|
||||||
|
self.flat_args_dynamic_dims = flat_args_dynamic_dims
|
||||||
|
self.graph_input_order = graph_input_order
|
||||||
|
self.graph_output_map = graph_output_map
|
||||||
|
self.fake_mode = fake_mode
|
||||||
|
|
||||||
|
# Get original placeholders and output
|
||||||
|
self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
|
||||||
|
self.output_node = next(n for n in module.graph.nodes if n.op == "output")
|
||||||
|
|
||||||
|
# Create new flattened input placeholders
|
||||||
|
self.new_input_nodes: dict[int, torch.fx.Node] = {}
|
||||||
|
self._create_flattened_inputs()
|
||||||
|
|
||||||
|
# Iterator for replacing old placeholders
|
||||||
|
self.old_to_new_mapping = {}
|
||||||
|
self._create_placeholder_mapping()
|
||||||
|
|
||||||
|
def _create_flattened_inputs(self) -> None:
|
||||||
|
"""Create new placeholder nodes for flattened inputs with proper fake tensors."""
|
||||||
|
for i in range(len(self.flat_inputs)):
|
||||||
|
placeholder = super().placeholder(f"arg_{i}", (), {})
|
||||||
|
|
||||||
|
# Check if this user input (index i) maps to a graph placeholder
|
||||||
|
if i in self.graph_input_order:
|
||||||
|
# graph_input_order[i] gives us which graph placeholder this user input corresponds to
|
||||||
|
graph_placeholder_idx = self.graph_input_order[i]
|
||||||
|
if graph_placeholder_idx < len(self.placeholders):
|
||||||
|
orig_placeholder = self.placeholders[graph_placeholder_idx]
|
||||||
|
# Copy other metadata but not "val" yet
|
||||||
|
for key, value in orig_placeholder.meta.items():
|
||||||
|
if key != "val":
|
||||||
|
placeholder.node.meta[key] = value
|
||||||
|
|
||||||
|
# Always ensure we have proper "val" metadata from fake tensor
|
||||||
|
if self.fake_mode is not None and isinstance(
|
||||||
|
self.flat_inputs[i], torch.Tensor
|
||||||
|
):
|
||||||
|
placeholder.node.meta["val"] = self.fake_mode.from_tensor(
|
||||||
|
self.flat_inputs[i],
|
||||||
|
symbolic_context=StatelessSymbolicContext(
|
||||||
|
dynamic_sizes=[
|
||||||
|
(
|
||||||
|
DimDynamic.DYNAMIC
|
||||||
|
if d in self.flat_args_dynamic_dims[i]
|
||||||
|
else DimDynamic.STATIC
|
||||||
|
)
|
||||||
|
for d in range(len(self.flat_inputs[i].shape))
|
||||||
|
],
|
||||||
|
constraint_sizes=[None] * len(self.flat_inputs[i].shape),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case
|
||||||
|
placeholder.node.meta["val"] = self.flat_inputs[i].val
|
||||||
|
else:
|
||||||
|
placeholder.node.meta["val"] = self.flat_inputs[i]
|
||||||
|
|
||||||
|
self.new_input_nodes[i] = placeholder
|
||||||
|
|
||||||
|
def _create_placeholder_mapping(self) -> None:
|
||||||
|
"""Create mapping from old placeholders to new ones."""
|
||||||
|
# graph_input_order maps: user_input_index -> graph_placeholder_index
|
||||||
|
# We need to create: old_graph_placeholder -> new_user_input_placeholder
|
||||||
|
for user_input_idx, graph_placeholder_idx in self.graph_input_order.items():
|
||||||
|
if graph_placeholder_idx < len(self.placeholders):
|
||||||
|
old_placeholder = self.placeholders[graph_placeholder_idx]
|
||||||
|
new_placeholder = self.new_input_nodes[user_input_idx]
|
||||||
|
self.old_to_new_mapping[old_placeholder] = new_placeholder
|
||||||
|
|
||||||
|
def placeholder(self, target, args, kwargs) -> Any:
|
||||||
|
"""Replace old placeholders with new flattened ones."""
|
||||||
|
# Return the corresponding new placeholder
|
||||||
|
if self.current_node in self.old_to_new_mapping:
|
||||||
|
new_arg = self.old_to_new_mapping[self.current_node]
|
||||||
|
|
||||||
|
# Copy over additional metadata from current node, but don't overwrite "val"
|
||||||
|
for key in ["tensor_dict", "example_value", "unbacked_bindings"]:
|
||||||
|
if key in self.current_node.meta:
|
||||||
|
new_arg.node.meta[key] = self.current_node.meta[key]
|
||||||
|
|
||||||
|
# Only copy "val" if we don't already have a good one
|
||||||
|
if "val" in self.current_node.meta and "val" not in new_arg.node.meta:
|
||||||
|
new_arg.node.meta["val"] = self.current_node.meta["val"]
|
||||||
|
|
||||||
|
return new_arg
|
||||||
|
else:
|
||||||
|
# Shouldn't happen if mapping is correct, but fallback
|
||||||
|
return super().placeholder(target, args, kwargs)
|
||||||
|
|
||||||
|
def output(self, target, args, kwargs) -> Any:
|
||||||
|
"""Transform output according to graph_output_map."""
|
||||||
|
original_outputs = args[0]
|
||||||
|
|
||||||
|
# Build new output list based on graph_output_map
|
||||||
|
new_outputs = []
|
||||||
|
for i in sorted(self.graph_output_map.keys()):
|
||||||
|
output_type, val = self.graph_output_map[i]
|
||||||
|
|
||||||
|
if output_type == "graph_out":
|
||||||
|
new_outputs.append(original_outputs[val])
|
||||||
|
elif output_type == "input":
|
||||||
|
input_idx = val.index
|
||||||
|
new_outputs.append(self.new_input_nodes[input_idx])
|
||||||
|
elif output_type == "constant":
|
||||||
|
new_outputs.append(val)
|
||||||
|
|
||||||
|
return super().output(target, (tuple(new_outputs),), {})
|
||||||
|
|
||||||
|
def run_node(self, node: Node) -> Any:
|
||||||
|
"""Run node transformation and preserve metadata."""
|
||||||
|
self.current_node = node
|
||||||
|
result = super().run_node(node)
|
||||||
|
|
||||||
|
# Copy important metadata
|
||||||
|
if hasattr(result, "node") and result.node is not node:
|
||||||
|
for key in ["val", "example_value", "unbacked_bindings"]:
|
||||||
|
if key in node.meta:
|
||||||
|
result.node.meta[key] = node.meta[key]
|
||||||
|
|
||||||
|
# Preserve node names (except output)
|
||||||
|
if node.op != "output" and hasattr(node, "name"):
|
||||||
|
result.node._rename(node.name)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def transform(self) -> torch.fx.GraphModule:
|
||||||
|
"""Perform the graph transformation and copy module metadata."""
|
||||||
|
result_gm = super().transform()
|
||||||
|
|
||||||
|
# Copy module metadata like the original implementation
|
||||||
|
if hasattr(self.module, "meta"):
|
||||||
|
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
|
||||||
|
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
|
||||||
|
"dynamo_flat_name_to_original_fqn"
|
||||||
|
]
|
||||||
|
if "dynamo_compile_id" in self.module.meta:
|
||||||
|
result_gm.meta["dynamo_compile_id"] = self.module.meta[
|
||||||
|
"dynamo_compile_id"
|
||||||
|
]
|
||||||
|
|
||||||
|
return result_gm
|
||||||
|
|
||||||
|
|
||||||
def _dynamo_graph_capture_for_export(
|
def _dynamo_graph_capture_for_export(
|
||||||
mod: torch.nn.Module,
|
mod: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
constraints: Optional[list[Constraint]] = None,
|
||||||
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||||
) -> Callable[..., torch.fx.GraphModule]:
|
) -> Callable[..., torch.fx.GraphModule]:
|
||||||
"""
|
"""
|
||||||
This is lower level API that is used for export to capture dynamo level
|
Improved dynamo graph capture using transformer approach with proper fake tensor handling.
|
||||||
torch IR.
|
|
||||||
|
|
||||||
Notable TODOs:
|
This function creates a capture instance that handles:
|
||||||
|
1. PyTree flattening/unflattening with proper input ordering
|
||||||
|
2. Dynamo graph capture with export-specific context
|
||||||
|
3. FX graph transformation for export compatibility
|
||||||
|
4. Proper fake tensor metadata preservation
|
||||||
|
5. Dynamic dimension constraint handling
|
||||||
|
|
||||||
|
Notable improvements over manual approach:
|
||||||
|
- Uses FX Transformer for cleaner graph manipulation
|
||||||
|
- Properly handles fake tensor metadata and dynamic dimensions
|
||||||
|
- Preserves all necessary metadata for export
|
||||||
|
- More robust error handling and edge case management
|
||||||
|
|
||||||
|
TODO:
|
||||||
1. Are we actually gonna run the bytecode?
|
1. Are we actually gonna run the bytecode?
|
||||||
2. Need to attach guards
|
2. Need to attach guards
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_dynamic_shapes = dynamic_shapes
|
||||||
|
_constraints = constraints
|
||||||
|
|
||||||
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
|
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
|
||||||
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
||||||
module_to_trace = ModuleToTrace(mod, in_spec)
|
module_to_trace = ModuleToTrace(mod, in_spec)
|
||||||
|
|
||||||
signature = inspect.signature(module_to_trace.forward)
|
signature = inspect.signature(module_to_trace.forward)
|
||||||
|
|
||||||
bound_arguments = signature.bind(*flat_inputs)
|
bound_arguments = signature.bind(*flat_inputs)
|
||||||
bound_arguments.apply_defaults()
|
bound_arguments.apply_defaults()
|
||||||
|
|
||||||
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
constraints: Optional[list[Constraint]] = _constraints
|
||||||
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
|
||||||
|
_dynamic_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import reset # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
reset()
|
||||||
|
|
||||||
|
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
||||||
frame = FrameInfo(
|
frame = FrameInfo(
|
||||||
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
|
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
|
||||||
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
|
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
|
||||||
@ -60,7 +298,14 @@ def _dynamo_graph_capture_for_export(
|
|||||||
)
|
)
|
||||||
|
|
||||||
dynamo_config_ctx = torch._dynamo.config.patch(
|
dynamo_config_ctx = torch._dynamo.config.patch(
|
||||||
"log_graph_in_out_metadata", True
|
specialize_int=True,
|
||||||
|
specialize_float=True,
|
||||||
|
assume_static_by_default=True,
|
||||||
|
automatic_dynamic_shapes=False,
|
||||||
|
capture_dynamic_output_shape_ops=True,
|
||||||
|
capture_scalar_outputs=True,
|
||||||
|
prefer_deferred_runtime_asserts_over_guards=False,
|
||||||
|
log_graph_in_out_metadata=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
@ -69,74 +314,137 @@ def _dynamo_graph_capture_for_export(
|
|||||||
dynamo_timed("fullgraph_capture"),
|
dynamo_timed("fullgraph_capture"),
|
||||||
dynamo_config_ctx,
|
dynamo_config_ctx,
|
||||||
):
|
):
|
||||||
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
|
out = fullgraph_capture(
|
||||||
|
frame,
|
||||||
|
constraints=_constraints,
|
||||||
|
_is_export_deprecated_do_not_use=True,
|
||||||
|
)
|
||||||
|
|
||||||
assert out.dynamo_output.tracer_output.output_graph is not None
|
assert out.dynamo_output.tracer_output.output_graph is not None
|
||||||
|
|
||||||
|
# Extract export metadata from the new location
|
||||||
export_metadata = (
|
export_metadata = (
|
||||||
out.dynamo_output.tracer_output.output_graph.export_metadata
|
out.dynamo_output.tracer_output.output_graph.export_metadata
|
||||||
)
|
)
|
||||||
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
||||||
output_return_type = export_metadata.output_return_type
|
graph_output_map = export_metadata.output_return_type
|
||||||
# We need to extract out_spec here because we are not actually running the bytecode
|
|
||||||
out_spec = export_metadata.out_spec
|
out_spec = export_metadata.out_spec
|
||||||
|
module_call_spec = export_metadata.module_call_spec
|
||||||
|
|
||||||
|
example_inputs: list[Any] = []
|
||||||
|
if out.backend_input is not None:
|
||||||
graph = out.backend_input.graph_module
|
graph = out.backend_input.graph_module
|
||||||
|
fake_mode = out.backend_input.fake_mode
|
||||||
|
example_inputs = out.backend_input.example_inputs
|
||||||
|
else:
|
||||||
|
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||||
|
graph.graph.output(None)
|
||||||
|
graph.recompile()
|
||||||
|
fake_mode = out.dynamo_output.tracer_output.output_graph.fake_mode
|
||||||
|
|
||||||
# It is not guaranteed that dynamo puts inputs in right order, so we need to
|
# Compute dynamic dimensions for each input based on constraints
|
||||||
# map the actual user order to the dynamo order.
|
flat_args_dynamic_dims = [
|
||||||
graph_input_order: dict[int, int] = {}
|
{
|
||||||
for inp in graph_inputs:
|
c.dim
|
||||||
source = graph_inputs[inp]
|
for c in (constraints or ())
|
||||||
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
if (
|
||||||
graph_input_order[source.index] = len(graph_input_order)
|
c.t_id == id(x)
|
||||||
|
and not isinstance(c, _RelaxedConstraint)
|
||||||
|
and c.constraint_range.vr.lower != c.constraint_range.vr.upper
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for x in flat_inputs
|
||||||
|
]
|
||||||
|
|
||||||
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
|
# Create input order mapping from dynamo's internal order to user order
|
||||||
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
|
graph_input_order: dict[int, int] = {}
|
||||||
# Sometimes there can be empty inputs
|
for inp in graph_inputs:
|
||||||
anchor = placeholders[0] if len(placeholders) > 0 else output
|
source = graph_inputs[inp]
|
||||||
inp_to_node = {}
|
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
||||||
|
graph_input_order[source.index] = len(graph_input_order)
|
||||||
|
|
||||||
with graph.graph.inserting_before(anchor):
|
for real_idx, graph_idx in graph_input_order.items():
|
||||||
for i in range(len(flat_inputs)):
|
flat_inputs[real_idx] = example_inputs[graph_idx]
|
||||||
node_new = graph.graph.placeholder(f"arg_{i}")
|
|
||||||
if i in graph_input_order:
|
|
||||||
placeholders[graph_input_order[i]]
|
|
||||||
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
|
|
||||||
inp_to_node[i] = node_new
|
|
||||||
|
|
||||||
new_args = []
|
# Use FX transformer to rebuild the graph cleanly
|
||||||
for i in output_return_type:
|
transformed_graph = DynamoGraphTransformer(
|
||||||
type, val = output_return_type[i]
|
graph,
|
||||||
if type == "graph_out":
|
flat_inputs,
|
||||||
new_args.append(output.args[0][val])
|
flat_args_dynamic_dims,
|
||||||
if type == "input":
|
graph_input_order,
|
||||||
input_idx = val.index
|
graph_output_map,
|
||||||
new_args.append(inp_to_node[input_idx])
|
fake_mode,
|
||||||
if type == "constant":
|
).transform()
|
||||||
new_args.append(val)
|
|
||||||
output.args = (tuple(new_args),)
|
|
||||||
|
|
||||||
for src_idx, i in graph_input_order.items():
|
# Set up PyTree codegen for proper input/output handling
|
||||||
old = placeholders[src_idx]
|
transformed_graph.graph._codegen = _PyTreeCodeGen(
|
||||||
new = inp_to_node[i]
|
|
||||||
old.replace_all_uses_with(new)
|
|
||||||
graph.graph.erase_node(old)
|
|
||||||
|
|
||||||
# Dynamo uses _lazyGraphModule, so we need to force recompile
|
|
||||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
|
||||||
|
|
||||||
_LazyGraphModule.force_recompile(graph)
|
|
||||||
|
|
||||||
graph.graph._codegen = _PyTreeCodeGen(
|
|
||||||
_PyTreeInfo(
|
_PyTreeInfo(
|
||||||
argument_names(signature, args, kwargs), # type: ignore[arg-type]
|
argument_names(inspect.signature(mod.forward), args, kwargs), # type: ignore[attr-defined, arg-type]
|
||||||
in_spec,
|
in_spec,
|
||||||
out_spec,
|
out_spec,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
transformed_graph.recompile()
|
||||||
|
|
||||||
graph.recompile()
|
clean_nn_module_stack(transformed_graph)
|
||||||
return graph
|
clean_export_root(transformed_graph)
|
||||||
|
|
||||||
|
transformed_graph.meta["module_call_specs"] = module_call_spec
|
||||||
|
|
||||||
|
constraint_violation_error = None
|
||||||
|
try:
|
||||||
|
# Check if we have any constraint violations
|
||||||
|
check_fn = out.dynamo_output.build_guards(
|
||||||
|
module_to_trace.forward.__code__
|
||||||
|
).guard_manager
|
||||||
|
check_fn.check(f_locals)
|
||||||
|
except (
|
||||||
|
ConstraintViolationError,
|
||||||
|
torch.utils._sympy.value_ranges.ValueRangeError,
|
||||||
|
) as e:
|
||||||
|
constraint_violation_error = e
|
||||||
|
|
||||||
|
if (
|
||||||
|
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||||
|
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||||
|
and not isinstance(
|
||||||
|
module_to_trace.forward,
|
||||||
|
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
dim_constraints.solve()
|
||||||
|
forced_specializations = dim_constraints.forced_specializations()
|
||||||
|
msg = dim_constraints.prettify_results(
|
||||||
|
inspect.signature(mod.forward), # type: ignore[attr-defined]
|
||||||
|
dynamic_shapes,
|
||||||
|
constraint_violation_error,
|
||||||
|
forced_specializations,
|
||||||
|
)
|
||||||
|
if constraint_violation_error:
|
||||||
|
constraint_violation_error.args = (
|
||||||
|
constraint_violation_error.args[0] + msg,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if forced_specializations:
|
||||||
|
constraint_violation_error = ConstraintViolationError(msg)
|
||||||
|
else:
|
||||||
|
log.info(
|
||||||
|
"Summary of dimension constraints:%s",
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Error if we have any constraints on static values
|
||||||
|
for k in shape_env.var_to_range.keys():
|
||||||
|
if isinstance(k, sympy.Integer):
|
||||||
|
constraint_violation_error = ConstraintViolationError(
|
||||||
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||||
|
"It appears that you're trying to set a constraint on a "
|
||||||
|
f"value which we evaluated to have a static value of {k}. "
|
||||||
|
'Set TORCH_LOGS="+export" for more information.'
|
||||||
|
)
|
||||||
|
if constraint_violation_error:
|
||||||
|
raise constraint_violation_error
|
||||||
|
|
||||||
|
return transformed_graph
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
@ -379,6 +379,10 @@ class ExportMetaData:
|
|||||||
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
|
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
|
||||||
torch.utils._pytree._LEAF_SPEC
|
torch.utils._pytree._LEAF_SPEC
|
||||||
)
|
)
|
||||||
|
module_call_spec: dict[
|
||||||
|
str,
|
||||||
|
dict[str, Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec]],
|
||||||
|
] = dc_field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
|
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
|
||||||
@ -1695,6 +1699,19 @@ class OutputGraph(OutputGraphGuardsState):
|
|||||||
if isinstance(
|
if isinstance(
|
||||||
mut_type, (AttributeMutationExisting, ValueMutationExisting)
|
mut_type, (AttributeMutationExisting, ValueMutationExisting)
|
||||||
):
|
):
|
||||||
|
if isinstance(var, UserDefinedDictVariable) and isinstance(
|
||||||
|
var.value, _ExportModuleSpecTrackerDict
|
||||||
|
):
|
||||||
|
for k, v in var.items.items():
|
||||||
|
specs = {}
|
||||||
|
for k_spec, val in v.items.items():
|
||||||
|
specs[k_spec.vt.as_python_constant()] = (
|
||||||
|
val.as_python_constant()
|
||||||
|
)
|
||||||
|
assert ["in_spec", "out_spec"] == list(specs.keys())
|
||||||
|
self.export_metadata.module_call_spec[
|
||||||
|
k.vt.as_python_constant()
|
||||||
|
] = specs
|
||||||
# export uses tracepoint pass to dump submodule inp/out spec
|
# export uses tracepoint pass to dump submodule inp/out spec
|
||||||
# into global state, so we filter it here
|
# into global state, so we filter it here
|
||||||
if not (
|
if not (
|
||||||
|
@ -70,6 +70,7 @@ def export_for_training(
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
"""
|
"""
|
||||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||||
@ -159,6 +160,7 @@ def export_for_training(
|
|||||||
strict=strict,
|
strict=strict,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
|
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -171,6 +173,7 @@ def export(
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
"""
|
"""
|
||||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||||
@ -283,6 +286,7 @@ def export(
|
|||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
pre_dispatch=True,
|
pre_dispatch=True,
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
|
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
draft_export_msg = (
|
draft_export_msg = (
|
||||||
|
@ -757,6 +757,7 @@ def _export_to_torch_ir(
|
|||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
disable_constraint_solver: bool = False,
|
disable_constraint_solver: bool = False,
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
restore_fqn: bool = True,
|
restore_fqn: bool = True,
|
||||||
_log_export_usage: bool = True,
|
_log_export_usage: bool = True,
|
||||||
same_signature: bool = True,
|
same_signature: bool = True,
|
||||||
@ -809,20 +810,31 @@ def _export_to_torch_ir(
|
|||||||
f, preserve_module_call_signature, module_call_specs
|
f, preserve_module_call_signature, module_call_specs
|
||||||
)
|
)
|
||||||
with ctx, _ignore_backend_decomps():
|
with ctx, _ignore_backend_decomps():
|
||||||
gm_torch_level, _ = torch._dynamo.export(
|
if _use_new_tracer_experimental:
|
||||||
f,
|
from torch._dynamo.functional_export import (
|
||||||
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
|
_dynamo_graph_capture_for_export,
|
||||||
constraints=constraints, # type: ignore[arg-type]
|
)
|
||||||
assume_static_by_default=True,
|
|
||||||
tracing_mode="symbolic",
|
gm_torch_level = _dynamo_graph_capture_for_export(
|
||||||
disable_constraint_solver=disable_constraint_solver,
|
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
)(*args, **kwargs)
|
||||||
_log_export_usage=_log_export_usage,
|
|
||||||
same_signature=same_signature,
|
else:
|
||||||
)(
|
gm_torch_level, _ = torch._dynamo.export(
|
||||||
*args,
|
f,
|
||||||
**kwargs,
|
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
|
||||||
)
|
constraints=constraints, # type: ignore[arg-type]
|
||||||
|
assume_static_by_default=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
disable_constraint_solver=disable_constraint_solver,
|
||||||
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
|
_log_export_usage=_log_export_usage,
|
||||||
|
same_signature=same_signature,
|
||||||
|
)(
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
gm_torch_level.meta["module_call_specs"] = module_call_specs
|
||||||
except (ConstraintViolationError, ValueRangeError) as e:
|
except (ConstraintViolationError, ValueRangeError) as e:
|
||||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
||||||
except GuardOnDataDependentSymNode as e:
|
except GuardOnDataDependentSymNode as e:
|
||||||
@ -832,8 +844,6 @@ def _export_to_torch_ir(
|
|||||||
case_name="constrain_as_size_example",
|
case_name="constrain_as_size_example",
|
||||||
)
|
)
|
||||||
|
|
||||||
gm_torch_level.meta["module_call_specs"] = module_call_specs
|
|
||||||
|
|
||||||
if isinstance(f, torch.nn.Module) and restore_fqn:
|
if isinstance(f, torch.nn.Module) and restore_fqn:
|
||||||
_restore_state_dict(f, gm_torch_level)
|
_restore_state_dict(f, gm_torch_level)
|
||||||
|
|
||||||
@ -1407,6 +1417,7 @@ def _strict_export(
|
|||||||
orig_in_spec: TreeSpec,
|
orig_in_spec: TreeSpec,
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool,
|
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||||
_to_aten_func: Callable,
|
_to_aten_func: Callable,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
) -> ExportArtifact:
|
) -> ExportArtifact:
|
||||||
"""
|
"""
|
||||||
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
|
_to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
|
||||||
@ -1421,6 +1432,7 @@ def _strict_export(
|
|||||||
restore_fqn=False, # don't need to restore because we will do it later
|
restore_fqn=False, # don't need to restore because we will do it later
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
_log_export_usage=False,
|
_log_export_usage=False,
|
||||||
|
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
|
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
|
||||||
@ -2041,6 +2053,7 @@ def _export_for_training(
|
|||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
global _EXPORT_MODULE_HIERARCHY
|
global _EXPORT_MODULE_HIERARCHY
|
||||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||||
@ -2056,7 +2069,13 @@ def _export_for_training(
|
|||||||
original_state_dict = _get_original_state_dict(mod)
|
original_state_dict = _get_original_state_dict(mod)
|
||||||
|
|
||||||
# Call the appropriate export function based on the strictness of tracing.
|
# Call the appropriate export function based on the strictness of tracing.
|
||||||
export_func = _strict_export if strict else _non_strict_export
|
export_func = (
|
||||||
|
functools.partial(
|
||||||
|
_strict_export, _use_new_tracer_experimental=_use_new_tracer_experimental
|
||||||
|
)
|
||||||
|
if strict
|
||||||
|
else _non_strict_export
|
||||||
|
)
|
||||||
|
|
||||||
alive_fake_input_ids_before_export: list[int] = []
|
alive_fake_input_ids_before_export: list[int] = []
|
||||||
|
|
||||||
@ -2185,6 +2204,7 @@ def _export(
|
|||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
pre_dispatch: bool = False,
|
pre_dispatch: bool = False,
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
|
_use_new_tracer_experimental: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
"""
|
"""
|
||||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||||
@ -2260,6 +2280,7 @@ def _export(
|
|||||||
strict=strict,
|
strict=strict,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
|
_use_new_tracer_experimental=_use_new_tracer_experimental,
|
||||||
)
|
)
|
||||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||||
return ep
|
return ep
|
||||||
|
@ -302,6 +302,12 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
|
|||||||
return hasattr(t, field)
|
return hasattr(t, field)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_attr(model: torch.nn.Module, attr_name: str, value):
|
||||||
|
attr_names = attr_name.split(".")
|
||||||
|
t = _get_attr_via_attr_list(model, attr_names[:-1])
|
||||||
|
setattr(t, attr_names[-1], value)
|
||||||
|
|
||||||
|
|
||||||
def _print_readable(
|
def _print_readable(
|
||||||
module,
|
module,
|
||||||
module_name,
|
module_name,
|
||||||
|
Reference in New Issue
Block a user