mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DCP] fix dcp gather_object/scatter_object_list (#147675)
gather_object/scatter_object_list's dst is `Destination rank on global process group (regardless of group argument)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147675 Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d7fc0c681
commit
3d62e81a1e
@ -4,6 +4,7 @@ import io
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
@ -14,12 +15,21 @@ from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed.c10d_logger import _c10d_logger
|
||||
from torch.distributed.checkpoint.logger import _dcp_logger
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.utils import _create_file_view, find_state_dict_object
|
||||
from torch.distributed.checkpoint.utils import (
|
||||
_create_file_view,
|
||||
_DistWrapper,
|
||||
find_state_dict_object,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
|
||||
|
||||
|
||||
@ -185,5 +195,53 @@ class TestReaderView(TestCase):
|
||||
self.assertEqual(ba, b"VWXYZ\0\0\0")
|
||||
|
||||
|
||||
class TestDistWrapper(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_gather_object(self):
|
||||
mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
torch.random.manual_seed(dist.get_rank())
|
||||
|
||||
dist_wrapper = _DistWrapper(
|
||||
mesh_2d.get_group(1), use_dist=True, coordinator_rank=0
|
||||
)
|
||||
|
||||
rank = mesh_2d.get_rank()
|
||||
half_world_size = self.world_size // 2
|
||||
gathered_objects = dist_wrapper.gather_object(rank)
|
||||
expected_objects = (
|
||||
list(range(rank, rank + half_world_size))
|
||||
if rank % half_world_size == 0
|
||||
else None
|
||||
)
|
||||
assert gathered_objects == expected_objects
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_scatter_object(self):
|
||||
mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
torch.random.manual_seed(dist.get_rank())
|
||||
|
||||
dist_wrapper = _DistWrapper(
|
||||
mesh_2d.get_group(1), use_dist=True, coordinator_rank=0
|
||||
)
|
||||
|
||||
rank = mesh_2d.get_rank()
|
||||
half_world_size = self.world_size // 2
|
||||
|
||||
objects = (
|
||||
list(range(rank, rank + half_world_size))
|
||||
if rank % half_world_size == 0
|
||||
else None
|
||||
)
|
||||
scattered_objects = dist_wrapper.scatter_object(objects)
|
||||
expected_objects = rank
|
||||
assert scattered_objects == expected_objects
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -92,9 +92,15 @@ class _DistWrapper:
|
||||
self.use_dist = use_dist
|
||||
self.coordinator_rank = coordinator_rank
|
||||
if self.use_dist:
|
||||
self.global_coordinator_rank = (
|
||||
dist.get_global_rank(group, coordinator_rank)
|
||||
if group is not None
|
||||
else coordinator_rank
|
||||
)
|
||||
self.rank = dist.get_rank(group)
|
||||
self.is_coordinator = self.rank == coordinator_rank
|
||||
else:
|
||||
self.global_coordinator_rank = 0
|
||||
self.rank = 0
|
||||
self.is_coordinator = True
|
||||
|
||||
@ -129,7 +135,7 @@ class _DistWrapper:
|
||||
dist.gather_object(
|
||||
obj=object,
|
||||
object_gather_list=gather_objs if self.is_coordinator else None,
|
||||
dst=self.coordinator_rank,
|
||||
dst=self.global_coordinator_rank,
|
||||
group=self.group,
|
||||
)
|
||||
result = gather_objs
|
||||
@ -156,7 +162,7 @@ class _DistWrapper:
|
||||
dist.scatter_object_list(
|
||||
scatter_object_output_list=gather_result,
|
||||
scatter_object_input_list=object_list if self.is_coordinator else None,
|
||||
src=self.coordinator_rank,
|
||||
src=self.global_coordinator_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user