mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 09:34:57 +08:00
revert change of case
This commit is contained in:
committed by
PyTorch MergeBot
parent
d990b72872
commit
2ace9e465a
@ -35,8 +35,6 @@ from torch.distributed.tensor.parallel import (
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, skipIfHpu
|
||||
from torch.testing._internal.common_device_type import skipXPUIf
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
@ -650,7 +648,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(8, torch.accelerator.device_count())
|
||||
return 8
|
||||
|
||||
def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
|
||||
if self.rank in mesh:
|
||||
@ -700,7 +698,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
|
||||
|
||||
mesh_2d = DeviceMesh(
|
||||
self.device_type, torch.arange(self.world_size).reshape(2, self.world_size // 2)
|
||||
self.device_type, torch.arange(self.world_size).reshape(2, 4)
|
||||
)
|
||||
|
||||
with mesh_2d:
|
||||
@ -1066,7 +1064,7 @@ DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(8, torch.accelerator.device_count())
|
||||
return 8
|
||||
|
||||
def _create_tensor(self, size):
|
||||
# Keep everything deterministic.
|
||||
@ -1082,7 +1080,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
mesh = self.build_device_mesh()
|
||||
shard_placement = Shard(0)
|
||||
|
||||
for size in range(self.world_size):
|
||||
for size in range(8):
|
||||
tensor = self._create_tensor(size)
|
||||
splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
|
||||
tensor,
|
||||
|
||||
Reference in New Issue
Block a user