# Owner(s): ["oncall: distributed"] import torch from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.fsdp._shard_utils import ( _create_chunk_dtensor, _create_chunk_sharded_tensor, ) from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_if_lt_x_gpu, with_comms, ) device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" class TestShardUtilsDistributed(FSDPTest): @property def world_size(self): return 2 def _create_tensor(self, *size): # Keep everything deterministic. torch.manual_seed(0) return torch.rand(*size).to(device=device_type) @skip_if_lt_x_gpu(2) def test_create_chunk_sharded_tensor(self): for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)): tensor = self._create_tensor(*size) sharded_tensor = _create_chunk_sharded_tensor( tensor, self.rank, self.world_size, torch.accelerator.device_count(), _get_default_group(), ) output = ( torch.empty(*size).to(device=device_type) if self.rank == 0 else None ) sharded_tensor.gather(0, output) if self.rank == 0: self.assertEqual(tensor, output) class TestShardUtilsDistributedDTensor(DTensorTestBase): @property def world_size(self): return 2 def _create_tensor(self, *size): # Keep everything deterministic. torch.manual_seed(0) return torch.rand(*size).to(device=device_type) @with_comms @skip_if_lt_x_gpu(2) def test_create_chunk_dtensor(self): device_mesh = self.build_device_mesh() for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)): tensor = self._create_tensor(*size) tensor_chunks = torch.chunk(tensor, self.world_size, dim=0) dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh) local_tensor = dtensor.to_local() if local_tensor.numel() != 0: self.assertEqual(local_tensor, tensor_chunks[self.rank]) else: self.assertEqual(self.rank >= len(tensor_chunks), True) if __name__ == "__main__": run_tests()