mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make DistributedDataParallel use new reducer (#18953)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18953 This removes Python side bucketing code from DistributedDataParallel and replaces it with calls to the new C++ based bucketing and reducing code. To confirm this is working well, we ran a test with both the previous implementation and the new implementation, and confirmed they are numerically equivalent. Performance is improved by a couple percent or more, including the single machine multiple GPU runs. Closes #13273. Reviewed By: mrshenli Differential Revision: D14580911 fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2
This commit is contained in:
committed by
Facebook Github Bot
parent
6ed57e052d
commit
a0263ec047
@ -1934,6 +1934,52 @@ class ReducerTest(TestCase):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeBucketAssignmentTest(TestCase):
|
||||||
|
def test_single_limit_single_dtype(self):
|
||||||
|
tensors = [
|
||||||
|
torch.empty([100], dtype=torch.float),
|
||||||
|
torch.empty([200], dtype=torch.float),
|
||||||
|
torch.empty([100], dtype=torch.float),
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
]
|
||||||
|
result = dist._compute_bucket_assignment_by_size(tensors, [400])
|
||||||
|
self.assertEqual([[0], [1], [2], [3]], result)
|
||||||
|
|
||||||
|
def test_single_limit_multi_dtype(self):
|
||||||
|
tensors = [
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
]
|
||||||
|
result = dist._compute_bucket_assignment_by_size(tensors, [400])
|
||||||
|
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
|
||||||
|
|
||||||
|
def test_multi_limit_single_dtype(self):
|
||||||
|
tensors = [
|
||||||
|
torch.empty([10], dtype=torch.float),
|
||||||
|
torch.empty([10], dtype=torch.float),
|
||||||
|
torch.empty([10], dtype=torch.float),
|
||||||
|
torch.empty([10], dtype=torch.float),
|
||||||
|
]
|
||||||
|
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
|
||||||
|
self.assertEqual([[0], [1, 2], [3]], result)
|
||||||
|
|
||||||
|
def test_multi_limit_multi_dtype(self):
|
||||||
|
tensors = [
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
torch.empty([50], dtype=torch.float),
|
||||||
|
torch.empty([25], dtype=torch.double),
|
||||||
|
]
|
||||||
|
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
|
||||||
|
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"
|
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"
|
||||||
|
|
||||||
|
@ -500,6 +500,13 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
module.def(
|
||||||
|
"_compute_bucket_assignment_by_size",
|
||||||
|
&::c10d::compute_bucket_assignment_by_size,
|
||||||
|
py::arg("tensors"),
|
||||||
|
py::arg("bucket_size"),
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
Py_RETURN_TRUE;
|
Py_RETURN_TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <torch/csrc/autograd/function_hook.h>
|
#include <torch/csrc/autograd/function_hook.h>
|
||||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||||
#include <torch/csrc/autograd/profiler.h>
|
#include <torch/csrc/autograd/profiler.h>
|
||||||
|
#include <torch/csrc/utils/hash.h>
|
||||||
#include <torch/csrc/utils/memory.h>
|
#include <torch/csrc/utils/memory.h>
|
||||||
|
|
||||||
namespace c10d {
|
namespace c10d {
|
||||||
@ -361,6 +362,13 @@ void Reducer::prepare_for_backward(
|
|||||||
bucket.pending = bucket.replicas.size();
|
bucket.pending = bucket.replicas.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no outputs are specified, we assume that autograd hooks for ALL
|
||||||
|
// variables will be called, and we don't have to search the autograd graph
|
||||||
|
// for presence of these hooks.
|
||||||
|
if (outputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Seed queue with the grad functions of all outputs.
|
// Seed queue with the grad functions of all outputs.
|
||||||
for (const auto& output : outputs) {
|
for (const auto& output : outputs) {
|
||||||
const auto& grad_fn = output.grad_fn();
|
const auto& grad_fn = output.grad_fn();
|
||||||
@ -433,4 +441,106 @@ void Reducer::finalize_backward() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Tensors may be coalesced into buckets. Buckets must contain tensors of
|
||||||
|
// the same type, on the same device, so a bucket can identified by a
|
||||||
|
// composite key of a tensor's type identifier and its device.
|
||||||
|
struct BucketKey {
|
||||||
|
BucketKey(c10::ScalarType type, c10::Device device)
|
||||||
|
: type(std::move(type)), device(std::move(device)) {}
|
||||||
|
|
||||||
|
const c10::ScalarType type;
|
||||||
|
const c10::Device device;
|
||||||
|
|
||||||
|
// See torch/csrc/utils/hash.h for dispatch code.
|
||||||
|
static size_t hash(const BucketKey& key) {
|
||||||
|
return torch::get_hash(key.type, key.device);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) {
|
||||||
|
return lhs.type == rhs.type && lhs.device == rhs.device;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// This is equivalent to take_tensors but returns indices into the
|
||||||
|
// tensor list argument for bucket assignment. Also, it is aware
|
||||||
|
// of device placement and will not allow buckets to span devices.
|
||||||
|
std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
|
||||||
|
const std::vector<at::Tensor>& tensors,
|
||||||
|
std::vector<size_t> bucket_size_limits) {
|
||||||
|
std::vector<std::vector<size_t>> result;
|
||||||
|
result.reserve(tensors.size());
|
||||||
|
|
||||||
|
// Keep iterator into the size_limit vector by tensor type and device.
|
||||||
|
// This is done so that we can use the consecutive bucket limits per type.
|
||||||
|
std::unordered_map<
|
||||||
|
BucketKey,
|
||||||
|
std::vector<size_t>::iterator,
|
||||||
|
torch::hash<BucketKey>>
|
||||||
|
bucket_size_limit_iterators;
|
||||||
|
|
||||||
|
// Local accumulator type for a single bucket.
|
||||||
|
struct BucketAccumulator {
|
||||||
|
std::vector<size_t> indices;
|
||||||
|
size_t size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keep vector of indices and size accumulator by tensor type and device.
|
||||||
|
std::unordered_map<BucketKey, BucketAccumulator, torch::hash<BucketKey>>
|
||||||
|
buckets;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < tensors.size(); i++) {
|
||||||
|
const auto& tensor = tensors[i];
|
||||||
|
AT_ASSERTM(!tensor.is_sparse(), "No support for sparse tensors.");
|
||||||
|
auto key = BucketKey(tensor.scalar_type(), tensor.device());
|
||||||
|
auto& bucket = buckets[key];
|
||||||
|
bucket.indices.push_back(i);
|
||||||
|
bucket.size += tensor.numel() * tensor.element_size();
|
||||||
|
|
||||||
|
// Initialize bucket size limit iterator if necessary.
|
||||||
|
if (bucket_size_limit_iterators.count(key) == 0) {
|
||||||
|
bucket_size_limit_iterators[key] = bucket_size_limits.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
|
||||||
|
const auto bucket_size_limit = *bucket_size_limit_iterator;
|
||||||
|
if (bucket.size >= bucket_size_limit) {
|
||||||
|
result.emplace_back(std::move(bucket.indices));
|
||||||
|
bucket = BucketAccumulator();
|
||||||
|
|
||||||
|
// Advance to the next bucket size limit for this type/device.
|
||||||
|
auto next = bucket_size_limit_iterator + 1;
|
||||||
|
if (next != bucket_size_limits.end()) {
|
||||||
|
bucket_size_limit_iterator = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add remaining buckets.
|
||||||
|
for (auto& it : buckets) {
|
||||||
|
auto& bucket = it.second;
|
||||||
|
if (!bucket.indices.empty()) {
|
||||||
|
result.emplace_back(std::move(bucket.indices));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort resulting buckets by the minimum tensor index they include.
|
||||||
|
// We assume that the order of the tensors is the order in which they are
|
||||||
|
// used (or the reverse order in which their gradients are produced).
|
||||||
|
// This sorting step ensures that the buckets are ready in consecutive order.
|
||||||
|
std::sort(
|
||||||
|
result.begin(),
|
||||||
|
result.end(),
|
||||||
|
[](const std::vector<size_t>& a, const std::vector<size_t>& b) {
|
||||||
|
const auto amin = std::min_element(a.begin(), a.end());
|
||||||
|
const auto bmin = std::min_element(b.begin(), b.end());
|
||||||
|
return *amin < *bmin;
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
@ -139,4 +139,8 @@ class Reducer {
|
|||||||
std::vector<std::vector<int64_t>> backward_stats_;
|
std::vector<std::vector<int64_t>> backward_stats_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
|
||||||
|
const std::vector<at::Tensor>& tensors,
|
||||||
|
std::vector<size_t> bucket_size);
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
@ -209,15 +209,19 @@ class DistributedDataParallel(Module):
|
|||||||
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||||
self.output_device = _get_device_index(output_device, True)
|
self.output_device = _get_device_index(output_device, True)
|
||||||
self.broadcast_buffers = broadcast_buffers
|
self.broadcast_buffers = broadcast_buffers
|
||||||
self.check_reduction = check_reduction
|
if check_reduction:
|
||||||
|
# This argument is no longer used since the reducer
|
||||||
|
# will ensure reduction completes even if some parameters
|
||||||
|
# do not receive gradients.
|
||||||
|
pass
|
||||||
|
|
||||||
MB = 1024 * 1024
|
MB = 1024 * 1024
|
||||||
|
|
||||||
# used for intra-node param sync and inter-node sync as well
|
# used for intra-node param sync and inter-node sync as well
|
||||||
self.broadcast_bucket_size = 250 * MB
|
self.broadcast_bucket_size = int(250 * MB)
|
||||||
|
|
||||||
# reduction bucket size
|
# reduction bucket size
|
||||||
self.bucket_bytes_cap = bucket_cap_mb * MB
|
self.bucket_bytes_cap = int(bucket_cap_mb * MB)
|
||||||
|
|
||||||
# Sync params and buffers
|
# Sync params and buffers
|
||||||
module_states = list(self.module.state_dict().values())
|
module_states = list(self.module.state_dict().values())
|
||||||
@ -254,60 +258,26 @@ class DistributedDataParallel(Module):
|
|||||||
self.modules_params = [list(m.parameters()) for m in self._module_copies]
|
self.modules_params = [list(m.parameters()) for m in self._module_copies]
|
||||||
self.modules_buffers = [list(m.buffers()) for m in self._module_copies]
|
self.modules_buffers = [list(m.buffers()) for m in self._module_copies]
|
||||||
|
|
||||||
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
|
param_list = [
|
||||||
param_buckets = []
|
list(filter(lambda p: p.requires_grad, module.parameters()))
|
||||||
|
for module in self._module_copies]
|
||||||
|
|
||||||
# Split the parameters into buckets and by types as well
|
# The bucket size limit is specified in the constructor.
|
||||||
# We only need to bucket and reduce parameters that require grad and
|
# Additionally, we allow for a single small bucket for parameters
|
||||||
# this is also true for backward since only the backward hooks for
|
# that are defined first, such that their gradients don't spill into
|
||||||
# parameters that require grad will be registered with gradient
|
# a much larger bucket, adding unnecessary latency after gradient
|
||||||
# reduction functions
|
# computation finishes. Experiments showed 1MB is a reasonable value.
|
||||||
params_to_bucket = [[] for _ in self._module_copies]
|
bucket_indices = dist._compute_bucket_assignment_by_size(
|
||||||
for dev_idx, m in enumerate(self._module_copies):
|
param_list[0],
|
||||||
for p in m.parameters():
|
[1024 * 1024, self.bucket_bytes_cap])
|
||||||
if p.requires_grad:
|
|
||||||
params_to_bucket[dev_idx].append(p)
|
|
||||||
|
|
||||||
param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket,
|
# Note: reverse list of buckets because we want to approximate the
|
||||||
int(self.bucket_bytes_cap),
|
# order in which their gradients are produced, and assume they
|
||||||
fine_grained=False)
|
# are used in the forward pass in the order they are defined.
|
||||||
for dev_params_to_bucket in params_to_bucket]
|
self.reducer = dist.Reducer(
|
||||||
|
param_list,
|
||||||
self.bucket_sizes = []
|
list(reversed(bucket_indices)),
|
||||||
self.bucket_map = {}
|
self.process_group)
|
||||||
|
|
||||||
# We transpose param_buckets, so the loop is over buckets.
|
|
||||||
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
|
|
||||||
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
|
|
||||||
self.bucket_sizes.append(0)
|
|
||||||
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
|
|
||||||
# of params from each device.
|
|
||||||
for param_tuple in zip(*param_buckets_tuple):
|
|
||||||
if not param_tuple[0].requires_grad:
|
|
||||||
continue
|
|
||||||
for p in param_tuple:
|
|
||||||
self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx])
|
|
||||||
self.bucket_sizes[bucket_idx] += 1
|
|
||||||
|
|
||||||
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
|
|
||||||
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
||||||
# The number of params ready in each bucket
|
|
||||||
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
||||||
|
|
||||||
# coalesced bucket for only device 0
|
|
||||||
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
|
|
||||||
# We will always reduce the bucket following the reverse order
|
|
||||||
# that is, alway reduces following the order of: n - 1, n - 2, ..., 0
|
|
||||||
self.next_bucket = len(self.bucket_sizes) - 1
|
|
||||||
# When all buckets are reduced, this will be set to True. This flag is
|
|
||||||
# useful for sanity checks to ensure that each iteration's backward has
|
|
||||||
# always reduced all buckets
|
|
||||||
self.all_buckets_reduced = False
|
|
||||||
self.check_previous_reduction = False
|
|
||||||
self.ready_buckets_not_reduced = set()
|
|
||||||
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
|
|
||||||
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
|
|
||||||
self._register_grad_hooks()
|
|
||||||
|
|
||||||
# passing a handle to torch.nn.SyncBatchNorm layer
|
# passing a handle to torch.nn.SyncBatchNorm layer
|
||||||
self._passing_sync_batchnorm_handle(self._module_copies)
|
self._passing_sync_batchnorm_handle(self._module_copies)
|
||||||
@ -315,15 +285,13 @@ class DistributedDataParallel(Module):
|
|||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
self._check_default_group()
|
self._check_default_group()
|
||||||
attrs = copy.copy(self.__dict__)
|
attrs = copy.copy(self.__dict__)
|
||||||
del attrs['process_group'], \
|
del attrs['process_group']
|
||||||
attrs['default_streams'], \
|
del attrs['reducer']
|
||||||
attrs['_grad_accs']
|
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# If serializable, then the process group should be the default one
|
# If serializable, then the process group should be the default one
|
||||||
self.process_group = _get_default_group()
|
self.process_group = _get_default_group()
|
||||||
self.check_previous_reduction = False
|
|
||||||
super(DistributedDataParallel, self).__setstate__(state)
|
super(DistributedDataParallel, self).__setstate__(state)
|
||||||
self._ddp_init_helper()
|
self._ddp_init_helper()
|
||||||
|
|
||||||
@ -342,32 +310,28 @@ class DistributedDataParallel(Module):
|
|||||||
"init_process_group and have not passed "
|
"init_process_group and have not passed "
|
||||||
"process_group argument to DDP constructor")
|
"process_group argument to DDP constructor")
|
||||||
|
|
||||||
def _check_previous_reduction(self):
|
|
||||||
if not self.training:
|
|
||||||
return
|
|
||||||
# self.check_previous_reduction will be False in the first iteration
|
|
||||||
# and is then toggled to True for all future iterations.
|
|
||||||
if self.check_previous_reduction is False:
|
|
||||||
self.check_previous_reduction = True
|
|
||||||
else:
|
|
||||||
if not self.all_buckets_reduced:
|
|
||||||
raise RuntimeError("Not all gradients have been reduced from "
|
|
||||||
"the backward of the previous iteration. "
|
|
||||||
"This is an unexpected and fatal error. "
|
|
||||||
"Please check and ensure that the model's "
|
|
||||||
"parameters are not changed after you wrap "
|
|
||||||
"up the model with DistributedDataParallel.")
|
|
||||||
self.all_buckets_reduced = False
|
|
||||||
|
|
||||||
def forward(self, *inputs, **kwargs):
|
def forward(self, *inputs, **kwargs):
|
||||||
if self.check_reduction:
|
|
||||||
self._check_previous_reduction()
|
|
||||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
self._sync_params()
|
self._sync_params()
|
||||||
if len(self.device_ids) == 1:
|
if len(self.device_ids) == 1:
|
||||||
return self.module(*inputs[0], **kwargs[0])
|
output = self.module(*inputs[0], **kwargs[0])
|
||||||
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
else:
|
||||||
return self.gather(outputs, self.output_device)
|
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
||||||
|
output = self.gather(outputs, self.output_device)
|
||||||
|
|
||||||
|
# We'll return the output object verbatim since it is a freeform object.
|
||||||
|
# We need to find any tensors in this object, though, because we need to
|
||||||
|
# figure out which parameters were used during this forward pass,
|
||||||
|
# to ensure we short circuit reduction for any unused parameters.
|
||||||
|
output_tensors = []
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
output_tensors = [output]
|
||||||
|
if isinstance(output, (list, tuple)):
|
||||||
|
def istensor(obj):
|
||||||
|
return isinstance(obj, torch.Tensor)
|
||||||
|
output_tensors = list(filter(istensor, output))
|
||||||
|
self.reducer.prepare_for_backward(output_tensors)
|
||||||
|
return output
|
||||||
|
|
||||||
def scatter(self, inputs, kwargs, device_ids):
|
def scatter(self, inputs, kwargs, device_ids):
|
||||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||||
@ -379,7 +343,6 @@ class DistributedDataParallel(Module):
|
|||||||
return gather(outputs, output_device, dim=self.dim)
|
return gather(outputs, output_device, dim=self.dim)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
self.check_previous_reduction = False
|
|
||||||
super(DistributedDataParallel, self).train(mode)
|
super(DistributedDataParallel, self).train(mode)
|
||||||
for module in self._module_copies[1:]:
|
for module in self._module_copies[1:]:
|
||||||
module.train(mode)
|
module.train(mode)
|
||||||
@ -398,6 +361,13 @@ class DistributedDataParallel(Module):
|
|||||||
self.modules_params[1:]):
|
self.modules_params[1:]):
|
||||||
for tensor, param in zip(tensors, module_params):
|
for tensor, param in zip(tensors, module_params):
|
||||||
param.set_(tensor)
|
param.set_(tensor)
|
||||||
|
# Assume we have just run the optimizer and zeroed the
|
||||||
|
# grads of the parameters on the root model. We need
|
||||||
|
# to zero the grads on all model replicas as well.
|
||||||
|
# This snippet is copied from torch.optim.Optimizer.
|
||||||
|
if param.grad is not None:
|
||||||
|
param.grad.detach_()
|
||||||
|
param.grad.zero_()
|
||||||
|
|
||||||
# module buffer sync
|
# module buffer sync
|
||||||
if self.broadcast_buffers and len(self.modules_buffers[0]) > 0:
|
if self.broadcast_buffers and len(self.modules_buffers[0]) > 0:
|
||||||
@ -419,99 +389,3 @@ class DistributedDataParallel(Module):
|
|||||||
for layer in module.modules():
|
for layer in module.modules():
|
||||||
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
|
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
|
||||||
layer._specify_ddp_gpu_num(len(self.device_ids))
|
layer._specify_ddp_gpu_num(len(self.device_ids))
|
||||||
|
|
||||||
def _register_grad_hooks(self):
|
|
||||||
self._grad_accs = [] # need to keep them in scope
|
|
||||||
|
|
||||||
# default stream tracking to launch nccl reduce kernels
|
|
||||||
self.default_streams = []
|
|
||||||
for dev_id in self.device_ids:
|
|
||||||
with torch.cuda.device(dev_id):
|
|
||||||
self.default_streams.append(torch.cuda.current_stream())
|
|
||||||
|
|
||||||
for device_idx, module in enumerate(self._module_copies):
|
|
||||||
for p in module.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
p_tmp = p.expand_as(p)
|
|
||||||
grad_acc = p_tmp.grad_fn.next_functions[0][0]
|
|
||||||
grad_acc.register_hook(self._make_param_hook(p, device_idx))
|
|
||||||
self._grad_accs.append(grad_acc)
|
|
||||||
|
|
||||||
def _make_param_hook(self, param, device_idx):
|
|
||||||
bucket_idx, bucket_offset = self.bucket_map[param]
|
|
||||||
|
|
||||||
def distributed_data_parallel_hook(*unused):
|
|
||||||
if param.grad.requires_grad:
|
|
||||||
raise RuntimeError("DistributedDataParallel only works "
|
|
||||||
"with gradients that don't require grad")
|
|
||||||
bucket = self.buckets[bucket_idx][device_idx]
|
|
||||||
bucket[bucket_offset] = param.grad.data
|
|
||||||
self.buckets_ready_size[bucket_idx][device_idx] += 1
|
|
||||||
|
|
||||||
# We can flush these and save memory for replicas
|
|
||||||
if device_idx > 0:
|
|
||||||
param.grad = None
|
|
||||||
with torch.no_grad():
|
|
||||||
param.set_()
|
|
||||||
|
|
||||||
# Current device's bucket is full
|
|
||||||
if self.buckets_ready_size[bucket_idx][device_idx] == self.bucket_sizes[bucket_idx]:
|
|
||||||
self.devs_ready[bucket_idx] += 1
|
|
||||||
if self.devs_ready[bucket_idx] < len(self.device_ids):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Now all devices's buckets with index: bucket_idx are ready
|
|
||||||
if bucket_idx == self.next_bucket:
|
|
||||||
self._queue_reduction(bucket_idx)
|
|
||||||
self.next_bucket -= 1
|
|
||||||
# Now reduce anything that is ready but not yet reduced
|
|
||||||
if len(self.ready_buckets_not_reduced) > 0:
|
|
||||||
sorted_todo = sorted(self.ready_buckets_not_reduced, reverse=True)
|
|
||||||
for i in sorted_todo:
|
|
||||||
# Nothing can be reduced now
|
|
||||||
if i < self.next_bucket:
|
|
||||||
break
|
|
||||||
self._queue_reduction(i)
|
|
||||||
self.ready_buckets_not_reduced.remove(i)
|
|
||||||
if i == self.next_bucket:
|
|
||||||
self.next_bucket -= 1
|
|
||||||
else:
|
|
||||||
self.ready_buckets_not_reduced.add(bucket_idx)
|
|
||||||
|
|
||||||
# When all devices' buckets
|
|
||||||
if self.next_bucket == -1:
|
|
||||||
# A final sync for all the reduction works
|
|
||||||
self._sync_reduction_works()
|
|
||||||
self.all_buckets_reduced = True
|
|
||||||
|
|
||||||
return distributed_data_parallel_hook
|
|
||||||
|
|
||||||
def _queue_reduction(self, bucket_idx):
|
|
||||||
# _queue_reduction will use a seperate CUDA stream to coalesce
|
|
||||||
# the small tensors to achieve more parallelisms, before passing the
|
|
||||||
# coalesced tensor into the c10d CUDA stream for reduction
|
|
||||||
result = dist._queue_reduction(self.process_group,
|
|
||||||
self.buckets[bucket_idx],
|
|
||||||
self.device_ids)
|
|
||||||
self.reduction_works[bucket_idx] = result[0]
|
|
||||||
self.buckets_coalesced[bucket_idx] = result[1]
|
|
||||||
|
|
||||||
def _sync_reduction_works(self):
|
|
||||||
# Now only work on the first GPU of self.device_ids
|
|
||||||
# _sync_reduction will use a seperate CUDA stream to uncoalesce
|
|
||||||
# the coalesced tensors to achieve more parallelisms
|
|
||||||
for bucket_idx, grads_batch in enumerate(self.buckets):
|
|
||||||
dist._sync_reduction(self.reduction_works[bucket_idx],
|
|
||||||
grads_batch[0],
|
|
||||||
self.buckets_coalesced[bucket_idx])
|
|
||||||
|
|
||||||
# Reset the module states
|
|
||||||
self.next_bucket = len(self.bucket_sizes) - 1
|
|
||||||
self.ready_buckets_not_reduced = set()
|
|
||||||
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
|
|
||||||
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
|
|
||||||
|
|
||||||
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
|
|
||||||
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
||||||
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
|
|
||||||
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
||||||
|
Reference in New Issue
Block a user