Fix replacement reconstruct (#164937)

If we return Dtensor, the object is created via fx graph call so we never needed to reconstruct them. But if there is side effect, we do need to reconstruct it.

Differential Revision: [D84159000](https://our.internmc.facebook.com/intern/diff/D84159000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164937
Approved by: https://github.com/StrongerXi
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-08 12:49:37 -07:00
committed by PyTorch MergeBot
parent 724463d5a2
commit afeec56a5a
2 changed files with 68 additions and 1 deletions

View File

@ -11,7 +11,8 @@ from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate
from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -39,6 +40,21 @@ class SimpleModel(torch.nn.Module):
return self.mlp_1(self.mlp_0(input))
class EinsumModel(torch.nn.Module):
"""Simple model that uses einsum with DTensor inputs and returns DTensor."""
def __init__(self):
super().__init__()
self.placement = None
def forward(self, x, y, z):
result = torch.einsum("bsh,hd->bsd", x, y)
self.placement = result.placements[0]
self.placement_2 = y.placements[0]
self.placement_3 = z.placements[0]
return result
class SimpleModelDynamicShapes(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -334,6 +350,32 @@ class DTensorExportTest(TestCase):
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
)
def test_einsum_dtensor_export(self):
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
world_size = 4
# Create device mesh
device_mesh = init_device_mesh(self.device_type, mesh_shape=(world_size,))
model = EinsumModel()
x = torch.randn(4, 8, 16)
x_dtensor = distribute_tensor(x, device_mesh, placements=[Shard(0)])
# y: [16, 16] replicated
y = torch.randn(16, 16)
z = torch.randn(16, 16)
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
# Run model to verify it works
output = model(x_dtensor, y_dtensor, z_dtensor)
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(
x_dtensor, y_dtensor, z_dtensor
)
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
self.assertEqual(output, output_gm)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -29,6 +29,7 @@ from torch.fx.experimental._backward_state import BackwardState
from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..bytecode_transformation import create_call_function
from ..exc import unimplemented_v2
from ..external_utils import call_module_hooks_from_backward_state
from ..guards import GuardBuilder, install_guard
@ -231,6 +232,30 @@ class PlacementVariable(DistributedVariable):
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
# Reconstruct the Placement object by calling its constructor
# e.g., Shard(0), Replicate(), Partial()
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
placement_type = type(self.value)
# Load the placement class
codegen.add_push_null(
lambda: codegen.load_import_from(
"torch.distributed.tensor.placement_types", placement_type.__name__
)
)
# For Shard, we need to pass the dim argument
if isinstance(self.value, Shard):
codegen(ConstantVariable.create(self.value.dim))
codegen.extend_output(create_call_function(1, False))
# Replicate and Partial have no required args
elif istype(self.value, (Replicate, Partial)):
codegen.extend_output(create_call_function(0, False))
else:
super().reconstruct(codegen)
class DeviceMeshVariable(DistributedVariable):
@staticmethod