Files
pytorch/test/distributed/fsdp/test_shard_utils.py
Zeng, Xiangdong c6392fcc06 [2/N] Port 3 fsdp distributed test cases to Intel GPU (#160940)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for fsdp distributed test cases, the first is https://github.com/pytorch/pytorch/pull/160158.
We could enable Intel GPU with following methods and try the best to keep the original code styles:
- Use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- Enabled XPU for some test path

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160940
Approved by: https://github.com/guangyey, https://github.com/d4l3k
2025-09-17 10:45:28 +00:00

81 lines
2.4 KiB
Python

# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import (
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class TestShardUtilsDistributed(FSDPTest):
@property
def world_size(self):
return 2
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).to(device=device_type)
@skip_if_lt_x_gpu(2)
def test_create_chunk_sharded_tensor(self):
for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
tensor = self._create_tensor(*size)
sharded_tensor = _create_chunk_sharded_tensor(
tensor,
self.rank,
self.world_size,
torch.accelerator.device_count(),
_get_default_group(),
)
output = (
torch.empty(*size).to(device=device_type) if self.rank == 0 else None
)
sharded_tensor.gather(0, output)
if self.rank == 0:
self.assertEqual(tensor, output)
class TestShardUtilsDistributedDTensor(DTensorTestBase):
@property
def world_size(self):
return 2
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).to(device=device_type)
@with_comms
@skip_if_lt_x_gpu(2)
def test_create_chunk_dtensor(self):
device_mesh = self.build_device_mesh()
for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
tensor = self._create_tensor(*size)
tensor_chunks = torch.chunk(tensor, self.world_size, dim=0)
dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh)
local_tensor = dtensor.to_local()
if local_tensor.numel() != 0:
self.assertEqual(local_tensor, tensor_chunks[self.rank])
else:
self.assertEqual(self.rank >= len(tensor_chunks), True)
if __name__ == "__main__":
run_tests()