Enable local tensor mode on DTensor view ops test (#165596)

While enabling this test discovered lack of support for sub meshes. Added limited support
for sub meshes by properly computing rank coordinates for a given sub mesh. The implementation
follows similar approach to collectives. We infer all sub meshes for the given dimensions and
compute each rank's coordinates with respect to is sub mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165596
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-16 09:43:58 -07:00
committed by PyTorch MergeBot
parent 7d0f872cb3
commit 2cd5fd1588
5 changed files with 57 additions and 13 deletions

View File

@ -1023,7 +1023,7 @@ class DTensorMeshTest(DTensorTestBase):
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,
skipped_tests=[
# Submeshes are not supported by local tensor mode
# Test asserts must be rewritten for local tensor
"test_from_local_sub_mesh",
"test_default_value_sub_mesh",
"test_redistribute_sub_mesh",

View File

@ -30,6 +30,7 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
with_comms,
)
@ -647,7 +648,7 @@ class TestViewOps(DTensorTestBase):
@with_comms
def test_squeeze_(self):
mesh_2d = init_device_mesh(self.device_type, (3, 2), mesh_dim_names=("a", "b"))
torch.manual_seed(self.rank)
self.init_manual_seed_for_rank()
x = torch.randn((1, 4), device=self.device_type)
dist_x = DTensor.from_local(x, mesh_2d, [Partial(), Shard(1)])
self._test_op_on_dtensor(
@ -664,5 +665,13 @@ class TestViewOps(DTensorTestBase):
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
TestViewOps,
skipped_tests=[
# Comparing data pointers is not supported for local tensor
"test_dtensor_view_op_uneven",
],
)
if __name__ == "__main__":
run_tests()

View File

@ -57,8 +57,9 @@ import torch
from torch import Size, SymBool, SymInt, Tensor
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
from torch.distributed import DeviceMesh
from torch.distributed import DeviceMesh, ProcessGroup
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.distributed_c10d import _get_default_group
from torch.fx.experimental._constant_symnode import ConstantIntNode
from torch.nested._internal.nested_int import NestedIntNode
from torch.utils import _pytree as pytree
@ -112,6 +113,9 @@ def _for_each_rank_run_func(
alias: bool = True,
) -> Any:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
flat_args = [
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
]
cpu_state = torch.get_rng_state()
devices, states = get_device_states((args, kwargs))
@ -250,6 +254,13 @@ class LocalIntNode:
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
)
def floordiv(
self, other: "int | LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
return LocalIntNode(
{r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints}
)
def mod(
self, other: "int | LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
@ -595,7 +606,7 @@ class LocalTensorMode(TorchDispatchMode):
# For LocalTensors, verify they have compatible ranks
for a in flat_args:
if isinstance(a, LocalTensor):
assert a._ranks == self.ranks, (
assert a._ranks <= self.ranks, (
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
)
@ -696,15 +707,28 @@ class _LocalDeviceMesh:
lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode"
rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
# NB: unlike the regular mechanism, we don't allow for MPMD
assert rank_coords.size(0) == 1
assert isinstance(rank_coords[0], LocalTensor)
root_mesh = self._get_root_mesh()
submesh_dims = self.mesh_dim_names
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
old_get_rank = DeviceMesh.get_rank # type: ignore[assignment]
try:
for r in lm.ranks:
DeviceMesh.get_rank = lambda self: r # type: ignore[method-assign]
submesh = (
root_mesh
if submesh_dims is None
else root_mesh.__getitem__(submesh_dims)
)
rank_coords = (submesh.mesh == r).nonzero().tolist()
assert len(rank_coords) in (0, 1)
if len(rank_coords) == 0:
continue
for d, c in enumerate(rank_coords[0]):
coords[d][r] = c
finally:
DeviceMesh.get_rank = old_get_rank # type: ignore[method-assign]
coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
for r, v in rank_coords[0]._local_tensors.items():
for i, x in enumerate(v.tolist()):
coords[i][r] = x
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
return out # type: ignore[return-value]

View File

@ -643,6 +643,11 @@ class _StridedShard(Shard):
return replicate_tensor.contiguous()
@staticmethod
@maybe_run_for_local_tensor
def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int:
return len(sharded_indices[rank])
def _local_shard_size_and_offset(
self,
curr_local_size: int,
@ -665,7 +670,7 @@ class _StridedShard(Shard):
# squeeze back to 1D indices tensor
sharded_indices = [shard.view(-1) for shard in sharded_indices]
local_shard_size = len(sharded_indices[rank])
local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank)
# offsets from _StridedShard is never used
return local_shard_size, None

View File

@ -381,6 +381,9 @@ class DTensorTestBase(MultiProcessTestCase):
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
return backend
def init_manual_seed_for_rank(self) -> None:
torch.manual_seed(self.rank)
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh(self.device_type, (self.world_size,))
@ -735,6 +738,9 @@ class LocalDTensorTestBase(DTensorTestBase):
def _spawn_processes(self) -> None:
pass
def init_manual_seed_for_rank(self) -> None:
torch.manual_seed(0)
def make_wrapped(fn, ctxs):
@functools.wraps(fn)