mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
Summary: There were two problems with SN + DP: 1. In SN, the updated _u vector is saved back to module via a `setattr`. However, in DP, everything is run on a replica, so those updates are lost. 2. In DP, the buffers are broadcast via a `broadcast_coalesced`, so on replicas they are all views. Therefore, the `detach_` call won't work. Fixes are: 1. Update _u vector in-place so, by the shared storage between 1st replica and the parallelized module, the update is retained 2. Do not call `detach_`. 3. Added comments in SN about the subtlety. 4. Added a note to the DP doc on this particular behavior of DP. cc crcrpar taesung89 The controller you requested could not be found. yaoshengfu Fixes https://github.com/pytorch/pytorch/issues/11476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12671 Differential Revision: D10410232 Pulled By: SsnL fbshipit-source-id: c447951844a30366d8c196bf9436340e88f3b6d9
190 lines
7.9 KiB
Python
190 lines
7.9 KiB
Python
import operator
|
|
import torch
|
|
import warnings
|
|
from ..modules import Module
|
|
from .scatter_gather import scatter_kwargs, gather
|
|
from .replicate import replicate
|
|
from .parallel_apply import parallel_apply
|
|
from torch.cuda._utils import _get_device_index
|
|
|
|
|
|
def _check_balance(device_ids):
|
|
imbalance_warn = """
|
|
There is an imbalance between your GPUs. You may want to exclude GPU {} which
|
|
has less than 75% of the memory or cores of GPU {}. You can do so by setting
|
|
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
|
|
environment variable."""
|
|
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
|
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
|
|
|
|
def warn_imbalance(get_prop):
|
|
values = [get_prop(props) for props in dev_props]
|
|
min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
|
|
max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
|
|
if min_val / max_val < 0.75:
|
|
warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
|
|
return True
|
|
return False
|
|
|
|
if warn_imbalance(lambda props: props.total_memory):
|
|
return
|
|
if warn_imbalance(lambda props: props.multi_processor_count):
|
|
return
|
|
|
|
|
|
class DataParallel(Module):
|
|
r"""Implements data parallelism at the module level.
|
|
|
|
This container parallelizes the application of the given :attr:`module` by
|
|
splitting the input across the specified devices by chunking in the batch
|
|
dimension (other objects will be copied once per device). In the forward
|
|
pass, the module is replicated on each device, and each replica handles a
|
|
portion of the input. During the backwards pass, gradients from each replica
|
|
are summed into the original module.
|
|
|
|
The batch size should be larger than the number of GPUs used.
|
|
|
|
See also: :ref:`cuda-nn-dataparallel-instead`
|
|
|
|
Arbitrary positional and keyword inputs are allowed to be passed into
|
|
DataParallel EXCEPT Tensors. All tensors will be scattered on dim
|
|
specified (default 0). Primitive types will be broadcasted, but all
|
|
other types will be a shallow copy and can be corrupted if written to in
|
|
the model's forward pass.
|
|
|
|
The parallelized :attr:`module` must have its parameters and buffers on
|
|
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
|
|
module.
|
|
|
|
.. warning::
|
|
In each forward, :attr:`module` is **replicated** on each device, so any
|
|
updates to the runing module in ``forward`` will be lost. For example,
|
|
if :attr:`module` has a counter attribute that is incremented in each
|
|
``forward``, it will always stay at the initial value becasue the update
|
|
is done on the replicas which are destroyed after ``forward``. However,
|
|
:class:`~torch.nn.DataParallel` guarantees that the replica on
|
|
``device[0]`` will have its parameters and buffers sharing storage with
|
|
the base parallelized :attr:`module`. So **in-place** updates to the
|
|
parameters or buffers on ``device[0]`` will be recorded. E.g.,
|
|
:class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
|
|
rely on this behavior to update the buffers.
|
|
|
|
.. warning::
|
|
Forward and backward hooks defined on :attr:`module` and its submodules
|
|
will be invoked ``len(device_ids)`` times, each with inputs located on
|
|
a particular device. Particularly, the hooks are only guaranteed to be
|
|
executed in correct order with respect to operations on corresponding
|
|
devices. For example, it is not guaranteed that hooks set via
|
|
:meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
|
|
`all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
|
|
that each such hook be executed before the corresponding
|
|
:meth:`~torch.nn.Module.forward` call of that device.
|
|
|
|
.. warning::
|
|
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
|
|
:func:`forward`, this wrapper will return a vector of length equal to
|
|
number of devices used in data parallelism, containing the result from
|
|
each device.
|
|
|
|
.. note::
|
|
There is a subtlety in using the
|
|
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
|
|
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
|
|
See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
|
|
details.
|
|
|
|
|
|
Args:
|
|
module (Module): module to be parallelized
|
|
device_ids (list of int or torch.device): CUDA devices (default: all devices)
|
|
output_device (int or torch.device): device location of output (default: device_ids[0])
|
|
|
|
Attributes:
|
|
module (Module): the module to be parallelized
|
|
|
|
Example::
|
|
|
|
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
|
|
>>> output = net(input_var)
|
|
"""
|
|
|
|
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
|
|
|
|
def __init__(self, module, device_ids=None, output_device=None, dim=0):
|
|
super(DataParallel, self).__init__()
|
|
|
|
if not torch.cuda.is_available():
|
|
self.module = module
|
|
self.device_ids = []
|
|
return
|
|
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
|
|
self.dim = dim
|
|
self.module = module
|
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
|
self.output_device = _get_device_index(output_device, True)
|
|
|
|
_check_balance(self.device_ids)
|
|
|
|
if len(self.device_ids) == 1:
|
|
self.module.cuda(device_ids[0])
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if not self.device_ids:
|
|
return self.module(*inputs, **kwargs)
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
|
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
|
return self.gather(outputs, self.output_device)
|
|
|
|
def replicate(self, module, device_ids):
|
|
return replicate(module, device_ids)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def parallel_apply(self, replicas, inputs, kwargs):
|
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
def gather(self, outputs, output_device):
|
|
return gather(outputs, output_device, dim=self.dim)
|
|
|
|
|
|
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
|
|
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
|
|
|
|
This is the functional version of the DataParallel module.
|
|
|
|
Args:
|
|
module (Module): the module to evaluate in parallel
|
|
inputs (tensor): inputs to the module
|
|
device_ids (list of int or torch.device): GPU ids on which to replicate module
|
|
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
|
|
(default: device_ids[0])
|
|
Returns:
|
|
a Tensor containing the result of module(input) located on
|
|
output_device
|
|
"""
|
|
if not isinstance(inputs, tuple):
|
|
inputs = (inputs,)
|
|
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
|
|
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
|
|
if len(device_ids) == 1:
|
|
return module(*inputs[0], **module_kwargs[0])
|
|
used_device_ids = device_ids[:len(inputs)]
|
|
replicas = replicate(module, used_device_ids)
|
|
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
|
|
return gather(outputs, output_device, dim)
|