[DDP] Support for multiple backwards (#59359)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59359

Move `prepare_for_backward` into `_DDPSink` backward instead of calling it in DDP forward pass so that we can run multiple backwards in DDP with `retain_graph=True`.

ghstack-source-id: 131774159

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D28855226

fbshipit-source-id: 6b7b25d75b7696f5b5629078233433f97663d61c
This commit is contained in:
Rohan Varma
2021-06-18 09:18:28 -07:00
committed by Facebook GitHub Bot
parent 3815a013ed
commit d5df274ea5
6 changed files with 285 additions and 84 deletions

View File

@ -119,6 +119,7 @@ Reducer::Reducer(
gradient_as_bucket_view_(gradient_as_bucket_view),
local_used_maps_reduced_(false),
num_iterations_(0),
num_backward_calls_(0),
num_buckets_ready_(0),
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
@ -299,12 +300,12 @@ bool Reducer::dynamic_graph_find_unused() {
return !static_graph_ && find_unused_parameters_;
}
bool Reducer::static_graph_first_iteration() {
return static_graph_ && num_iterations_ == 1;
bool Reducer::static_graph_first_bwd() {
return static_graph_ && num_backward_calls_ == 1;
}
bool Reducer::static_graph_after_first_iteration() {
return static_graph_ && num_iterations_ > 1;
bool Reducer::static_graph_after_first_bwd() {
return static_graph_ && num_backward_calls_ > 1;
}
void Reducer::initialize_local_used_map() {
@ -402,8 +403,7 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) {
// Gradient is undefined. When find_unused_parameters=True, ensure it is
// not marked as locally used, otherwise we will be allreducing zero's
// instead of not touching .grad field of parameter.
if (this->dynamic_graph_find_unused() ||
this->static_graph_first_iteration()) {
if (this->dynamic_graph_find_unused() || this->static_graph_first_bwd()) {
REDUCER_CHECK(
local_used_maps_[0][variable_index].item<int>() == 0,
logger_,
@ -483,6 +483,15 @@ void Reducer::push_rebuilt_params_for_all_indices() {
}
void Reducer::push_rebuilt_params(const size_t& index) {
// NOTE: We don't check this in should_rebuild_bucket because that controls
// whether we should push rebuilt params and whether to actually kick off
// process to rebuild buckets, if we check this in should_rebuild_buckets then
// the latter would break.
if (all_rebuilt_params_pushed_) {
// We only enter here in the case we are calling multiple backwards with
// retain_graph=True in the iteration before rebuilding buckets.
return;
}
rebuilt_params_.push_back(replicas_[0][index]);
rebuilt_param_indices_.push_back(index);
}
@ -569,7 +578,7 @@ void Reducer::autograd_hook(size_t index) {
}
// See Note [Skip allreducing local_used_maps_dev]
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
if (dynamic_graph_find_unused() || static_graph_first_bwd()) {
// Since it gets here, this param has been used for this iteration. We want
// to mark it in local_used_maps_. During no_sync session, the same var can
// be set multiple times, which is OK as does not affect correctness. As
@ -587,7 +596,7 @@ void Reducer::autograd_hook(size_t index) {
});
}
if (static_graph_first_iteration()) {
if (static_graph_first_bwd()) {
numGradHooksTriggeredMap_[index] += 1;
return;
}
@ -614,7 +623,7 @@ void Reducer::autograd_hook(size_t index) {
// will be broadcasted and initialized.
// If it is static graph, after 1st iteration, check if a variable
// is ready for communication based on numGradHooksTriggeredMap_.
if (static_graph_after_first_iteration()) {
if (static_graph_after_first_bwd()) {
REDUCER_CHECK(
numGradHooksTriggeredMapPerIteration_[index] > 0,
logger_,
@ -825,7 +834,7 @@ void Reducer::mark_variable_ready(size_t variable_index) {
}
// Check that all buckets were completed and had their work kicked off.
TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
if (static_graph_after_first_iteration() && should_rebuild_buckets()) {
if (static_graph_after_first_bwd() && should_rebuild_buckets()) {
for (const auto& unused_index : unused_parameters_) {
push_rebuilt_params(unused_index);
}
@ -906,9 +915,11 @@ void Reducer::initialize_buckets(
this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif
// This shouldn't be called if we're expecting autograd hooks to fire.
// Note that we check !require_finalize instead of !expect_autograd_hooks
// since the latter is set in forward pass, and the former indicates
// at least one gradient hook has fired and we are in autograd execution.
REDUCER_CHECK(
!expect_autograd_hooks_,
!require_finalize_,
logger_,
"`initialize_buckets` must NOT be called during autograd execution.");
@ -1061,6 +1072,10 @@ void Reducer::initialize_buckets(
buckets_.push_back(std::move(bucket));
}
// Need to reset bucket.pending and variable.pending as buckets have been
// re-initialized and they must be appropriately set before the next backward
// pass.
reset_bucket_counting();
}
// (see Note: "Gradient Layout Contract" in initialize_buckets).
@ -1138,14 +1153,29 @@ void Reducer::populate_bucket_views_out(
}
}
void Reducer::prepare_for_forward() {
void Reducer::prepare_for_forward(bool will_run_grad_reduction) {
std::lock_guard<std::mutex> lock(mutex_);
expect_autograd_hooks_ = will_run_grad_reduction;
// To maintain compatibility with current version, where prepare_for_forward
// is not called if will_run_grad_reduction is False.
if (!expect_autograd_hooks_) {
return;
}
num_iterations_++;
if (should_collect_runtime_stats()) {
record_forward_compute_start_time();
}
}
void Reducer::reset_variable_counting() {
// Reset unused parameter accounting.
has_marked_unused_parameters_ = false;
// Reset per iteration marked ready parameters.
perIterationReadyParams_.clear();
// Reset bucket counting.
reset_bucket_counting();
}
void Reducer::reset_bucket_counting() {
next_bucket_ = 0;
// Reset num_buckets_ready_ at the beginning of backward computation
@ -1227,22 +1257,12 @@ void Reducer::search_unused_parameters(
void Reducer::prepare_for_backward(
const std::vector<torch::autograd::Variable>& outputs) {
std::lock_guard<std::mutex> lock(mutex_);
++num_backward_calls_;
backward_compute_start_time_ = current_time_in_nanos();
if (should_collect_runtime_stats()) {
record_backward_compute_start_time();
}
// Reset accounting.
expect_autograd_hooks_ = true;
reset_bucket_counting();
// Reset unused parameter accounting.
has_marked_unused_parameters_ = false;
// Reset per iteration marked ready parameters.
perIterationReadyParams_.clear();
// If static graph is not set, search graph to detect unused parameters.
// When static graph is set, unused_parameters_ will be detected and will
// not change after 1st iteration.
@ -1402,9 +1422,9 @@ void Reducer::save_thread_local_state() {
}
void Reducer::finalize_backward() {
// No longer expect autograd hooks to fire after this function returns.
// Note that we don't reset expect_autograd_hooks_ so that we can re-run
// backwards with retain_graph=True.
TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
expect_autograd_hooks_ = false;
// No longer require call to finalize after this function returns.
TORCH_INTERNAL_ASSERT(require_finalize_);
@ -1445,7 +1465,7 @@ void Reducer::finalize_backward() {
}
// See Note [Skip allreducing local_used_maps_dev]
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
if (dynamic_graph_find_unused() || static_graph_first_bwd()) {
// Due to the lazy wait, it is possible that reduction of the current
// iteration is still going when the one for next iteration gets kicked off.
// For such case, we want to wait explicitly to make sure the reduction does
@ -1466,6 +1486,15 @@ void Reducer::finalize_backward() {
local_used_maps_reduced_ = false;
}
// Reset various accounting variables including bucket counting to ensure we
// can appropriately launch allreduce for each bucket in the next backwards.
reset_variable_counting();
// If we populated rebuilt params list in this backward call, avoid
// repopulating in subsequent backward calls. In particular this is needed to
// avoid re-pushing parameters when calling multiple backwards with
// retain_graph=True.
all_rebuilt_params_pushed_ = all_rebuilt_params_pushed_ || !rebuilt_params_.empty();
if (should_collect_runtime_stats()) {
record_backward_comm_end_time();
}

View File

@ -83,9 +83,9 @@ class TORCH_API Reducer {
// a call to this function can simply be omitted.
void prepare_for_backward(const std::vector<at::Tensor>& outputs);
// Called at the begginning of forward() inside DistributedDataParallel,
// Called at the beginning of forward() inside DistributedDataParallel,
// right now it caputures the starting time of forward in each iteration.
void prepare_for_forward();
void prepare_for_forward(bool will_run_grad_reduction = true);
// Returns the relative time in nanoseconds when gradients were ready,
// with respect to the time `prepare_for_backward` was called. The outer
@ -157,6 +157,12 @@ class TORCH_API Reducer {
// Delay all reduce to be after all gradients' calculation is complete.
void delay_all_reduce();
bool static_graph_first_bwd();
// Resets various counters Reducer uses to manager internal state such as
// buckets that need to be reduced across workers.
void reset_variable_counting();
// Weak reference to associated DDP logger. The reference is weak to avoid
// refcycle between reducer and logger.
void set_logger(std::weak_ptr<c10d::Logger> logger);
@ -178,6 +184,8 @@ class TORCH_API Reducer {
std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
hooks_;
// Whether we need to run autograd hooks (only false if user runs with
// no_grad or no_sync context manager)
bool expect_autograd_hooks_;
bool require_finalize_;
size_t next_bucket_;
@ -365,7 +373,13 @@ class TORCH_API Reducer {
std::vector<VariableLocator> variable_locators_;
// track the number of iterations to synchronize grads in training so far.
// This is the number of calls to the forward pass, not necessarily equal to
// number of calls to backward pass.
long num_iterations_;
// Number of times backward() has been called. This is mainly used for static
// graph training to know when to populate the map of how many times grad
// hooks have been triggered.
long num_backward_calls_;
// track the number of buckets that have been ready for
// communication calls like allReduce or communication hooks.
int num_buckets_ready_;
@ -392,7 +406,13 @@ class TORCH_API Reducer {
bool is_multi_device_module_ = false;
// Following variables are to help build dynamic bucket order
// Whether the process of rebuilding buckets has occured.
bool has_rebuilt_bucket_;
// Flag indicating all rebuilt param indices have been pushed. This is needed
// because there can be multiple calls to backward with retain_graph=True
// without a forward that actually rebuilds the buckets. In this case, we use
// this flag to avoid pushing parameters multiple times.
bool all_rebuilt_params_pushed_{false};
std::vector<at::Tensor> rebuilt_params_;
std::vector<int64_t> rebuilt_param_indices_;
const int64_t bucket_bytes_cap_;
@ -457,8 +477,7 @@ class TORCH_API Reducer {
// get current cuda stream
const c10::Stream get_current_stream();
bool dynamic_graph_find_unused();
bool static_graph_first_iteration();
bool static_graph_after_first_iteration();
bool static_graph_after_first_bwd();
// comm_hook_ is used to access the DDP communication hook if registered.
std::unique_ptr<CommHookInterface> comm_hook_;