mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable local tensor mode on DTensor view ops test (#165596)
While enabling this test discovered lack of support for sub meshes. Added limited support for sub meshes by properly computing rank coordinates for a given sub mesh. The implementation follows similar approach to collectives. We infer all sub meshes for the given dimensions and compute each rank's coordinates with respect to is sub mesh. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165596 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d0f872cb3
commit
2cd5fd1588
@ -1023,7 +1023,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
skipped_tests=[
|
||||
# Submeshes are not supported by local tensor mode
|
||||
# Test asserts must be rewritten for local tensor
|
||||
"test_from_local_sub_mesh",
|
||||
"test_default_value_sub_mesh",
|
||||
"test_redistribute_sub_mesh",
|
||||
|
@ -30,6 +30,7 @@ from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
@ -647,7 +648,7 @@ class TestViewOps(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_squeeze_(self):
|
||||
mesh_2d = init_device_mesh(self.device_type, (3, 2), mesh_dim_names=("a", "b"))
|
||||
torch.manual_seed(self.rank)
|
||||
self.init_manual_seed_for_rank()
|
||||
x = torch.randn((1, 4), device=self.device_type)
|
||||
dist_x = DTensor.from_local(x, mesh_2d, [Partial(), Shard(1)])
|
||||
self._test_op_on_dtensor(
|
||||
@ -664,5 +665,13 @@ class TestViewOps(DTensorTestBase):
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
skipped_tests=[
|
||||
# Comparing data pointers is not supported for local tensor
|
||||
"test_dtensor_view_op_uneven",
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -57,8 +57,9 @@ import torch
|
||||
from torch import Size, SymBool, SymInt, Tensor
|
||||
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
|
||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.distributed import DeviceMesh, ProcessGroup
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -112,6 +113,9 @@ def _for_each_rank_run_func(
|
||||
alias: bool = True,
|
||||
) -> Any:
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
flat_args = [
|
||||
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
|
||||
]
|
||||
|
||||
cpu_state = torch.get_rng_state()
|
||||
devices, states = get_device_states((args, kwargs))
|
||||
@ -250,6 +254,13 @@ class LocalIntNode:
|
||||
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def floordiv(
|
||||
self, other: "int | LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def mod(
|
||||
self, other: "int | LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
@ -595,7 +606,7 @@ class LocalTensorMode(TorchDispatchMode):
|
||||
# For LocalTensors, verify they have compatible ranks
|
||||
for a in flat_args:
|
||||
if isinstance(a, LocalTensor):
|
||||
assert a._ranks == self.ranks, (
|
||||
assert a._ranks <= self.ranks, (
|
||||
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
||||
)
|
||||
|
||||
@ -696,15 +707,28 @@ class _LocalDeviceMesh:
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||
|
||||
rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
|
||||
# NB: unlike the regular mechanism, we don't allow for MPMD
|
||||
assert rank_coords.size(0) == 1
|
||||
assert isinstance(rank_coords[0], LocalTensor)
|
||||
root_mesh = self._get_root_mesh()
|
||||
submesh_dims = self.mesh_dim_names
|
||||
|
||||
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
|
||||
old_get_rank = DeviceMesh.get_rank # type: ignore[assignment]
|
||||
try:
|
||||
for r in lm.ranks:
|
||||
DeviceMesh.get_rank = lambda self: r # type: ignore[method-assign]
|
||||
submesh = (
|
||||
root_mesh
|
||||
if submesh_dims is None
|
||||
else root_mesh.__getitem__(submesh_dims)
|
||||
)
|
||||
rank_coords = (submesh.mesh == r).nonzero().tolist()
|
||||
assert len(rank_coords) in (0, 1)
|
||||
if len(rank_coords) == 0:
|
||||
continue
|
||||
for d, c in enumerate(rank_coords[0]):
|
||||
coords[d][r] = c
|
||||
finally:
|
||||
DeviceMesh.get_rank = old_get_rank # type: ignore[method-assign]
|
||||
|
||||
coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
|
||||
for r, v in rank_coords[0]._local_tensors.items():
|
||||
for i, x in enumerate(v.tolist()):
|
||||
coords[i][r] = x
|
||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||
|
||||
return out # type: ignore[return-value]
|
||||
|
@ -643,6 +643,11 @@ class _StridedShard(Shard):
|
||||
|
||||
return replicate_tensor.contiguous()
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int:
|
||||
return len(sharded_indices[rank])
|
||||
|
||||
def _local_shard_size_and_offset(
|
||||
self,
|
||||
curr_local_size: int,
|
||||
@ -665,7 +670,7 @@ class _StridedShard(Shard):
|
||||
# squeeze back to 1D indices tensor
|
||||
sharded_indices = [shard.view(-1) for shard in sharded_indices]
|
||||
|
||||
local_shard_size = len(sharded_indices[rank])
|
||||
local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank)
|
||||
|
||||
# offsets from _StridedShard is never used
|
||||
return local_shard_size, None
|
||||
|
@ -381,6 +381,9 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
|
||||
return backend
|
||||
|
||||
def init_manual_seed_for_rank(self) -> None:
|
||||
torch.manual_seed(self.rank)
|
||||
|
||||
def build_device_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
@ -735,6 +738,9 @@ class LocalDTensorTestBase(DTensorTestBase):
|
||||
def _spawn_processes(self) -> None:
|
||||
pass
|
||||
|
||||
def init_manual_seed_for_rank(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def make_wrapped(fn, ctxs):
|
||||
@functools.wraps(fn)
|
||||
|
Reference in New Issue
Block a user