mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Avoids the following deprecation warning: ```python loss.backward(*args, **kwargs) /usr/local/lib/python3.7/dist-packages/torch/tensor.py:245: in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) /usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:147: in backward allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag /usr/local/lib/python3.7/dist-packages/torch/autograd/function.py:89: in apply return self._forward_cls.backward(self, *args) # type: ignore /usr/local/lib/python3.7/dist-packages/torch/nn/parallel/_functions.py:34: in backward return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs) /usr/local/lib/python3.7/dist-packages/torch/nn/parallel/_functions.py:45: in forward return comm.reduce_add_coalesced(grads_, destination) /usr/local/lib/python3.7/dist-packages/torch/nn/parallel/comm.py:143: in reduce_add_coalesced flat_result = reduce_add(flat_tensors, destination) /usr/local/lib/python3.7/dist-packages/torch/nn/parallel/comm.py:96: in reduce_add nccl.reduce(inputs, output=result, root=root_index) /usr/local/lib/python3.7/dist-packages/torch/cuda/nccl.py:69: in reduce _check_sequence_type(inputs) /usr/local/lib/python3.7/dist-packages/torch/cuda/nccl.py:48: in _check_sequence_type if not isinstance(inputs, collections.Container) or isinstance(inputs, torch.Tensor): _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ name = 'Container' def __getattr__(name): # For backwards compatibility, continue to make the collections ABCs # through Python 3.6 available through the collections module. # Note, no new collections ABCs were added in Python 3.7 if name in _collections_abc.__all__: obj = getattr(_collections_abc, name) import warnings warnings.warn("Using or importing the ABCs from 'collections' instead " "of from 'collections.abc' is deprecated since Python 3.3," "and in 3.9 it will stop working", > DeprecationWarning, stacklevel=2) E DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working /usr/lib/python3.7/collections/__init__.py:52: DeprecationWarning ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/72239 Reviewed By: ngimel Differential Revision: D34387815 Pulled By: mruberry fbshipit-source-id: 30c9b4fe518351bc9a6f211269e27ee3ab73a13c (cherry picked from commit 1f68cdfac5875b56893b6b7ab3e8db96897f128b)
114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
import collections
|
|
import warnings
|
|
|
|
import torch.cuda
|
|
from typing import Optional, Sequence, Union
|
|
|
|
|
|
__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']
|
|
|
|
SUM = 0 # ncclRedOp_t
|
|
|
|
|
|
def is_available(tensors):
|
|
if not hasattr(torch._C, '_nccl_all_reduce'):
|
|
warnings.warn('PyTorch is not compiled with NCCL support')
|
|
return False
|
|
|
|
devices = set()
|
|
for tensor in tensors:
|
|
if tensor.is_sparse:
|
|
return False
|
|
if not tensor.is_contiguous():
|
|
return False
|
|
if not tensor.is_cuda:
|
|
return False
|
|
device = tensor.get_device()
|
|
if device in devices:
|
|
return False
|
|
devices.add(device)
|
|
|
|
return True
|
|
|
|
|
|
def version():
|
|
ver = torch._C._nccl_version()
|
|
major = ver >> 32
|
|
minor = (ver >> 16) & 65535
|
|
patch = ver & 65535
|
|
return (major, minor, patch)
|
|
|
|
|
|
def unique_id():
|
|
return torch._C._nccl_unique_id()
|
|
|
|
|
|
def init_rank(num_ranks, uid, rank):
|
|
return torch._C._nccl_init_rank(num_ranks, uid, rank)
|
|
|
|
|
|
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
|
|
if not isinstance(inputs, collections.abc.Container) or isinstance(inputs, torch.Tensor):
|
|
raise TypeError("Inputs should be a collection of tensors")
|
|
|
|
|
|
def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
|
|
_check_sequence_type(inputs)
|
|
if outputs is None:
|
|
outputs = inputs
|
|
_check_sequence_type(outputs)
|
|
torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
|
|
|
|
|
|
# `output` used to be `outputs`, taking in a list of tensors. So we have two
|
|
# arguments for BC reasons.
|
|
def reduce(inputs: Sequence[torch.Tensor],
|
|
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
|
|
root: int = 0,
|
|
op: int = SUM,
|
|
streams: Optional[Sequence[torch.cuda.Stream]] = None,
|
|
comms=None, *,
|
|
outputs: Optional[Sequence[torch.Tensor]] = None) -> None:
|
|
_check_sequence_type(inputs)
|
|
_output: torch.Tensor
|
|
if outputs is not None:
|
|
if output is not None:
|
|
raise ValueError(
|
|
"'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
|
|
"favor of 'output', taking in a single output tensor. The signature of reduce is: "
|
|
"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).")
|
|
else:
|
|
warnings.warn(
|
|
"nccl.reduce with an output tensor list is deprecated. "
|
|
"Please specify a single output tensor with argument 'output' instead instead.")
|
|
_output = outputs[root]
|
|
elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence):
|
|
# User called old API with positional arguments of list of output tensors.
|
|
warnings.warn(
|
|
"nccl.reduce with an output tensor list is deprecated. "
|
|
"Please specify a single output tensor.")
|
|
_output = output[root]
|
|
else:
|
|
_output = inputs[root] if output is None else output
|
|
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
|
|
|
|
|
|
def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None:
|
|
_check_sequence_type(inputs)
|
|
torch._C._nccl_broadcast(inputs, root, streams, comms)
|
|
|
|
|
|
def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None:
|
|
_check_sequence_type(inputs)
|
|
_check_sequence_type(outputs)
|
|
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
|
|
|
|
|
|
def reduce_scatter(inputs: Sequence[torch.Tensor],
|
|
outputs: Sequence[torch.Tensor],
|
|
op: int = SUM,
|
|
streams=None, comms=None) -> None:
|
|
_check_sequence_type(inputs)
|
|
_check_sequence_type(outputs)
|
|
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
|