mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[c10d] Remove deprecated multi-gpu-per-thread APIs (#114156)
As of today, PyTorch Distributed's preferred programming model is one device per thread, as exemplified by the APIs in its document. The multi-GPU functions (which stand for multiple GPUs per CPU thread) have been deprecated for three versions. Removing them now before 2.2 release. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114156 Approved by: https://github.com/albanD, https://github.com/fduwjj, https://github.com/H-Huang
This commit is contained in:
@ -505,17 +505,14 @@ coverage_ignore_functions = [
|
||||
"all_gather",
|
||||
"all_gather_coalesced",
|
||||
"all_gather_into_tensor",
|
||||
"all_gather_multigpu",
|
||||
"all_gather_object",
|
||||
"all_reduce",
|
||||
"all_reduce_coalesced",
|
||||
"all_reduce_multigpu",
|
||||
"all_to_all",
|
||||
"all_to_all_single",
|
||||
"barrier",
|
||||
"batch_isend_irecv",
|
||||
"broadcast",
|
||||
"broadcast_multigpu",
|
||||
"broadcast_object_list",
|
||||
"destroy_process_group",
|
||||
"gather",
|
||||
@ -543,9 +540,7 @@ coverage_ignore_functions = [
|
||||
"new_subgroups_by_enumeration",
|
||||
"recv",
|
||||
"reduce",
|
||||
"reduce_multigpu",
|
||||
"reduce_scatter",
|
||||
"reduce_scatter_multigpu",
|
||||
"reduce_scatter_tensor",
|
||||
"scatter",
|
||||
"scatter_object_list",
|
||||
|
@ -483,72 +483,11 @@ Multi-GPU collective functions
|
||||
------------------------------
|
||||
|
||||
.. warning::
|
||||
The multi-GPU functions will be deprecated. If you must use them, please revisit our documentation later.
|
||||
|
||||
If you have more than one GPU on each node, when using the NCCL and Gloo backend,
|
||||
:func:`~torch.distributed.broadcast_multigpu`
|
||||
:func:`~torch.distributed.all_reduce_multigpu`
|
||||
:func:`~torch.distributed.reduce_multigpu`
|
||||
:func:`~torch.distributed.all_gather_multigpu` and
|
||||
:func:`~torch.distributed.reduce_scatter_multigpu` support distributed collective
|
||||
operations among multiple GPUs within each node. These functions can potentially
|
||||
improve the overall distributed training performance and be easily used by
|
||||
passing a list of tensors. Each Tensor in the passed tensor list needs
|
||||
to be on a separate GPU device of the host where the function is called. Note
|
||||
that the length of the tensor list needs to be identical among all the
|
||||
distributed processes. Also note that currently the multi-GPU collective
|
||||
functions are only supported by the NCCL backend.
|
||||
|
||||
For example, if the system we use for distributed training has 2 nodes, each
|
||||
of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would
|
||||
like to all-reduce. The following code can serve as a reference:
|
||||
|
||||
Code running on Node 0
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
dist.init_process_group(backend="nccl",
|
||||
init_method="file:///distributed_test",
|
||||
world_size=2,
|
||||
rank=0)
|
||||
tensor_list = []
|
||||
for dev_idx in range(torch.cuda.device_count()):
|
||||
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
|
||||
|
||||
dist.all_reduce_multigpu(tensor_list)
|
||||
|
||||
Code running on Node 1
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
dist.init_process_group(backend="nccl",
|
||||
init_method="file:///distributed_test",
|
||||
world_size=2,
|
||||
rank=1)
|
||||
tensor_list = []
|
||||
for dev_idx in range(torch.cuda.device_count()):
|
||||
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
|
||||
|
||||
dist.all_reduce_multigpu(tensor_list)
|
||||
|
||||
After the call, all 16 tensors on the two nodes will have the all-reduced value
|
||||
of 16
|
||||
|
||||
.. autofunction:: broadcast_multigpu
|
||||
|
||||
.. autofunction:: all_reduce_multigpu
|
||||
|
||||
.. autofunction:: reduce_multigpu
|
||||
|
||||
.. autofunction:: all_gather_multigpu
|
||||
|
||||
.. autofunction:: reduce_scatter_multigpu
|
||||
The multi-GPU functions (which stand for multiple GPUs per CPU thread) are
|
||||
deprecated. As of today, PyTorch Distributed's preferred programming model
|
||||
is one device per thread, as exemplified by the APIs in this document. If
|
||||
you are a backend developer and want to support multiple devices per thread,
|
||||
please contact PyTorch Distributed's maintainers.
|
||||
|
||||
|
||||
.. _distributed-launch:
|
||||
|
@ -2311,18 +2311,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// Bump collective counter
|
||||
seq_++;
|
||||
|
||||
// Currently, the API permits two scenarios where inputs.size() and
|
||||
// Currently, the API permits one scenario where inputs.size() and
|
||||
// outputs.size() are > 0.
|
||||
// 1. If the call was a _coalesced call, all inputs must be on the same
|
||||
// device.
|
||||
// The group of nccl calls applies the collective separately to each input,
|
||||
// but the group as a whole should be efficient, and might even execute as
|
||||
// a single fused kernel.
|
||||
// 2. If the call was a _multigpu call, all inputs must be on different
|
||||
// devices.
|
||||
// The nccl group applies the collective across them (eg, if the collective
|
||||
// is an allreduce, the output on each device contains contributions summed
|
||||
// across `inputs' tensors).
|
||||
const auto devices = getDeviceList(inputs);
|
||||
const bool inputs_same_dev = (devices.size() == 1);
|
||||
const auto key = getKeyFromDevices(devices);
|
||||
|
@ -705,7 +705,7 @@ Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex ten
|
||||
|
||||
The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
|
||||
They are used in specifying strategies for reduction collectives, e.g.,
|
||||
:func:`reduce`, :func:`all_reduce_multigpu`, etc.
|
||||
:func:`reduce`.
|
||||
|
||||
This class does not support ``__members__`` property.)");
|
||||
|
||||
|
@ -42,17 +42,17 @@ DistStoreError = torch._C._DistStoreError
|
||||
|
||||
__all__ = [
|
||||
'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
|
||||
'all_gather_multigpu', 'all_gather_object', 'all_reduce',
|
||||
'all_reduce_coalesced', 'all_reduce_multigpu', 'all_to_all',
|
||||
'all_gather_object', 'all_reduce',
|
||||
'all_reduce_coalesced', 'all_to_all',
|
||||
'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast',
|
||||
'broadcast_multigpu', 'broadcast_object_list', 'destroy_process_group',
|
||||
'broadcast_object_list', 'destroy_process_group',
|
||||
'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank',
|
||||
'get_world_size', 'group', 'init_process_group', 'irecv',
|
||||
'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available',
|
||||
'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available',
|
||||
'isend', 'monitored_barrier', 'new_group', 'new_subgroups',
|
||||
'new_subgroups_by_enumeration', 'recv', 'reduce', 'reduce_multigpu',
|
||||
'reduce_scatter', 'reduce_scatter_multigpu', 'scatter',
|
||||
'new_subgroups_by_enumeration', 'recv', 'reduce',
|
||||
'reduce_scatter', 'scatter',
|
||||
'scatter_object_list', 'send', 'supports_complex',
|
||||
'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions',
|
||||
'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore',
|
||||
@ -1851,66 +1851,6 @@ def batch_isend_irecv(p2p_op_list):
|
||||
return reqs
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0):
|
||||
"""
|
||||
Broadcasts the tensor to the whole group with multiple GPU tensors per node.
|
||||
|
||||
``tensor`` must have the same number of elements in all the GPUs from
|
||||
all processes participating in the collective. each tensor in the list must
|
||||
be on a different GPU
|
||||
|
||||
Only nccl and gloo backend are currently supported
|
||||
tensors should only be GPU tensors
|
||||
|
||||
Args:
|
||||
tensor_list (List[Tensor]): Tensors that participate in the collective
|
||||
operation. If ``src`` is the rank, then the specified ``src_tensor``
|
||||
element of ``tensor_list`` (``tensor_list[src_tensor]``) will be
|
||||
broadcast to all other tensors (on different GPUs) in the src process
|
||||
and all tensors in ``tensor_list`` of other non-src processes.
|
||||
You also need to make sure that ``len(tensor_list)`` is the same
|
||||
for all the distributed processes calling this function.
|
||||
|
||||
src (int): Source rank.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
src_tensor (int, optional): Source tensor rank within ``tensor_list``
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
None, if not async_op or if not part of the group
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.distributed.broadcast_multigpu will be deprecated. If you must "
|
||||
"use it, please revisit our documentation later at "
|
||||
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
|
||||
)
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("broadcast_multigpu")
|
||||
return
|
||||
|
||||
opts = BroadcastOptions()
|
||||
opts.rootRank = src
|
||||
opts.rootTensor = src_tensor
|
||||
opts.asyncOp = async_op
|
||||
|
||||
if group is None or group is GroupMember.WORLD:
|
||||
default_pg = _get_default_group()
|
||||
work = default_pg.broadcast(tensor_list, opts)
|
||||
else:
|
||||
group_src_rank = get_group_rank(group, src)
|
||||
opts.rootRank = group_src_rank
|
||||
work = group.broadcast(tensor_list, opts)
|
||||
if async_op:
|
||||
return work
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def broadcast(tensor, src, group=None, async_op=False):
|
||||
"""
|
||||
@ -1954,68 +1894,6 @@ def broadcast(tensor, src, group=None, async_op=False):
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
@_exception_logger
|
||||
def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
r"""
|
||||
Reduces the tensor data across all machines in a way that all get the final result.
|
||||
|
||||
This function reduces a number of tensors on every node,
|
||||
while each tensor resides on different GPUs.
|
||||
Therefore, the input tensor in the tensor list needs to be GPU tensors.
|
||||
Also, each tensor in the tensor list needs to reside on a different GPU.
|
||||
|
||||
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
|
||||
identical in all processes.
|
||||
|
||||
Complex tensors are supported.
|
||||
|
||||
Only nccl and gloo backend is currently supported
|
||||
tensors should only be GPU tensors
|
||||
|
||||
Args:
|
||||
tensor_list (List[Tensor]): List of input and output tensors of
|
||||
the collective. The function operates in-place and requires that
|
||||
each tensor to be a GPU tensor on different GPUs.
|
||||
You also need to make sure that ``len(tensor_list)`` is the same for
|
||||
all the distributed processes calling this function.
|
||||
op (optional): One of the values from
|
||||
``torch.distributed.ReduceOp``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (ProcessGroup, optional): The process group to work on. If
|
||||
``None``, the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
None, if not async_op or if not part of the group
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.distributed.all_reduce_multigpu will be deprecated. If you must "
|
||||
"use it, please revisit our documentation later at "
|
||||
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
|
||||
)
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
return
|
||||
|
||||
tensor_list = [
|
||||
t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
|
||||
]
|
||||
|
||||
opts = AllreduceOptions()
|
||||
opts.reduceOp = op
|
||||
if group is None:
|
||||
default_pg = _get_default_group()
|
||||
work = default_pg.allreduce(tensor_list, opts)
|
||||
else:
|
||||
work = group.allreduce(tensor_list, opts)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
@_exception_logger
|
||||
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
"""
|
||||
@ -2159,69 +2037,6 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
@_exception_logger
|
||||
def reduce_multigpu(
|
||||
tensor_list, dst, op=ReduceOp.SUM, group=None, async_op=False, dst_tensor=0
|
||||
):
|
||||
"""
|
||||
Reduces the tensor data on multiple GPUs across all machines.
|
||||
|
||||
Each tensor in ``tensor_list`` should reside on a separate GPU.
|
||||
|
||||
Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
|
||||
is going to receive the final result.
|
||||
|
||||
Only nccl backend is currently supported
|
||||
tensors should only be GPU tensors
|
||||
|
||||
Args:
|
||||
tensor_list (List[Tensor]): Input and output GPU tensors of the
|
||||
collective. The function operates in-place.
|
||||
You also need to make sure that ``len(tensor_list)`` is the same for
|
||||
all the distributed processes calling this function.
|
||||
dst (int): Destination rank
|
||||
op (optional): One of the values from
|
||||
``torch.distributed.ReduceOp``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
dst_tensor (int, optional): Destination tensor rank within
|
||||
``tensor_list``
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
None, otherwise
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.distributed.reduce_multigpu will be deprecated. If you must "
|
||||
"use it, please revisit our documentation later at "
|
||||
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
|
||||
)
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("reduce_multigpu")
|
||||
return
|
||||
|
||||
opts = ReduceOptions()
|
||||
opts.reduceOp = op
|
||||
opts.rootRank = dst
|
||||
opts.rootTensor = dst_tensor
|
||||
|
||||
if group is None or group is GroupMember.WORLD:
|
||||
default_pg = _get_default_group()
|
||||
work = default_pg.reduce(tensor_list, opts)
|
||||
else:
|
||||
group_dst_rank = get_group_rank(group, dst)
|
||||
opts.rootRank = group_dst_rank
|
||||
work = group.reduce(tensor_list, opts)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
@_exception_logger
|
||||
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
"""
|
||||
@ -2267,83 +2082,6 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
@_exception_logger
|
||||
def all_gather_multigpu(
|
||||
output_tensor_lists, input_tensor_list, group=None, async_op=False
|
||||
):
|
||||
"""
|
||||
Gathers tensors from the whole group in a list.
|
||||
|
||||
Each tensor in ``tensor_list`` should reside on a separate GPU
|
||||
|
||||
Only nccl backend is currently supported
|
||||
tensors should only be GPU tensors
|
||||
|
||||
Complex tensors are supported.
|
||||
|
||||
Args:
|
||||
output_tensor_lists (List[List[Tensor]]): Output lists. It should
|
||||
contain correctly-sized tensors on each GPU to be used for output
|
||||
of the collective, e.g. ``output_tensor_lists[i]`` contains the
|
||||
all_gather result that resides on the GPU of
|
||||
``input_tensor_list[i]``.
|
||||
|
||||
Note that each element of ``output_tensor_lists`` has the size of
|
||||
``world_size * len(input_tensor_list)``, since the function all
|
||||
gathers the result from every single GPU in the group. To interpret
|
||||
each element of ``output_tensor_lists[i]``, note that
|
||||
``input_tensor_list[j]`` of rank k will be appear in
|
||||
``output_tensor_lists[i][k * world_size + j]``
|
||||
|
||||
Also note that ``len(output_tensor_lists)``, and the size of each
|
||||
element in ``output_tensor_lists`` (each element is a list,
|
||||
therefore ``len(output_tensor_lists[i])``) need to be the same
|
||||
for all the distributed processes calling this function.
|
||||
|
||||
input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
|
||||
be broadcast from current process.
|
||||
Note that ``len(input_tensor_list)`` needs to be the same for
|
||||
all the distributed processes calling this function.
|
||||
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
None, if not async_op or if not part of the group
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.distributed.all_gather_multigpu will be deprecated. If you must "
|
||||
"use it, please revisit our documentation later at "
|
||||
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
|
||||
)
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("all_gather_multigpu")
|
||||
return
|
||||
|
||||
output_tensor_lists = [
|
||||
[t if not t.is_complex() else torch.view_as_real(t) for t in l]
|
||||
for l in output_tensor_lists
|
||||
]
|
||||
input_tensor_list = [
|
||||
t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
|
||||
]
|
||||
|
||||
if group is None:
|
||||
default_pg = _get_default_group()
|
||||
work = default_pg.allgather(output_tensor_lists, input_tensor_list)
|
||||
else:
|
||||
work = group.allgather(output_tensor_lists, input_tensor_list)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
@ -3235,77 +2973,6 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
|
||||
work.wait()
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def reduce_scatter_multigpu(
|
||||
output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, group=None, async_op=False
|
||||
):
|
||||
"""
|
||||
Reduce and scatter a list of tensors to the whole group.
|
||||
|
||||
Only nccl backend is currently supported.
|
||||
|
||||
Each tensor in ``output_tensor_list`` should reside on a separate GPU, as
|
||||
should each list of tensors in ``input_tensor_lists``.
|
||||
|
||||
Args:
|
||||
output_tensor_list (List[Tensor]): Output tensors (on different GPUs)
|
||||
to receive the result of the operation.
|
||||
|
||||
Note that ``len(output_tensor_list)`` needs to be the same for all
|
||||
the distributed processes calling this function.
|
||||
|
||||
input_tensor_lists (List[List[Tensor]]): Input lists. It should
|
||||
contain correctly-sized tensors on each GPU to be used for input of
|
||||
the collective, e.g. ``input_tensor_lists[i]`` contains the
|
||||
reduce_scatter input that resides on the GPU of
|
||||
``output_tensor_list[i]``.
|
||||
|
||||
Note that each element of ``input_tensor_lists`` has the size of
|
||||
``world_size * len(output_tensor_list)``, since the function
|
||||
scatters the result from every single GPU in the group. To
|
||||
interpret each element of ``input_tensor_lists[i]``, note that
|
||||
``output_tensor_list[j]`` of rank k receives the reduce-scattered
|
||||
result from ``input_tensor_lists[i][k * world_size + j]``
|
||||
|
||||
Also note that ``len(input_tensor_lists)``, and the size of each
|
||||
element in ``input_tensor_lists`` (each element is a list,
|
||||
therefore ``len(input_tensor_lists[i])``) need to be the same for
|
||||
all the distributed processes calling this function.
|
||||
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op.
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
None, if not async_op or if not part of the group.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.distributed.reduce_scatter_multigpu will be deprecated. If you must "
|
||||
"use it, please revisit our documentation later at "
|
||||
"https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions"
|
||||
)
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("reduce_scatter_multigpu")
|
||||
return
|
||||
|
||||
opts = ReduceScatterOptions()
|
||||
opts.reduceOp = op
|
||||
|
||||
if group is None:
|
||||
default_pg = _get_default_group()
|
||||
work = default_pg.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
|
||||
else:
|
||||
work = group.reduce_scatter(output_tensor_list, input_tensor_lists, opts)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
"""
|
||||
@ -4299,7 +3966,6 @@ def _get_process_group_store(pg: ProcessGroup) -> Store:
|
||||
# This ops are not friently to TorchDynamo. So, we decide to disallow these ops
|
||||
# in FX graph, allowing them to run them on eager, with torch.compile.
|
||||
dynamo_unsupported_distributed_c10d_ops = [
|
||||
all_reduce_multigpu,
|
||||
recv,
|
||||
all_gather_object,
|
||||
all_gather_coalesced,
|
||||
@ -4311,14 +3977,10 @@ dynamo_unsupported_distributed_c10d_ops = [
|
||||
gather,
|
||||
broadcast_object_list,
|
||||
barrier,
|
||||
reduce_multigpu,
|
||||
scatter,
|
||||
scatter_object_list,
|
||||
reduce,
|
||||
reduce_scatter_multigpu,
|
||||
all_gather,
|
||||
broadcast_multigpu,
|
||||
all_gather_multigpu,
|
||||
reduce_scatter,
|
||||
all_gather_into_tensor,
|
||||
broadcast,
|
||||
|
@ -458,15 +458,6 @@ def init_multigpu_helper(world_size: int, backend: str):
|
||||
nGPUs = torch.cuda.device_count()
|
||||
visible_devices = range(nGPUs)
|
||||
|
||||
if backend == "nccl":
|
||||
# This is a hack for a known NCCL issue using multiprocess
|
||||
# in conjunction with multiple threads to manage different GPUs which
|
||||
# may cause ncclCommInitRank to fail.
|
||||
# http://docs.nvidia.com/deeplearning/sdk/nccl-release-notes/rel_2.1.4.html#rel_2.1.4
|
||||
# It slows down the performance of collective operations.
|
||||
# Without this setting NCCL might throw unhandled error.
|
||||
os.environ["NCCL_MAX_NRINGS"] = "1"
|
||||
|
||||
# If rank is less than or equal to number of available GPU's
|
||||
# then each rank can be mapped to corresponding GPU.
|
||||
nGPUs_per_process = 1
|
||||
|
@ -4162,233 +4162,6 @@ class DistributedTest:
|
||||
group, group_id, rank = self._init_full_group_test()
|
||||
self._test_barrier_helper(group, group_id, rank)
|
||||
|
||||
def _test_broadcast_multigpu_helper(self, group, group_id, rank, rank_to_GPU):
|
||||
for src in group:
|
||||
expected_tensor = _build_tensor(src + 1)
|
||||
tensors = [
|
||||
_build_tensor(src + 1, -1).cuda(device=i) for i in rank_to_GPU[rank]
|
||||
]
|
||||
if rank == src:
|
||||
tensors[0] = expected_tensor.cuda(device=rank_to_GPU[rank][0])
|
||||
|
||||
dist.broadcast_multigpu(tensors, src, group_id)
|
||||
for tensor in tensors:
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
self._barrier()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "nccl", "NCCL broadcast multigpu skipped"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_broadcast_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
self._test_broadcast_multigpu_helper(group, group_id, rank, rank_to_GPU)
|
||||
|
||||
def _test_all_reduce_multigpu_helper(
|
||||
self,
|
||||
group,
|
||||
group_id,
|
||||
rank,
|
||||
rank_to_GPU,
|
||||
op,
|
||||
master_value,
|
||||
worker_value,
|
||||
expected_value,
|
||||
dtype=torch.float,
|
||||
):
|
||||
for src in group:
|
||||
curr_value = master_value if rank == src else worker_value
|
||||
tensors = [
|
||||
_build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i)
|
||||
for i in rank_to_GPU[rank]
|
||||
]
|
||||
self.call_dist_op(
|
||||
":all_reduce",
|
||||
False,
|
||||
dist.all_reduce_multigpu,
|
||||
tensors,
|
||||
op,
|
||||
group_id,
|
||||
)
|
||||
expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype)
|
||||
for tensor in tensors:
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
|
||||
self._barrier()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_all_reduce_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
self._test_all_reduce_multigpu_helper(
|
||||
group,
|
||||
group_id,
|
||||
rank,
|
||||
rank_to_GPU,
|
||||
dist.ReduceOp.SUM,
|
||||
2,
|
||||
10,
|
||||
(2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "mpi", "MPI doesn't support broadcast multigpu"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_all_reduce_multigpu_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
self._test_all_reduce_multigpu_helper(
|
||||
group,
|
||||
group_id,
|
||||
rank,
|
||||
rank_to_GPU,
|
||||
dist.ReduceOp.SUM,
|
||||
complex(2, 3),
|
||||
complex(10, 11),
|
||||
(complex(2, 3) + complex(10, 11) * (len(group) - 1))
|
||||
* len(rank_to_GPU[0]),
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
|
||||
def _test_reduce_multigpu_helper(
|
||||
self,
|
||||
group,
|
||||
group_id,
|
||||
rank,
|
||||
rank_to_GPU,
|
||||
op,
|
||||
master_value,
|
||||
worker_value,
|
||||
expected_value,
|
||||
):
|
||||
for src in group:
|
||||
tensor_value = master_value if rank == src else worker_value
|
||||
tensors = [
|
||||
_build_tensor(src + 1, tensor_value).cuda(device=i)
|
||||
for i in rank_to_GPU[rank]
|
||||
]
|
||||
self.call_dist_op(
|
||||
":reduce",
|
||||
False,
|
||||
dist.reduce_multigpu,
|
||||
tensors,
|
||||
src,
|
||||
op,
|
||||
group_id,
|
||||
expect_event=len(tensors) == 1,
|
||||
tensor_shapes=[tensors[0].shape],
|
||||
)
|
||||
if rank == src:
|
||||
expected_tensor = _build_tensor(src + 1, expected_value)
|
||||
self.assertEqual(tensors[0], expected_tensor)
|
||||
|
||||
self._barrier()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND != "nccl", "Only Nccl backend supports reduce multigpu"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_reduce_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
torch.cuda.set_device(device_id)
|
||||
self._test_reduce_multigpu_helper(
|
||||
group,
|
||||
group_id,
|
||||
rank,
|
||||
rank_to_GPU,
|
||||
dist.ReduceOp.SUM,
|
||||
2,
|
||||
10,
|
||||
(2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
|
||||
)
|
||||
|
||||
def _test_all_gather_multigpu_helper(
|
||||
self, group, group_id, rank, rank_to_GPU, dtype=torch.float
|
||||
):
|
||||
for dest in group:
|
||||
tensors = [
|
||||
_build_tensor(dest + 1, dtype=dtype).cuda(device=i)
|
||||
for i in rank_to_GPU[rank]
|
||||
]
|
||||
|
||||
# construct expected output along with
|
||||
# a place holder to receive all gather results
|
||||
output_tensors = []
|
||||
expected_output = []
|
||||
output_per_gpu = (
|
||||
[_build_tensor(dest + 1, -1, dtype=dtype)]
|
||||
* len(rank_to_GPU[0])
|
||||
* len(group)
|
||||
)
|
||||
expected_per_gpu = (
|
||||
[_build_tensor(dest + 1, dtype=dtype)]
|
||||
* len(rank_to_GPU[0])
|
||||
* len(group)
|
||||
)
|
||||
for gpu in rank_to_GPU[rank]:
|
||||
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
|
||||
expected_output.append(
|
||||
[t.cuda(device=gpu) for t in expected_per_gpu]
|
||||
)
|
||||
self.call_dist_op(
|
||||
":all_gather",
|
||||
False,
|
||||
dist.all_gather_multigpu,
|
||||
output_tensors,
|
||||
tensors,
|
||||
group_id,
|
||||
expect_event=len(expected_output) == 1,
|
||||
)
|
||||
self.assertEqual(output_tensors, expected_output)
|
||||
|
||||
self._barrier()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_all_gather_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
torch.cuda.set_device(device_id)
|
||||
self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
|
||||
)
|
||||
@skip_if_no_gpu
|
||||
def test_all_gather_multigpu_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
torch.cuda.set_device(device_id)
|
||||
self._test_all_gather_multigpu_helper(
|
||||
group, group_id, rank, rank_to_GPU, dtype=torch.cfloat
|
||||
)
|
||||
|
||||
def _model_step(self, model):
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
|
Reference in New Issue
Block a user