revert change of case

This commit is contained in:
Liao, Wei
2025-09-03 14:50:47 +08:00
committed by PyTorch MergeBot
parent d990b72872
commit 2ace9e465a

View File

@ -35,8 +35,6 @@ from torch.distributed.tensor.parallel import (
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing import make_tensor
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, skipIfHpu
from torch.testing._internal.common_device_type import skipXPUIf
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
@ -650,7 +648,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):
return min(8, torch.accelerator.device_count())
return 8
def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
if self.rank in mesh:
@ -700,7 +698,7 @@ class DTensorMeshTest(DTensorTestBase):
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, self.world_size // 2)
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
with mesh_2d:
@ -1066,7 +1064,7 @@ DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
class TestDTensorPlacementTypes(DTensorTestBase):
@property
def world_size(self):
return min(8, torch.accelerator.device_count())
return 8
def _create_tensor(self, size):
# Keep everything deterministic.
@ -1082,7 +1080,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
mesh = self.build_device_mesh()
shard_placement = Shard(0)
for size in range(self.world_size):
for size in range(8):
tensor = self._create_tensor(size)
splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
tensor,