Compare commits

...

1 Commits

Author SHA1 Message Date
0754b9153a [DTensor] bad cases for explicit_order_placements
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
2025-11-16 17:00:05 -08:00
2 changed files with 102 additions and 2 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import itertools
import math
from contextlib import nullcontext
from typing import Any
@ -46,6 +47,27 @@ c10d_functional = torch.ops.c10d_functional
class LocalTest(TestCase):
def test_explicit_order_placements_error(self):
# no shard at the end
actual = _explicit_order_placements(
[
2,
],
[_StridedShard(0, split_factor=2)],
)
# no shard at the end
actual = _explicit_order_placements(
[2, 3], [_StridedShard(0, split_factor=1), _StridedShard(0, split_factor=1)]
)
# there is shard, but shard.dim does not match - equivalent to no shard at the end
actual = _explicit_order_placements(
[2, 3], [_StridedShard(0, split_factor=1), Shard(1)]
)
# shard at the end, but split_factor does not match mesh.size()
actual = _explicit_order_placements(
[2, 3], [_StridedShard(0, split_factor=4), Shard(0)]
)
def test_explicit_order_placements(self):
# mesh_shape: ShapeType, placements: Sequence[Placement]
test_cases = [
@ -155,7 +177,7 @@ class LocalTest(TestCase):
class UtilTest(DTensorTestBase):
@property
def world_size(self):
return 8
return 2
def _compute_start_end_offsets(self, global_offset, local_size, n_dim):
offset = []
@ -292,6 +314,62 @@ class UtilTest(DTensorTestBase):
global_tensor[dim0_start:dim0_end, dim1_start:dim1_end],
)
@with_comms
def test_strided_shard_compute_local_shape_and_global_offset_1D(self):
assert self.world_size == 2
device_mesh = init_device_mesh(self.device_type, (2,))
batch_size, seq_len = 2, 3
nelem = batch_size * seq_len
global_tensor = torch.arange(nelem).view(batch_size * seq_len)
global_shape = global_tensor.size()
placements = (_StridedShard(dim=0, split_factor=batch_size),)
dtensor = distribute_tensor(
global_tensor, device_mesh, (Replicate(),)
).redistribute(device_mesh, placements)
local_size, global_offset = compute_local_shape_and_global_offset(
global_shape, device_mesh, placements
)
import time
time.sleep(torch.distributed.get_rank())
print(
f"{torch.distributed.get_rank()=} {local_size=} expected_size={dtensor._local_tensor.shape}"
)
@with_comms
def test_strided_shard_compute_local_shape_and_global_offset_2D(self):
device_mesh = init_device_mesh(self.device_type, (2, 2))
batch_size, seq_len, dim1 = 2, 3, 3
nelem = batch_size * seq_len * dim1
global_tensor = torch.arange(nelem).view(batch_size * seq_len * dim1)
global_shape = global_tensor.size()
placements = (
_StridedShard(dim=0, split_factor=batch_size),
_StridedShard(
dim=0,
split_factor=batch_size
* math.ceil(seq_len * 1.0 / device_mesh.size(0)),
),
)
dtensor = distribute_tensor(
global_tensor, device_mesh, (Replicate(), Replicate())
).redistribute(device_mesh, placements)
local_size, global_offset = compute_local_shape_and_global_offset(
global_shape, device_mesh, placements
)
import time
time.sleep(torch.distributed.get_rank())
print(
f"{torch.distributed.get_rank()=} {local_size=} expected_size={dtensor._local_tensor.shape}"
)
@with_comms
def test_fsdp_tp_meta_compute(self):
# FSDP + TP sharding
@ -335,6 +413,25 @@ class UtilTest(DTensorTestBase):
self.assertEqual(local_shape[0], expected_shapes[rank])
self.assertEqual(global_offset[0], expected_offsets[rank])
@with_comms
def test_strided_shard_2d_meta_compute(self):
# FSDP + TP uneven sharding
tp_size = 2
dp_size = self.world_size // tp_size
global_mesh = init_device_mesh(
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
global_tensor_shape = torch.Size([15, 5])
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
rank = global_mesh.get_rank()
expected_shapes = [2, 2, 2, 2, 2, 2, 2, 1]
expected_offsets = [0, 8, 2, 10, 4, 12, 6, 14]
self.assertEqual(local_shape[0], expected_shapes[rank])
self.assertEqual(global_offset[0], expected_offsets[rank])
@with_comms
def test_hsdp_tp_meta_compute(self):
# HSDP + TP sharding

View File

@ -105,7 +105,10 @@ def _explicit_order_placements(
)
aggregate_size *= mesh_shape[strided_mesh_dim]
ordered.append((strided_mesh_dim, Shard(p.dim)))
if len(deferred_strided_placements) != 0:
raise NotImplementedError(
f"cannot convert {placements} into explicit order because of unresolved _StridedShard {deferred_strided_placements.items()}"
)
return ordered