[c10d] Remove deprecated multi-gpu-per-thread APIs (#114156)

As of today, PyTorch Distributed's preferred programming model is one device per thread, as exemplified by the APIs in its document.  The multi-GPU functions (which stand for multiple GPUs per CPU thread) have been deprecated for three versions. Removing them now before 2.2 release.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114156
Approved by: https://github.com/albanD, https://github.com/fduwjj, https://github.com/H-Huang
This commit is contained in:
Ke Wen
2023-11-21 03:50:19 +00:00
committed by PyTorch MergeBot
parent f67696f45e
commit dc65f6c601
7 changed files with 12 additions and 657 deletions

View File

@ -505,17 +505,14 @@ coverage_ignore_functions = [
"all_gather", "all_gather",
"all_gather_coalesced", "all_gather_coalesced",
"all_gather_into_tensor", "all_gather_into_tensor",
"all_gather_multigpu",
"all_gather_object", "all_gather_object",
"all_reduce", "all_reduce",
"all_reduce_coalesced", "all_reduce_coalesced",
"all_reduce_multigpu",
"all_to_all", "all_to_all",
"all_to_all_single", "all_to_all_single",
"barrier", "barrier",
"batch_isend_irecv", "batch_isend_irecv",
"broadcast", "broadcast",
"broadcast_multigpu",
"broadcast_object_list", "broadcast_object_list",
"destroy_process_group", "destroy_process_group",
"gather", "gather",
@ -543,9 +540,7 @@ coverage_ignore_functions = [
"new_subgroups_by_enumeration", "new_subgroups_by_enumeration",
"recv", "recv",
"reduce", "reduce",
"reduce_multigpu",
"reduce_scatter", "reduce_scatter",
"reduce_scatter_multigpu",
"reduce_scatter_tensor", "reduce_scatter_tensor",
"scatter", "scatter",
"scatter_object_list", "scatter_object_list",

View File

@ -483,72 +483,11 @@ Multi-GPU collective functions
------------------------------ ------------------------------
.. warning:: .. warning::
The multi-GPU functions will be deprecated. If you must use them, please revisit our documentation later. The multi-GPU functions (which stand for multiple GPUs per CPU thread) are
deprecated. As of today, PyTorch Distributed's preferred programming model
If you have more than one GPU on each node, when using the NCCL and Gloo backend, is one device per thread, as exemplified by the APIs in this document. If
:func:`~torch.distributed.broadcast_multigpu` you are a backend developer and want to support multiple devices per thread,
:func:`~torch.distributed.all_reduce_multigpu` please contact PyTorch Distributed's maintainers.
:func:`~torch.distributed.reduce_multigpu`
:func:`~torch.distributed.all_gather_multigpu` and
:func:`~torch.distributed.reduce_scatter_multigpu` support distributed collective
operations among multiple GPUs within each node. These functions can potentially
improve the overall distributed training performance and be easily used by
passing a list of tensors. Each Tensor in the passed tensor list needs
to be on a separate GPU device of the host where the function is called. Note
that the length of the tensor list needs to be identical among all the
distributed processes. Also note that currently the multi-GPU collective
functions are only supported by the NCCL backend.
For example, if the system we use for distributed training has 2 nodes, each
of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would
like to all-reduce. The following code can serve as a reference:
Code running on Node 0
::
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl",
init_method="file:///distributed_test",
world_size=2,
rank=0)
tensor_list = []
for dev_idx in range(torch.cuda.device_count()):
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
dist.all_reduce_multigpu(tensor_list)
Code running on Node 1
::
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl",
init_method="file:///distributed_test",
world_size=2,
rank=1)
tensor_list = []
for dev_idx in range(torch.cuda.device_count()):
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
dist.all_reduce_multigpu(tensor_list)
After the call, all 16 tensors on the two nodes will have the all-reduced value
of 16
.. autofunction:: broadcast_multigpu
.. autofunction:: all_reduce_multigpu
.. autofunction:: reduce_multigpu
.. autofunction:: all_gather_multigpu
.. autofunction:: reduce_scatter_multigpu
.. _distributed-launch: .. _distributed-launch:

View File

@ -2311,18 +2311,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Bump collective counter // Bump collective counter
seq_++; seq_++;
// Currently, the API permits two scenarios where inputs.size() and // Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0. // outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same // 1. If the call was a _coalesced call, all inputs must be on the same
// device. // device.
// The group of nccl calls applies the collective separately to each input, // The group of nccl calls applies the collective separately to each input,
// but the group as a whole should be efficient, and might even execute as // but the group as a whole should be efficient, and might even execute as
// a single fused kernel. // a single fused kernel.
// 2. If the call was a _multigpu call, all inputs must be on different
// devices.
// The nccl group applies the collective across them (eg, if the collective
// is an allreduce, the output on each device contains contributions summed
// across `inputs' tensors).
const auto devices = getDeviceList(inputs); const auto devices = getDeviceList(inputs);
const bool inputs_same_dev = (devices.size() == 1); const bool inputs_same_dev = (devices.size() == 1);
const auto key = getKeyFromDevices(devices); const auto key = getKeyFromDevices(devices);

View File

@ -705,7 +705,7 @@ Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex ten
The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
They are used in specifying strategies for reduction collectives, e.g., They are used in specifying strategies for reduction collectives, e.g.,
:func:`reduce`, :func:`all_reduce_multigpu`, etc. :func:`reduce`.
This class does not support ``__members__`` property.)"); This class does not support ``__members__`` property.)");

View File

@ -42,17 +42,17 @@ DistStoreError = torch._C._DistStoreError
__all__ = [ __all__ = [
'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced', 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
'all_gather_multigpu', 'all_gather_object', 'all_reduce', 'all_gather_object', 'all_reduce',
'all_reduce_coalesced', 'all_reduce_multigpu', 'all_to_all', 'all_reduce_coalesced', 'all_to_all',
'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast', 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast',
'broadcast_multigpu', 'broadcast_object_list', 'destroy_process_group', 'broadcast_object_list', 'destroy_process_group',
'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank', 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank',
'get_world_size', 'group', 'init_process_group', 'irecv', 'get_world_size', 'group', 'init_process_group', 'irecv',
'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available', 'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available',
'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available', 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available',
'isend', 'monitored_barrier', 'new_group', 'new_subgroups', 'isend', 'monitored_barrier', 'new_group', 'new_subgroups',
'new_subgroups_by_enumeration', 'recv', 'reduce', 'reduce_multigpu', 'new_subgroups_by_enumeration', 'recv', 'reduce',
'reduce_scatter', 'reduce_scatter_multigpu', 'scatter', 'reduce_scatter', 'scatter',
'scatter_object_list', 'send', 'supports_complex', 'scatter_object_list', 'send', 'supports_complex',
'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions', 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions',
'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore', 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore',
@ -1851,66 +1851,6 @@ def batch_isend_irecv(p2p_op_list):
return reqs return reqs
@_exception_logger
def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0):
"""
Broadcasts the tensor to the whole group with multiple GPU tensors per node.
``tensor`` must have the same number of elements in all the GPUs from
all processes participating in the collective. each tensor in the list must
be on a different GPU
Only nccl and gloo backend are currently supported
tensors should only be GPU tensors
Args:
tensor_list (List[Tensor]): Tensors that participate in the collective
operation. If ``src`` is the rank, then the specified ``src_tensor``
element of ``tensor_list`` (``tensor_list[src_tensor]``) will be
broadcast to all other tensors (on different GPUs) in the src process
and all tensors in ``tensor_list`` of other non-src processes.
You also need to make sure that ``len(tensor_list)`` is the same
for all the distributed processes calling this function.
src (int): Source rank.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
src_tensor (int, optional): Source tensor rank within ``tensor_list``
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
warnings.warn(
"torch.distributed.broadcast_multigpu will be deprecated. If you must "
"use it, please revisit our documentation later at "
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
)
if _rank_not_in_group(group):
_warn_not_in_group("broadcast_multigpu")
return
opts = BroadcastOptions()
opts.rootRank = src
opts.rootTensor = src_tensor
opts.asyncOp = async_op
if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
work = default_pg.broadcast(tensor_list, opts)
else:
group_src_rank = get_group_rank(group, src)
opts.rootRank = group_src_rank
work = group.broadcast(tensor_list, opts)
if async_op:
return work
else:
work.wait()
@_exception_logger @_exception_logger
def broadcast(tensor, src, group=None, async_op=False): def broadcast(tensor, src, group=None, async_op=False):
""" """
@ -1954,68 +1894,6 @@ def broadcast(tensor, src, group=None, async_op=False):
else: else:
work.wait() work.wait()
@_exception_logger
def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False):
r"""
Reduces the tensor data across all machines in a way that all get the final result.
This function reduces a number of tensors on every node,
while each tensor resides on different GPUs.
Therefore, the input tensor in the tensor list needs to be GPU tensors.
Also, each tensor in the tensor list needs to reside on a different GPU.
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
identical in all processes.
Complex tensors are supported.
Only nccl and gloo backend is currently supported
tensors should only be GPU tensors
Args:
tensor_list (List[Tensor]): List of input and output tensors of
the collective. The function operates in-place and requires that
each tensor to be a GPU tensor on different GPUs.
You also need to make sure that ``len(tensor_list)`` is the same for
all the distributed processes calling this function.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on. If
``None``, the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
warnings.warn(
"torch.distributed.all_reduce_multigpu will be deprecated. If you must "
"use it, please revisit our documentation later at "
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
)
if _rank_not_in_group(group):
return
tensor_list = [
t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
]
opts = AllreduceOptions()
opts.reduceOp = op
if group is None:
default_pg = _get_default_group()
work = default_pg.allreduce(tensor_list, opts)
else:
work = group.allreduce(tensor_list, opts)
if async_op:
return work
else:
work.wait()
@_exception_logger @_exception_logger
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
""" """
@ -2159,69 +2037,6 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
else: else:
work.wait() work.wait()
@_exception_logger
def reduce_multigpu(
tensor_list, dst, op=ReduceOp.SUM, group=None, async_op=False, dst_tensor=0
):
"""
Reduces the tensor data on multiple GPUs across all machines.
Each tensor in ``tensor_list`` should reside on a separate GPU.
Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
is going to receive the final result.
Only nccl backend is currently supported
tensors should only be GPU tensors
Args:
tensor_list (List[Tensor]): Input and output GPU tensors of the
collective. The function operates in-place.
You also need to make sure that ``len(tensor_list)`` is the same for
all the distributed processes calling this function.
dst (int): Destination rank
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
dst_tensor (int, optional): Destination tensor rank within
``tensor_list``
Returns:
Async work handle, if async_op is set to True.
None, otherwise
"""
warnings.warn(
"torch.distributed.reduce_multigpu will be deprecated. If you must "
"use it, please revisit our documentation later at "
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
)
if _rank_not_in_group(group):
_warn_not_in_group("reduce_multigpu")
return
opts = ReduceOptions()
opts.reduceOp = op
opts.rootRank = dst
opts.rootTensor = dst_tensor
if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
work = default_pg.reduce(tensor_list, opts)
else:
group_dst_rank = get_group_rank(group, dst)
opts.rootRank = group_dst_rank
work = group.reduce(tensor_list, opts)
if async_op:
return work
else:
work.wait()
@_exception_logger @_exception_logger
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
""" """
@ -2267,83 +2082,6 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
else: else:
work.wait() work.wait()
@_exception_logger
def all_gather_multigpu(
output_tensor_lists, input_tensor_list, group=None, async_op=False
):
"""
Gathers tensors from the whole group in a list.
Each tensor in ``tensor_list`` should reside on a separate GPU
Only nccl backend is currently supported
tensors should only be GPU tensors
Complex tensors are supported.
Args:
output_tensor_lists (List[List[Tensor]]): Output lists. It should
contain correctly-sized tensors on each GPU to be used for output
of the collective, e.g. ``output_tensor_lists[i]`` contains the
all_gather result that resides on the GPU of
``input_tensor_list[i]``.
Note that each element of ``output_tensor_lists`` has the size of
``world_size * len(input_tensor_list)``, since the function all
gathers the result from every single GPU in the group. To interpret
each element of ``output_tensor_lists[i]``, note that
``input_tensor_list[j]`` of rank k will be appear in
``output_tensor_lists[i][k * world_size + j]``
Also note that ``len(output_tensor_lists)``, and the size of each
element in ``output_tensor_lists`` (each element is a list,
therefore ``len(output_tensor_lists[i])``) need to be the same
for all the distributed processes calling this function.
input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
be broadcast from current process.
Note that ``len(input_tensor_list)`` needs to be the same for
all the distributed processes calling this function.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
warnings.warn(
"torch.distributed.all_gather_multigpu will be deprecated. If you must "
"use it, please revisit our documentation later at "
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
)
if _rank_not_in_group(group):
_warn_not_in_group("all_gather_multigpu")
return
output_tensor_lists = [
[t if not t.is_complex() else torch.view_as_real(t) for t in l]
for l in output_tensor_lists
]
input_tensor_list = [
t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
]
if group is None:
default_pg = _get_default_group()
work = default_pg.allgather(output_tensor_lists, input_tensor_list)
else:
work = group.allgather(output_tensor_lists, input_tensor_list)
if async_op:
return work
else:
work.wait()
def _object_to_tensor(obj, device): def _object_to_tensor(obj, device):
f = io.BytesIO() f = io.BytesIO()
_pickler(f).dump(obj) _pickler(f).dump(obj)
@ -3235,77 +2973,6 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
work.wait() work.wait()
@_exception_logger
def reduce_scatter_multigpu(
output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, group=None, async_op=False
):
"""
Reduce and scatter a list of tensors to the whole group.
Only nccl backend is currently supported.
Each tensor in ``output_tensor_list`` should reside on a separate GPU, as
should each list of tensors in ``input_tensor_lists``.
Args:
output_tensor_list (List[Tensor]): Output tensors (on different GPUs)
to receive the result of the operation.
Note that ``len(output_tensor_list)`` needs to be the same for all
the distributed processes calling this function.
input_tensor_lists (List[List[Tensor]]): Input lists. It should
contain correctly-sized tensors on each GPU to be used for input of
the collective, e.g. ``input_tensor_lists[i]`` contains the
reduce_scatter input that resides on the GPU of
``output_tensor_list[i]``.
Note that each element of ``input_tensor_lists`` has the size of
``world_size * len(output_tensor_list)``, since the function
scatters the result from every single GPU in the group. To
interpret each element of ``input_tensor_lists[i]``, note that
``output_tensor_list[j]`` of rank k receives the reduce-scattered
result from ``input_tensor_lists[i][k * world_size + j]``
Also note that ``len(input_tensor_lists)``, and the size of each
element in ``input_tensor_lists`` (each element is a list,
therefore ``len(input_tensor_lists[i])``) need to be the same for
all the distributed processes calling this function.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
warnings.warn(
"torch.distributed.reduce_scatter_multigpu will be deprecated. If you must "
"use it, please revisit our documentation later at "
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
)
if _rank_not_in_group(group):
_warn_not_in_group("reduce_scatter_multigpu")
return
opts = ReduceScatterOptions()
opts.reduceOp = op
if group is None:
default_pg = _get_default_group()
work = default_pg.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
else:
work = group.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
if async_op:
return work
else:
work.wait()
@_exception_logger @_exception_logger
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
""" """
@ -4299,7 +3966,6 @@ def _get_process_group_store(pg: ProcessGroup) -> Store:
# This ops are not friently to TorchDynamo. So, we decide to disallow these ops # This ops are not friently to TorchDynamo. So, we decide to disallow these ops
# in FX graph, allowing them to run them on eager, with torch.compile. # in FX graph, allowing them to run them on eager, with torch.compile.
dynamo_unsupported_distributed_c10d_ops = [ dynamo_unsupported_distributed_c10d_ops = [
all_reduce_multigpu,
recv, recv,
all_gather_object, all_gather_object,
all_gather_coalesced, all_gather_coalesced,
@ -4311,14 +3977,10 @@ dynamo_unsupported_distributed_c10d_ops = [
gather, gather,
broadcast_object_list, broadcast_object_list,
barrier, barrier,
reduce_multigpu,
scatter, scatter,
scatter_object_list, scatter_object_list,
reduce, reduce,
reduce_scatter_multigpu,
all_gather, all_gather,
broadcast_multigpu,
all_gather_multigpu,
reduce_scatter, reduce_scatter,
all_gather_into_tensor, all_gather_into_tensor,
broadcast, broadcast,

View File

@ -458,15 +458,6 @@ def init_multigpu_helper(world_size: int, backend: str):
nGPUs = torch.cuda.device_count() nGPUs = torch.cuda.device_count()
visible_devices = range(nGPUs) visible_devices = range(nGPUs)
if backend == "nccl":
# This is a hack for a known NCCL issue using multiprocess
# in conjunction with multiple threads to manage different GPUs which
# may cause ncclCommInitRank to fail.
# http://docs.nvidia.com/deeplearning/sdk/nccl-release-notes/rel_2.1.4.html#rel_2.1.4
# It slows down the performance of collective operations.
# Without this setting NCCL might throw unhandled error.
os.environ["NCCL_MAX_NRINGS"] = "1"
# If rank is less than or equal to number of available GPU's # If rank is less than or equal to number of available GPU's
# then each rank can be mapped to corresponding GPU. # then each rank can be mapped to corresponding GPU.
nGPUs_per_process = 1 nGPUs_per_process = 1

View File

@ -4162,233 +4162,6 @@ class DistributedTest:
group, group_id, rank = self._init_full_group_test() group, group_id, rank = self._init_full_group_test()
self._test_barrier_helper(group, group_id, rank) self._test_barrier_helper(group, group_id, rank)
def _test_broadcast_multigpu_helper(self, group, group_id, rank, rank_to_GPU):
for src in group:
expected_tensor = _build_tensor(src + 1)
tensors = [
_build_tensor(src + 1, -1).cuda(device=i) for i in rank_to_GPU[rank]
]
if rank == src:
tensors[0] = expected_tensor.cuda(device=rank_to_GPU[rank][0])
dist.broadcast_multigpu(tensors, src, group_id)
for tensor in tensors:
self.assertEqual(tensor, expected_tensor)
self._barrier()
@skip_but_pass_in_sandcastle_if(
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "nccl", "NCCL broadcast multigpu skipped"
)
@skip_if_no_gpu
def test_broadcast_multigpu(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
self._test_broadcast_multigpu_helper(group, group_id, rank, rank_to_GPU)
def _test_all_reduce_multigpu_helper(
self,
group,
group_id,
rank,
rank_to_GPU,
op,
master_value,
worker_value,
expected_value,
dtype=torch.float,
):
for src in group:
curr_value = master_value if rank == src else worker_value
tensors = [
_build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i)
for i in rank_to_GPU[rank]
]
self.call_dist_op(
":all_reduce",
False,
dist.all_reduce_multigpu,
tensors,
op,
group_id,
)
expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype)
for tensor in tensors:
self.assertEqual(tensor, expected_tensor)
self._barrier()
@skip_but_pass_in_sandcastle_if(
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL"
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
)
@skip_if_no_gpu
def test_all_reduce_multigpu(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
self._test_all_reduce_multigpu_helper(
group,
group_id,
rank,
rank_to_GPU,
dist.ReduceOp.SUM,
2,
10,
(2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL"
)
@skip_but_pass_in_sandcastle_if(
BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
)
@skip_if_no_gpu
def test_all_reduce_multigpu_complex(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
self._test_all_reduce_multigpu_helper(
group,
group_id,
rank,
rank_to_GPU,
dist.ReduceOp.SUM,
complex(2, 3),
complex(10, 11),
(complex(2, 3) + complex(10, 11) * (len(group) - 1))
* len(rank_to_GPU[0]),
dtype=torch.cfloat,
)
def _test_reduce_multigpu_helper(
self,
group,
group_id,
rank,
rank_to_GPU,
op,
master_value,
worker_value,
expected_value,
):
for src in group:
tensor_value = master_value if rank == src else worker_value
tensors = [
_build_tensor(src + 1, tensor_value).cuda(device=i)
for i in rank_to_GPU[rank]
]
self.call_dist_op(
":reduce",
False,
dist.reduce_multigpu,
tensors,
src,
op,
group_id,
expect_event=len(tensors) == 1,
tensor_shapes=[tensors[0].shape],
)
if rank == src:
expected_tensor = _build_tensor(src + 1, expected_value)
self.assertEqual(tensors[0], expected_tensor)
self._barrier()
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl", "Only Nccl backend supports reduce multigpu"
)
@skip_if_no_gpu
def test_reduce_multigpu(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
device_id = rank_to_GPU[rank][0]
torch.cuda.set_device(device_id)
self._test_reduce_multigpu_helper(
group,
group_id,
rank,
rank_to_GPU,
dist.ReduceOp.SUM,
2,
10,
(2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
)
def _test_all_gather_multigpu_helper(
self, group, group_id, rank, rank_to_GPU, dtype=torch.float
):
for dest in group:
tensors = [
_build_tensor(dest + 1, dtype=dtype).cuda(device=i)
for i in rank_to_GPU[rank]
]
# construct expected output along with
# a place holder to receive all gather results
output_tensors = []
expected_output = []
output_per_gpu = (
[_build_tensor(dest + 1, -1, dtype=dtype)]
* len(rank_to_GPU[0])
* len(group)
)
expected_per_gpu = (
[_build_tensor(dest + 1, dtype=dtype)]
* len(rank_to_GPU[0])
* len(group)
)
for gpu in rank_to_GPU[rank]:
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
expected_output.append(
[t.cuda(device=gpu) for t in expected_per_gpu]
)
self.call_dist_op(
":all_gather",
False,
dist.all_gather_multigpu,
output_tensors,
tensors,
group_id,
expect_event=len(expected_output) == 1,
)
self.assertEqual(output_tensors, expected_output)
self._barrier()
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
)
@skip_if_no_gpu
def test_all_gather_multigpu(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
device_id = rank_to_GPU[rank][0]
torch.cuda.set_device(device_id)
self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
)
@skip_if_no_gpu
def test_all_gather_multigpu_complex(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
device_id = rank_to_GPU[rank][0]
torch.cuda.set_device(device_id)
self._test_all_gather_multigpu_helper(
group, group_id, rank, rank_to_GPU, dtype=torch.cfloat
)
def _model_step(self, model): def _model_step(self, model):
for param in model.parameters(): for param in model.parameters():
if param.grad is not None: if param.grad is not None: