mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C10D] Support group_dst/group_src in c10d send/recv object_list (#140847)
Also add mypy annotations Partially addresses RFC 0042 (https://github.com/pytorch/rfcs/pull/71) See more details/motivation in https://github.com/pytorch/pytorch/pull/140460 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140847 Approved by: https://github.com/H-Huang ghstack dependencies: #140843
This commit is contained in:
committed by
PyTorch MergeBot
parent
c82c46ccc7
commit
98e6e69b1b
@ -3928,7 +3928,10 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
||||
"set_device",
|
||||
[SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
|
||||
)
|
||||
def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
|
||||
@parametrize("group_rank", [True, False])
|
||||
def test_send_recv_object_list_subgroup(
|
||||
self, set_device: SetDeviceMethod, group_rank
|
||||
):
|
||||
world_size = 4
|
||||
if self.rank >= world_size:
|
||||
return
|
||||
@ -3940,12 +3943,22 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = [{}]
|
||||
c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device)
|
||||
if group_rank:
|
||||
c10d.recv_object_list(x, group_src=1, group=subgroup, device=device)
|
||||
else:
|
||||
c10d.recv_object_list(
|
||||
x, src=self.rank + 1, group=subgroup, device=device
|
||||
)
|
||||
expected = [{"rank": self.rank + 1}]
|
||||
self.assertEqual(x, expected)
|
||||
else:
|
||||
x = [{"rank": self.rank}]
|
||||
c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device)
|
||||
if group_rank:
|
||||
c10d.send_object_list(x, group_dst=0, group=subgroup, device=device)
|
||||
else:
|
||||
c10d.send_object_list(
|
||||
x, dst=self.rank - 1, group=subgroup, device=device
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
@ -3087,7 +3087,13 @@ def gather_object(
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def send_object_list(object_list, dst, group=None, device=None):
|
||||
def send_object_list(
|
||||
object_list: List[Any],
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sends picklable objects in ``object_list`` synchronously.
|
||||
|
||||
@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None):
|
||||
device (``torch.device``, optional): If not None, the objects are
|
||||
serialized and converted to tensors which are moved to the
|
||||
``device`` before sending. Default is ``None``.
|
||||
|
||||
group_dst (int, optional): Destination rank on ``group``.
|
||||
Must specify one of ``dst`` and ``group_dst`` but not both
|
||||
Returns:
|
||||
``None``.
|
||||
|
||||
@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None):
|
||||
>>> objects
|
||||
['foo', 12, {1: 2}]
|
||||
"""
|
||||
if get_rank() == dst:
|
||||
raise ValueError(
|
||||
"Invalid destination rank: destination rank should not be the same as "
|
||||
"the rank of the current process."
|
||||
)
|
||||
group = _group_or_default_group(group)
|
||||
group_dst = _canonicalize_group_rank(group, dst, group_dst)
|
||||
_check_not_self_rank(group, group_dst, "destination")
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("send_object_list")
|
||||
@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None):
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
|
||||
# Send object sizes
|
||||
send(object_sizes_tensor, dst=dst, group=group)
|
||||
send(object_sizes_tensor, group_dst=group_dst, group=group)
|
||||
|
||||
# Concatenate and send serialized object tensors
|
||||
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
|
||||
@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None):
|
||||
else:
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
|
||||
send(object_tensor, dst=dst, group=group)
|
||||
send(object_tensor, group_dst=group_dst, group=group)
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def recv_object_list(object_list, src=None, group=None, device=None):
|
||||
def recv_object_list(
|
||||
object_list: List[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_src: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Receives picklable objects in ``object_list`` synchronously.
|
||||
|
||||
@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
|
||||
the default process group will be used. Default is ``None``.
|
||||
device (``torch.device``, optional): If not None, receives on this device.
|
||||
Default is ``None``.
|
||||
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
|
||||
|
||||
Returns:
|
||||
Sender rank. -1 if rank is not part of the group. If rank is part of the group,
|
||||
@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
|
||||
)
|
||||
|
||||
# Receive object sizes
|
||||
rank_sizes = recv(object_sizes_tensor, src=src, group=group)
|
||||
rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src)
|
||||
|
||||
# Tensor to receive serialized objects into.
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
|
||||
device=current_device,
|
||||
)
|
||||
|
||||
rank_objects = recv(object_tensor, src=src, group=group)
|
||||
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
|
||||
assert (
|
||||
rank_sizes == rank_objects
|
||||
), "Mismatch in return ranks for object sizes and objects."
|
||||
|
Reference in New Issue
Block a user