mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix replacement reconstruct (#164937)
If we return Dtensor, the object is created via fx graph call so we never needed to reconstruct them. But if there is side effect, we do need to reconstruct it. Differential Revision: [D84159000](https://our.internmc.facebook.com/intern/diff/D84159000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164937 Approved by: https://github.com/StrongerXi
This commit is contained in:
committed by
PyTorch MergeBot
parent
724463d5a2
commit
afeec56a5a
@ -11,7 +11,8 @@ from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate
|
||||
from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -39,6 +40,21 @@ class SimpleModel(torch.nn.Module):
|
||||
return self.mlp_1(self.mlp_0(input))
|
||||
|
||||
|
||||
class EinsumModel(torch.nn.Module):
|
||||
"""Simple model that uses einsum with DTensor inputs and returns DTensor."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.placement = None
|
||||
|
||||
def forward(self, x, y, z):
|
||||
result = torch.einsum("bsh,hd->bsd", x, y)
|
||||
self.placement = result.placements[0]
|
||||
self.placement_2 = y.placements[0]
|
||||
self.placement_3 = z.placements[0]
|
||||
return result
|
||||
|
||||
|
||||
class SimpleModelDynamicShapes(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -334,6 +350,32 @@ class DTensorExportTest(TestCase):
|
||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
||||
)
|
||||
|
||||
def test_einsum_dtensor_export(self):
|
||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||
world_size = 4
|
||||
# Create device mesh
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(world_size,))
|
||||
model = EinsumModel()
|
||||
|
||||
x = torch.randn(4, 8, 16)
|
||||
x_dtensor = distribute_tensor(x, device_mesh, placements=[Shard(0)])
|
||||
|
||||
# y: [16, 16] replicated
|
||||
y = torch.randn(16, 16)
|
||||
z = torch.randn(16, 16)
|
||||
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
|
||||
|
||||
# Run model to verify it works
|
||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(
|
||||
x_dtensor, y_dtensor, z_dtensor
|
||||
)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
@ -29,6 +29,7 @@ from torch.fx.experimental._backward_state import BackwardState
|
||||
|
||||
from .. import compiled_autograd, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import unimplemented_v2
|
||||
from ..external_utils import call_module_hooks_from_backward_state
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
@ -231,6 +232,30 @@ class PlacementVariable(DistributedVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# Reconstruct the Placement object by calling its constructor
|
||||
# e.g., Shard(0), Replicate(), Partial()
|
||||
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
|
||||
|
||||
placement_type = type(self.value)
|
||||
|
||||
# Load the placement class
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
"torch.distributed.tensor.placement_types", placement_type.__name__
|
||||
)
|
||||
)
|
||||
|
||||
# For Shard, we need to pass the dim argument
|
||||
if isinstance(self.value, Shard):
|
||||
codegen(ConstantVariable.create(self.value.dim))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
# Replicate and Partial have no required args
|
||||
elif istype(self.value, (Replicate, Partial)):
|
||||
codegen.extend_output(create_call_function(0, False))
|
||||
else:
|
||||
super().reconstruct(codegen)
|
||||
|
||||
|
||||
class DeviceMeshVariable(DistributedVariable):
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user