Make DistributedDataParallel use new reducer (#18953)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18953

This removes Python side bucketing code from DistributedDataParallel
and replaces it with calls to the new C++ based bucketing and reducing
code. To confirm this is working well, we ran a test with both the
previous implementation and the new implementation, and confirmed they
are numerically equivalent.

Performance is improved by a couple percent or more, including the
single machine multiple GPU runs.

Closes #13273.

Reviewed By: mrshenli

Differential Revision: D14580911

fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2
This commit is contained in:
Pieter Noordhuis
2019-04-15 12:24:43 -07:00
committed by Facebook Github Bot
parent 6ed57e052d
commit a0263ec047
5 changed files with 219 additions and 178 deletions

View File

@ -1934,6 +1934,52 @@ class ReducerTest(TestCase):
optimizer.step() optimizer.step()
class ComputeBucketAssignmentTest(TestCase):
def test_single_limit_single_dtype(self):
tensors = [
torch.empty([100], dtype=torch.float),
torch.empty([200], dtype=torch.float),
torch.empty([100], dtype=torch.float),
torch.empty([50], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0], [1], [2], [3]], result)
def test_single_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
def test_multi_limit_single_dtype(self):
tensors = [
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
self.assertEqual([[0], [1, 2], [3]], result)
def test_multi_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
if __name__ == '__main__': if __name__ == '__main__':
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process" assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"

View File

@ -500,6 +500,13 @@ They are used in specifying strategies for reduction collectives, e.g.,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#endif #endif
module.def(
"_compute_bucket_assignment_by_size",
&::c10d::compute_bucket_assignment_by_size,
py::arg("tensors"),
py::arg("bucket_size"),
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }

View File

@ -7,6 +7,7 @@
#include <torch/csrc/autograd/function_hook.h> #include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h> #include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/profiler.h> #include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/utils/hash.h>
#include <torch/csrc/utils/memory.h> #include <torch/csrc/utils/memory.h>
namespace c10d { namespace c10d {
@ -361,6 +362,13 @@ void Reducer::prepare_for_backward(
bucket.pending = bucket.replicas.size(); bucket.pending = bucket.replicas.size();
} }
// If no outputs are specified, we assume that autograd hooks for ALL
// variables will be called, and we don't have to search the autograd graph
// for presence of these hooks.
if (outputs.empty()) {
return;
}
// Seed queue with the grad functions of all outputs. // Seed queue with the grad functions of all outputs.
for (const auto& output : outputs) { for (const auto& output : outputs) {
const auto& grad_fn = output.grad_fn(); const auto& grad_fn = output.grad_fn();
@ -433,4 +441,106 @@ void Reducer::finalize_backward() {
} }
} }
namespace {
// Tensors may be coalesced into buckets. Buckets must contain tensors of
// the same type, on the same device, so a bucket can identified by a
// composite key of a tensor's type identifier and its device.
struct BucketKey {
BucketKey(c10::ScalarType type, c10::Device device)
: type(std::move(type)), device(std::move(device)) {}
const c10::ScalarType type;
const c10::Device device;
// See torch/csrc/utils/hash.h for dispatch code.
static size_t hash(const BucketKey& key) {
return torch::get_hash(key.type, key.device);
}
};
inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) {
return lhs.type == rhs.type && lhs.device == rhs.device;
}
} // namespace
// This is equivalent to take_tensors but returns indices into the
// tensor list argument for bucket assignment. Also, it is aware
// of device placement and will not allow buckets to span devices.
std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
std::vector<size_t> bucket_size_limits) {
std::vector<std::vector<size_t>> result;
result.reserve(tensors.size());
// Keep iterator into the size_limit vector by tensor type and device.
// This is done so that we can use the consecutive bucket limits per type.
std::unordered_map<
BucketKey,
std::vector<size_t>::iterator,
torch::hash<BucketKey>>
bucket_size_limit_iterators;
// Local accumulator type for a single bucket.
struct BucketAccumulator {
std::vector<size_t> indices;
size_t size = 0;
};
// Keep vector of indices and size accumulator by tensor type and device.
std::unordered_map<BucketKey, BucketAccumulator, torch::hash<BucketKey>>
buckets;
for (size_t i = 0; i < tensors.size(); i++) {
const auto& tensor = tensors[i];
AT_ASSERTM(!tensor.is_sparse(), "No support for sparse tensors.");
auto key = BucketKey(tensor.scalar_type(), tensor.device());
auto& bucket = buckets[key];
bucket.indices.push_back(i);
bucket.size += tensor.numel() * tensor.element_size();
// Initialize bucket size limit iterator if necessary.
if (bucket_size_limit_iterators.count(key) == 0) {
bucket_size_limit_iterators[key] = bucket_size_limits.begin();
}
auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
const auto bucket_size_limit = *bucket_size_limit_iterator;
if (bucket.size >= bucket_size_limit) {
result.emplace_back(std::move(bucket.indices));
bucket = BucketAccumulator();
// Advance to the next bucket size limit for this type/device.
auto next = bucket_size_limit_iterator + 1;
if (next != bucket_size_limits.end()) {
bucket_size_limit_iterator = next;
}
}
}
// Add remaining buckets.
for (auto& it : buckets) {
auto& bucket = it.second;
if (!bucket.indices.empty()) {
result.emplace_back(std::move(bucket.indices));
}
}
// Sort resulting buckets by the minimum tensor index they include.
// We assume that the order of the tensors is the order in which they are
// used (or the reverse order in which their gradients are produced).
// This sorting step ensures that the buckets are ready in consecutive order.
std::sort(
result.begin(),
result.end(),
[](const std::vector<size_t>& a, const std::vector<size_t>& b) {
const auto amin = std::min_element(a.begin(), a.end());
const auto bmin = std::min_element(b.begin(), b.end());
return *amin < *bmin;
});
return result;
}
} // namespace c10d } // namespace c10d

View File

@ -139,4 +139,8 @@ class Reducer {
std::vector<std::vector<int64_t>> backward_stats_; std::vector<std::vector<int64_t>> backward_stats_;
}; };
std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
std::vector<size_t> bucket_size);
} // namespace c10d } // namespace c10d

View File

@ -209,15 +209,19 @@ class DistributedDataParallel(Module):
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True) self.output_device = _get_device_index(output_device, True)
self.broadcast_buffers = broadcast_buffers self.broadcast_buffers = broadcast_buffers
self.check_reduction = check_reduction if check_reduction:
# This argument is no longer used since the reducer
# will ensure reduction completes even if some parameters
# do not receive gradients.
pass
MB = 1024 * 1024 MB = 1024 * 1024
# used for intra-node param sync and inter-node sync as well # used for intra-node param sync and inter-node sync as well
self.broadcast_bucket_size = 250 * MB self.broadcast_bucket_size = int(250 * MB)
# reduction bucket size # reduction bucket size
self.bucket_bytes_cap = bucket_cap_mb * MB self.bucket_bytes_cap = int(bucket_cap_mb * MB)
# Sync params and buffers # Sync params and buffers
module_states = list(self.module.state_dict().values()) module_states = list(self.module.state_dict().values())
@ -254,60 +258,26 @@ class DistributedDataParallel(Module):
self.modules_params = [list(m.parameters()) for m in self._module_copies] self.modules_params = [list(m.parameters()) for m in self._module_copies]
self.modules_buffers = [list(m.buffers()) for m in self._module_copies] self.modules_buffers = [list(m.buffers()) for m in self._module_copies]
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_list = [
param_buckets = [] list(filter(lambda p: p.requires_grad, module.parameters()))
for module in self._module_copies]
# Split the parameters into buckets and by types as well # The bucket size limit is specified in the constructor.
# We only need to bucket and reduce parameters that require grad and # Additionally, we allow for a single small bucket for parameters
# this is also true for backward since only the backward hooks for # that are defined first, such that their gradients don't spill into
# parameters that require grad will be registered with gradient # a much larger bucket, adding unnecessary latency after gradient
# reduction functions # computation finishes. Experiments showed 1MB is a reasonable value.
params_to_bucket = [[] for _ in self._module_copies] bucket_indices = dist._compute_bucket_assignment_by_size(
for dev_idx, m in enumerate(self._module_copies): param_list[0],
for p in m.parameters(): [1024 * 1024, self.bucket_bytes_cap])
if p.requires_grad:
params_to_bucket[dev_idx].append(p)
param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket, # Note: reverse list of buckets because we want to approximate the
int(self.bucket_bytes_cap), # order in which their gradients are produced, and assume they
fine_grained=False) # are used in the forward pass in the order they are defined.
for dev_params_to_bucket in params_to_bucket] self.reducer = dist.Reducer(
param_list,
self.bucket_sizes = [] list(reversed(bucket_indices)),
self.bucket_map = {} self.process_group)
# We transpose param_buckets, so the loop is over buckets.
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
self.bucket_sizes.append(0)
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
# of params from each device.
for param_tuple in zip(*param_buckets_tuple):
if not param_tuple[0].requires_grad:
continue
for p in param_tuple:
self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx])
self.bucket_sizes[bucket_idx] += 1
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
# The number of params ready in each bucket
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
# coalesced bucket for only device 0
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
# We will always reduce the bucket following the reverse order
# that is, alway reduces following the order of: n - 1, n - 2, ..., 0
self.next_bucket = len(self.bucket_sizes) - 1
# When all buckets are reduced, this will be set to True. This flag is
# useful for sanity checks to ensure that each iteration's backward has
# always reduced all buckets
self.all_buckets_reduced = False
self.check_previous_reduction = False
self.ready_buckets_not_reduced = set()
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
self._register_grad_hooks()
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self._module_copies) self._passing_sync_batchnorm_handle(self._module_copies)
@ -315,15 +285,13 @@ class DistributedDataParallel(Module):
def __getstate__(self): def __getstate__(self):
self._check_default_group() self._check_default_group()
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
del attrs['process_group'], \ del attrs['process_group']
attrs['default_streams'], \ del attrs['reducer']
attrs['_grad_accs']
return attrs return attrs
def __setstate__(self, state): def __setstate__(self, state):
# If serializable, then the process group should be the default one # If serializable, then the process group should be the default one
self.process_group = _get_default_group() self.process_group = _get_default_group()
self.check_previous_reduction = False
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self._ddp_init_helper() self._ddp_init_helper()
@ -342,32 +310,28 @@ class DistributedDataParallel(Module):
"init_process_group and have not passed " "init_process_group and have not passed "
"process_group argument to DDP constructor") "process_group argument to DDP constructor")
def _check_previous_reduction(self):
if not self.training:
return
# self.check_previous_reduction will be False in the first iteration
# and is then toggled to True for all future iterations.
if self.check_previous_reduction is False:
self.check_previous_reduction = True
else:
if not self.all_buckets_reduced:
raise RuntimeError("Not all gradients have been reduced from "
"the backward of the previous iteration. "
"This is an unexpected and fatal error. "
"Please check and ensure that the model's "
"parameters are not changed after you wrap "
"up the model with DistributedDataParallel.")
self.all_buckets_reduced = False
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if self.check_reduction:
self._check_previous_reduction()
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
self._sync_params() self._sync_params()
if len(self.device_ids) == 1: if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0]) output = self.module(*inputs[0], **kwargs[0])
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) else:
return self.gather(outputs, self.output_device) outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
# We'll return the output object verbatim since it is a freeform object.
# We need to find any tensors in this object, though, because we need to
# figure out which parameters were used during this forward pass,
# to ensure we short circuit reduction for any unused parameters.
output_tensors = []
if isinstance(output, torch.Tensor):
output_tensors = [output]
if isinstance(output, (list, tuple)):
def istensor(obj):
return isinstance(obj, torch.Tensor)
output_tensors = list(filter(istensor, output))
self.reducer.prepare_for_backward(output_tensors)
return output
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
@ -379,7 +343,6 @@ class DistributedDataParallel(Module):
return gather(outputs, output_device, dim=self.dim) return gather(outputs, output_device, dim=self.dim)
def train(self, mode=True): def train(self, mode=True):
self.check_previous_reduction = False
super(DistributedDataParallel, self).train(mode) super(DistributedDataParallel, self).train(mode)
for module in self._module_copies[1:]: for module in self._module_copies[1:]:
module.train(mode) module.train(mode)
@ -398,6 +361,13 @@ class DistributedDataParallel(Module):
self.modules_params[1:]): self.modules_params[1:]):
for tensor, param in zip(tensors, module_params): for tensor, param in zip(tensors, module_params):
param.set_(tensor) param.set_(tensor)
# Assume we have just run the optimizer and zeroed the
# grads of the parameters on the root model. We need
# to zero the grads on all model replicas as well.
# This snippet is copied from torch.optim.Optimizer.
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
# module buffer sync # module buffer sync
if self.broadcast_buffers and len(self.modules_buffers[0]) > 0: if self.broadcast_buffers and len(self.modules_buffers[0]) > 0:
@ -419,99 +389,3 @@ class DistributedDataParallel(Module):
for layer in module.modules(): for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm): if isinstance(layer, torch.nn.modules.SyncBatchNorm):
layer._specify_ddp_gpu_num(len(self.device_ids)) layer._specify_ddp_gpu_num(len(self.device_ids))
def _register_grad_hooks(self):
self._grad_accs = [] # need to keep them in scope
# default stream tracking to launch nccl reduce kernels
self.default_streams = []
for dev_id in self.device_ids:
with torch.cuda.device(dev_id):
self.default_streams.append(torch.cuda.current_stream())
for device_idx, module in enumerate(self._module_copies):
for p in module.parameters():
if p.requires_grad:
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(p, device_idx))
self._grad_accs.append(grad_acc)
def _make_param_hook(self, param, device_idx):
bucket_idx, bucket_offset = self.bucket_map[param]
def distributed_data_parallel_hook(*unused):
if param.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works "
"with gradients that don't require grad")
bucket = self.buckets[bucket_idx][device_idx]
bucket[bucket_offset] = param.grad.data
self.buckets_ready_size[bucket_idx][device_idx] += 1
# We can flush these and save memory for replicas
if device_idx > 0:
param.grad = None
with torch.no_grad():
param.set_()
# Current device's bucket is full
if self.buckets_ready_size[bucket_idx][device_idx] == self.bucket_sizes[bucket_idx]:
self.devs_ready[bucket_idx] += 1
if self.devs_ready[bucket_idx] < len(self.device_ids):
return
# Now all devices's buckets with index: bucket_idx are ready
if bucket_idx == self.next_bucket:
self._queue_reduction(bucket_idx)
self.next_bucket -= 1
# Now reduce anything that is ready but not yet reduced
if len(self.ready_buckets_not_reduced) > 0:
sorted_todo = sorted(self.ready_buckets_not_reduced, reverse=True)
for i in sorted_todo:
# Nothing can be reduced now
if i < self.next_bucket:
break
self._queue_reduction(i)
self.ready_buckets_not_reduced.remove(i)
if i == self.next_bucket:
self.next_bucket -= 1
else:
self.ready_buckets_not_reduced.add(bucket_idx)
# When all devices' buckets
if self.next_bucket == -1:
# A final sync for all the reduction works
self._sync_reduction_works()
self.all_buckets_reduced = True
return distributed_data_parallel_hook
def _queue_reduction(self, bucket_idx):
# _queue_reduction will use a seperate CUDA stream to coalesce
# the small tensors to achieve more parallelisms, before passing the
# coalesced tensor into the c10d CUDA stream for reduction
result = dist._queue_reduction(self.process_group,
self.buckets[bucket_idx],
self.device_ids)
self.reduction_works[bucket_idx] = result[0]
self.buckets_coalesced[bucket_idx] = result[1]
def _sync_reduction_works(self):
# Now only work on the first GPU of self.device_ids
# _sync_reduction will use a seperate CUDA stream to uncoalesce
# the coalesced tensors to achieve more parallelisms
for bucket_idx, grads_batch in enumerate(self.buckets):
dist._sync_reduction(self.reduction_works[bucket_idx],
grads_batch[0],
self.buckets_coalesced[bucket_idx])
# Reset the module states
self.next_bucket = len(self.bucket_sizes) - 1
self.ready_buckets_not_reduced = set()
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]