Compare commits

...

1 Commits

Author SHA1 Message Date
c362ec0d5d [export] Codemod more tests to use dynamo_graph_capture_for_export
Summary:
as title.

Test Plan:
CI
2025-11-12 10:40:55 -08:00
5 changed files with 23 additions and 47 deletions

View File

@ -6,10 +6,7 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
@ -153,17 +150,6 @@ def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=N
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
@ -360,7 +346,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
)
@ -386,10 +371,6 @@ class DTensorExportTest(TestCase):
graph_capture_and_aot_export_joint_with_descriptors_v2,
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
),
(
graph_capture_and_aot_export_joint_with_descriptors,
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
),
],
)
def test_dynamic_shapes(self, export_fn_with_answer):
@ -434,7 +415,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
dynamo_graph_capture_for_export,
_dynamo_graph_capture_for_export,
],
)
def test_einsum_dtensor_export(self, export_fn):
@ -456,11 +436,7 @@ class DTensorExportTest(TestCase):
# Run model to verify it works
output = model(*inputs)
with torch._dynamo.config.patch(
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
):
# TODO: switch to use the official graph_capture API once it is ready
gm = export_fn(model)(*inputs)
gm = export_fn(model)(*inputs)
output_gm = gm(*inputs)
self.assertEqual(output, output_gm)
@ -468,7 +444,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
],
)
def test_flex_attention_dtensor_export(self, export_fn):
@ -531,7 +506,7 @@ class DTensorExportTest(TestCase):
return nest_fn(leaf) + 1
z = torch.randn(16, 16)
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
gm = graph_capture_and_aot_export_joint_with_descriptors_v2(fn, (z,))
self.assertEqual(fn(z), gm(z)[0])
@ -546,7 +521,7 @@ class DTensorExportTest(TestCase):
y = torch.randint(1, (10,)).bool()
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
class Bar(torch.nn.Module):
def forward(self, x):
@ -556,25 +531,25 @@ class DTensorExportTest(TestCase):
x = torch.randint(1000, (4, 64, 16))
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
gm = dynamo_graph_capture_for_export(Bar())(x_dt)
self.assertExpectedInline(
str(gm.graph).strip(),
"""\
graph():
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
%l_x_ : torch.distributed.tensor.DTensor [num_users=2] = placeholder[target=L_x_]
%max_1 : [num_users=1] = call_method[target=max](args = (%l_x_,), kwargs = {})
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
return (res,)""", # noqa: B950
str(gm.graph).strip(),
return (getitem,)""", # noqa: B950
)

View File

@ -962,7 +962,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
x = (torch.randn(4, 16, requires_grad=True),)
with self.assertRaisesRegex(Exception, "weight = self.linear.w"):
torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
torch._dynamo.functional_export.dynamo_graph_capture_for_export(Model())(x)
instantiate_parametrized_tests(ExceptionTests)

View File

@ -8146,7 +8146,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
unsafe_grad(y) # should not warn
self.assertEqual(len(w), 1)
@torch._dynamo.config.patch(install_free_tensors=True)
def test_partial_export(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -8166,14 +8165,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def forward(self, a, b):
return a + b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
foo = Foo()
foo.parallelize()
x = torch.randn(4, 4, dtype=torch.float32)
y = torch.randn(4, 4, dtype=torch.float32)
ref = foo(x, y)
gm = _dynamo_graph_capture_for_export(foo)(x, y)
gm = dynamo_graph_capture_for_export(foo)(x, y)
res = gm(x, y)
self.assertEqual(res, ref)

View File

@ -387,9 +387,9 @@ def forward(self, x):
export_inputs = ((dct, lst, 56), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
graph_module = dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)
@ -406,9 +406,9 @@ def forward(self, x):
export_inputs = ((torch.randn(4, 4),), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
graph_module = dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)

View File

@ -11,6 +11,7 @@ import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names, check_user_input_output
from torch._dynamo.exc import UserErrorType
@ -579,9 +580,10 @@ def pytreeify(
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_out_shuffle_args)
with enable_python_dispatcher():
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="real", proxy_module_inputs=True
)(*flat_out_shuffle_args)
_normalize_shuffle_graph(out_shuffle_graph)
assert out_shuffle.out_spec is not None