Revert "Fix _StridedShard incorrect split (#165533)"

This reverts commit dfc8a1c5ddc8401197e9ab546e03b0f745edc27b.

Reverted https://github.com/pytorch/pytorch/pull/165533 on behalf of https://github.com/seemethere due to Causing a merge conflict internally, see D84829161 ([comment](https://github.com/pytorch/pytorch/pull/165533#issuecomment-3416143176))
This commit is contained in:
PyTorch MergeBot
2025-10-17 15:57:01 +00:00
parent 935ccdbe75
commit 85c5433d38
3 changed files with 52 additions and 82 deletions

View File

@ -20,7 +20,6 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -1146,22 +1145,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
sharded_dt, mesh, tgt_placement, shard_order=None
)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
device_mesh = init_device_mesh(self.device_type, (4, 2))
x = torch.randn(8, 4, device=self.device_type)
# specify right-to-left order use _StridedShard
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = self.distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
)
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
if __name__ == "__main__":
run_tests()