Compare commits

...

1 Commits

Author SHA1 Message Date
a2a41a9b58 [export] Codemod unittests to use new graph capture API
Summary:
as title.

Test Plan:
pytest test/functorch/test_aot_joint_with_descriptors.py
pytest test/higher_order_ops/test_local_map.py
2025-11-04 07:50:48 -08:00
2 changed files with 16 additions and 35 deletions

View File

@ -13,7 +13,7 @@ import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._dynamo.testing import normalize_gm
from torch._functorch._aot_autograd.descriptors import (
BufferAOTInput,
@ -48,17 +48,13 @@ from torch.testing._internal.common_utils import (
def graph_capture(model, inputs, with_export):
gm = model
fake_mode = None
tracing_context = None
if with_export:
with (
torch._dynamo.config.patch(install_free_tensors=True),
fx_traceback.preserve_node_meta(),
):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = gm.meta.get("fake_mode", None)
with fx_traceback.preserve_node_meta():
gm = dynamo_graph_capture_for_export(model)(*inputs)
tracing_context = gm.meta.get("tracing_context", None)
with tracing(TracingContext(fake_mode)):
with tracing(tracing_context):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
@ -325,7 +321,7 @@ class inner_f(torch.nn.Module):
inputs = (torch.randn(4, 3),)
kwargs = {"scale": torch.tensor(2.0)}
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
gm = dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
with ExitStack() as stack:
# Export joint with descriptors
@ -356,8 +352,8 @@ class inner_f(torch.nn.Module):
primals,
tangents,
):
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight')
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias')
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
primals_4: "f32[]" # PlainAOTInput(idx=1)
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
@ -379,8 +375,8 @@ class inner_f(torch.nn.Module):
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
return pytree.tree_unflatten([
mul_2, # PlainAOTOutput(idx=0)
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight'))
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias'))
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
None, # None
None, # None
], self._out_spec)""",
@ -1063,9 +1059,11 @@ class inner_f(torch.nn.Module):
str(custom_metadata),
"""\
('call_function', 'new_empty', {'pp_stage': 0})
('get_attr', '_tensor_constant0', {'pp_stage': 0})
('call_function', 'index_put', {'pp_stage': 0})
('call_function', 'slice_2', {'pp_stage': 0})
('call_function', 'slice_backward', {'pp_stage': 0})
('get_attr', '_tensor_constant0_1', {'pp_stage': 0})
('call_function', 'index', {'pp_stage': 0})""",
)
@ -1082,7 +1080,7 @@ class inner_f(torch.nn.Module):
model = SimpleLinear()
inputs = (torch.randn(4, 3),)
gm = _dynamo_graph_capture_for_export(model)(*inputs)
gm = dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):

View File

@ -15,6 +15,7 @@ import torch._inductor.decomposition
import torch.fx.traceback as fx_traceback
import torch.nn.functional as F
from torch import nn
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._subclasses.fake_tensor import FakeTensorMode
@ -51,24 +52,6 @@ def enable_local_map_wrapping():
yield
def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch.export._trace import _restore_state_dict
"""
Thin wrapper around graph capture output that restores the
original calling convention and attribute fqn. TODO:
1) Use bytecode for calling convention instead of pytree for more
seamless UX.
2) Attach guards
3) Be more careful about tensor constants names.
"""
with torch._dynamo.config.patch(install_free_tensors=True):
gm = _dynamo_graph_capture_for_export(model)(*inputs)
_restore_state_dict(model, gm)
return gm
def ap_style_initial_capture(
model: torch.nn.Module, inputs_fn: Callable
) -> torch.nn.Module:
@ -90,7 +73,7 @@ def ap_style_initial_capture(
enable_local_map_wrapping(),
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
):
torch_ir_with_fqn = _export(model, inputs)
torch_ir_with_fqn = dynamo_graph_capture_for_export(model)(*inputs)
unused = ExitStack()
joint_with_descriptors = aot_export_joint_with_descriptors(
unused,