Compare commits

...

15 Commits

Author SHA1 Message Date
949ea834b8 revert register_constant 2025-09-29 16:13:12 -07:00
5ebd3d09a1 cleanup 2025-09-29 15:36:25 -07:00
61f71fb889 lint 2025-09-29 15:20:02 -07:00
22f413bfd1 workaround tensor_parallel_transformation issues 2025-09-29 15:20:02 -07:00
4dfcfa6b1b clean up test 2025-09-29 15:20:02 -07:00
4a89cdee21 fix test 2025-09-29 15:20:02 -07:00
1ca1d268f6 revert torch/_guards.py 2025-09-29 15:20:00 -07:00
5b547f8ef0 cleanup 2025-09-29 15:19:48 -07:00
1039d1dc2c revert torch/export/_trace.py 2025-09-29 15:19:48 -07:00
920168cc7b revert torch/_subclasses/meta_utils.py 2025-09-29 15:19:48 -07:00
fbcebfc276 cleanup test 2025-09-29 15:19:48 -07:00
6d4ca00f9c pytree.register_constant for DTensorSpec 2025-09-29 15:19:48 -07:00
e301c555e9 fixes 2025-09-29 15:19:48 -07:00
dcd94f54e6 hacks 2025-09-29 15:19:48 -07:00
suo
2f7d5f8b87 [rfc] Supporting exporting a model with DTensor params/inputs
As title, one issue was that our fake mode detection didn't understand
dtensor.

RFC because:
- I'm a dtensor noob so I don't know if this is the right way to use dtensor
- I don't like making torch/_guards.py aware of DTensor, looking for
  suggestions on alternative ways to structure the code.
2025-09-29 15:19:46 -07:00

View File

@ -0,0 +1,158 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
run_tests,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
class SimpleModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
return self.mlp_1(self.mlp_0(input))
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
# install_free_tensors is required for dynamo to work
with torch._dynamo.config.patch(
install_free_tensors=True, inline_inbuilt_nn_modules=True
):
with torch._export.utils._disable_aten_to_metadata_assertions():
ep = torch.export.export(model, (inputs,), strict=True)
# joint_gm produced here is missing the backward region, due to incompatiblility
# between ep.module() and aot_export_joint_with_descriptors.
# Keeping this here to show the issue.
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(inputs)
return aot_export_joint_with_descriptors_alone(gm, inputs)
def aot_export_joint_with_descriptors_alone(model, inputs):
with contextlib.ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
(inputs,),
)
return joint_with_descriptors.graph_module
def _count_op(gm, target):
return sum(1 for node in gm.graph.nodes if node.target == target)
@requires_cuda
class DTensorExportTest(TestCase):
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
def setUp(self):
super().setUp()
self.world_size = 8
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.world_size, store=store
)
self.device_type = "cuda"
def _run_test(self, export_fn):
dp_degree = 2
tp_degree = self.world_size // dp_degree
# 2-D mesh is [dp, tp]
mesh_2d = init_device_mesh(
self.device_type,
mesh_shape=(dp_degree, tp_degree),
mesh_dim_names=["dp", "tp"],
)
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
@parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
)
def test_export_parallelize_module_with_dtensor_input(
self,
export_fn,
):
self._run_test(export_fn)
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
instantiate_parametrized_tests(DTensorExportTest)
if __name__ == "__main__":
run_tests()