mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix _StridedShard
incorrect split (#165533)"
This reverts commit dfc8a1c5ddc8401197e9ab546e03b0f745edc27b. Reverted https://github.com/pytorch/pytorch/pull/165533 on behalf of https://github.com/seemethere due to Causing a merge conflict internally, see D84829161 ([comment](https://github.com/pytorch/pytorch/pull/165533#issuecomment-3416143176))
This commit is contained in:
@ -20,7 +20,6 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -1146,22 +1145,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
device_mesh = init_device_mesh(self.device_type, (4, 2))
|
||||
x = torch.randn(8, 4, device=self.device_type)
|
||||
# specify right-to-left order use _StridedShard
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
|
||||
)
|
||||
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user