[DCP] fix dcp gather_object/scatter_object_list (#147675)

gather_object/scatter_object_list's dst is `Destination rank on global process group (regardless of group argument)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147675
Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
lanzongwei.lan
2025-03-06 21:20:34 +00:00
committed by PyTorch MergeBot
parent 1d7fc0c681
commit 3d62e81a1e
2 changed files with 67 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import io
import sys
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
@ -14,12 +15,21 @@ from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.c10d_logger import _c10d_logger
from torch.distributed.checkpoint.logger import _dcp_logger
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.utils import _create_file_view, find_state_dict_object
from torch.distributed.checkpoint.utils import (
_create_file_view,
_DistWrapper,
find_state_dict_object,
)
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
@ -185,5 +195,53 @@ class TestReaderView(TestCase):
self.assertEqual(ba, b"VWXYZ\0\0\0")
class TestDistWrapper(DTensorTestBase):
@property
def world_size(self):
return min(4, torch.cuda.device_count())
@with_comms
@skip_if_lt_x_gpu(4)
def test_gather_object(self):
mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2))
torch.random.manual_seed(dist.get_rank())
dist_wrapper = _DistWrapper(
mesh_2d.get_group(1), use_dist=True, coordinator_rank=0
)
rank = mesh_2d.get_rank()
half_world_size = self.world_size // 2
gathered_objects = dist_wrapper.gather_object(rank)
expected_objects = (
list(range(rank, rank + half_world_size))
if rank % half_world_size == 0
else None
)
assert gathered_objects == expected_objects
@with_comms
@skip_if_lt_x_gpu(4)
def test_scatter_object(self):
mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2))
torch.random.manual_seed(dist.get_rank())
dist_wrapper = _DistWrapper(
mesh_2d.get_group(1), use_dist=True, coordinator_rank=0
)
rank = mesh_2d.get_rank()
half_world_size = self.world_size // 2
objects = (
list(range(rank, rank + half_world_size))
if rank % half_world_size == 0
else None
)
scattered_objects = dist_wrapper.scatter_object(objects)
expected_objects = rank
assert scattered_objects == expected_objects
if __name__ == "__main__":
run_tests()

View File

@ -92,9 +92,15 @@ class _DistWrapper:
self.use_dist = use_dist
self.coordinator_rank = coordinator_rank
if self.use_dist:
self.global_coordinator_rank = (
dist.get_global_rank(group, coordinator_rank)
if group is not None
else coordinator_rank
)
self.rank = dist.get_rank(group)
self.is_coordinator = self.rank == coordinator_rank
else:
self.global_coordinator_rank = 0
self.rank = 0
self.is_coordinator = True
@ -129,7 +135,7 @@ class _DistWrapper:
dist.gather_object(
obj=object,
object_gather_list=gather_objs if self.is_coordinator else None,
dst=self.coordinator_rank,
dst=self.global_coordinator_rank,
group=self.group,
)
result = gather_objs
@ -156,7 +162,7 @@ class _DistWrapper:
dist.scatter_object_list(
scatter_object_output_list=gather_result,
scatter_object_input_list=object_list if self.is_coordinator else None,
src=self.coordinator_rank,
src=self.global_coordinator_rank,
group=self.group,
)