mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DDP] Support for multiple backwards (#59359)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59359 Move `prepare_for_backward` into `_DDPSink` backward instead of calling it in DDP forward pass so that we can run multiple backwards in DDP with `retain_graph=True`. ghstack-source-id: 131774159 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D28855226 fbshipit-source-id: 6b7b25d75b7696f5b5629078233433f97663d61c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3815a013ed
commit
d5df274ea5
@ -119,6 +119,7 @@ Reducer::Reducer(
|
||||
gradient_as_bucket_view_(gradient_as_bucket_view),
|
||||
local_used_maps_reduced_(false),
|
||||
num_iterations_(0),
|
||||
num_backward_calls_(0),
|
||||
num_buckets_ready_(0),
|
||||
has_rebuilt_bucket_(false),
|
||||
bucket_bytes_cap_(bucket_bytes_cap),
|
||||
@ -299,12 +300,12 @@ bool Reducer::dynamic_graph_find_unused() {
|
||||
return !static_graph_ && find_unused_parameters_;
|
||||
}
|
||||
|
||||
bool Reducer::static_graph_first_iteration() {
|
||||
return static_graph_ && num_iterations_ == 1;
|
||||
bool Reducer::static_graph_first_bwd() {
|
||||
return static_graph_ && num_backward_calls_ == 1;
|
||||
}
|
||||
|
||||
bool Reducer::static_graph_after_first_iteration() {
|
||||
return static_graph_ && num_iterations_ > 1;
|
||||
bool Reducer::static_graph_after_first_bwd() {
|
||||
return static_graph_ && num_backward_calls_ > 1;
|
||||
}
|
||||
|
||||
void Reducer::initialize_local_used_map() {
|
||||
@ -402,8 +403,7 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) {
|
||||
// Gradient is undefined. When find_unused_parameters=True, ensure it is
|
||||
// not marked as locally used, otherwise we will be allreducing zero's
|
||||
// instead of not touching .grad field of parameter.
|
||||
if (this->dynamic_graph_find_unused() ||
|
||||
this->static_graph_first_iteration()) {
|
||||
if (this->dynamic_graph_find_unused() || this->static_graph_first_bwd()) {
|
||||
REDUCER_CHECK(
|
||||
local_used_maps_[0][variable_index].item<int>() == 0,
|
||||
logger_,
|
||||
@ -483,6 +483,15 @@ void Reducer::push_rebuilt_params_for_all_indices() {
|
||||
}
|
||||
|
||||
void Reducer::push_rebuilt_params(const size_t& index) {
|
||||
// NOTE: We don't check this in should_rebuild_bucket because that controls
|
||||
// whether we should push rebuilt params and whether to actually kick off
|
||||
// process to rebuild buckets, if we check this in should_rebuild_buckets then
|
||||
// the latter would break.
|
||||
if (all_rebuilt_params_pushed_) {
|
||||
// We only enter here in the case we are calling multiple backwards with
|
||||
// retain_graph=True in the iteration before rebuilding buckets.
|
||||
return;
|
||||
}
|
||||
rebuilt_params_.push_back(replicas_[0][index]);
|
||||
rebuilt_param_indices_.push_back(index);
|
||||
}
|
||||
@ -569,7 +578,7 @@ void Reducer::autograd_hook(size_t index) {
|
||||
}
|
||||
|
||||
// See Note [Skip allreducing local_used_maps_dev]
|
||||
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
|
||||
if (dynamic_graph_find_unused() || static_graph_first_bwd()) {
|
||||
// Since it gets here, this param has been used for this iteration. We want
|
||||
// to mark it in local_used_maps_. During no_sync session, the same var can
|
||||
// be set multiple times, which is OK as does not affect correctness. As
|
||||
@ -587,7 +596,7 @@ void Reducer::autograd_hook(size_t index) {
|
||||
});
|
||||
}
|
||||
|
||||
if (static_graph_first_iteration()) {
|
||||
if (static_graph_first_bwd()) {
|
||||
numGradHooksTriggeredMap_[index] += 1;
|
||||
return;
|
||||
}
|
||||
@ -614,7 +623,7 @@ void Reducer::autograd_hook(size_t index) {
|
||||
// will be broadcasted and initialized.
|
||||
// If it is static graph, after 1st iteration, check if a variable
|
||||
// is ready for communication based on numGradHooksTriggeredMap_.
|
||||
if (static_graph_after_first_iteration()) {
|
||||
if (static_graph_after_first_bwd()) {
|
||||
REDUCER_CHECK(
|
||||
numGradHooksTriggeredMapPerIteration_[index] > 0,
|
||||
logger_,
|
||||
@ -825,7 +834,7 @@ void Reducer::mark_variable_ready(size_t variable_index) {
|
||||
}
|
||||
// Check that all buckets were completed and had their work kicked off.
|
||||
TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
|
||||
if (static_graph_after_first_iteration() && should_rebuild_buckets()) {
|
||||
if (static_graph_after_first_bwd() && should_rebuild_buckets()) {
|
||||
for (const auto& unused_index : unused_parameters_) {
|
||||
push_rebuilt_params(unused_index);
|
||||
}
|
||||
@ -906,9 +915,11 @@ void Reducer::initialize_buckets(
|
||||
this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
|
||||
#endif
|
||||
|
||||
// This shouldn't be called if we're expecting autograd hooks to fire.
|
||||
// Note that we check !require_finalize instead of !expect_autograd_hooks
|
||||
// since the latter is set in forward pass, and the former indicates
|
||||
// at least one gradient hook has fired and we are in autograd execution.
|
||||
REDUCER_CHECK(
|
||||
!expect_autograd_hooks_,
|
||||
!require_finalize_,
|
||||
logger_,
|
||||
"`initialize_buckets` must NOT be called during autograd execution.");
|
||||
|
||||
@ -1061,6 +1072,10 @@ void Reducer::initialize_buckets(
|
||||
|
||||
buckets_.push_back(std::move(bucket));
|
||||
}
|
||||
// Need to reset bucket.pending and variable.pending as buckets have been
|
||||
// re-initialized and they must be appropriately set before the next backward
|
||||
// pass.
|
||||
reset_bucket_counting();
|
||||
}
|
||||
|
||||
// (see Note: "Gradient Layout Contract" in initialize_buckets).
|
||||
@ -1138,14 +1153,29 @@ void Reducer::populate_bucket_views_out(
|
||||
}
|
||||
}
|
||||
|
||||
void Reducer::prepare_for_forward() {
|
||||
void Reducer::prepare_for_forward(bool will_run_grad_reduction) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
expect_autograd_hooks_ = will_run_grad_reduction;
|
||||
// To maintain compatibility with current version, where prepare_for_forward
|
||||
// is not called if will_run_grad_reduction is False.
|
||||
if (!expect_autograd_hooks_) {
|
||||
return;
|
||||
}
|
||||
num_iterations_++;
|
||||
if (should_collect_runtime_stats()) {
|
||||
record_forward_compute_start_time();
|
||||
}
|
||||
}
|
||||
|
||||
void Reducer::reset_variable_counting() {
|
||||
// Reset unused parameter accounting.
|
||||
has_marked_unused_parameters_ = false;
|
||||
// Reset per iteration marked ready parameters.
|
||||
perIterationReadyParams_.clear();
|
||||
// Reset bucket counting.
|
||||
reset_bucket_counting();
|
||||
}
|
||||
|
||||
void Reducer::reset_bucket_counting() {
|
||||
next_bucket_ = 0;
|
||||
// Reset num_buckets_ready_ at the beginning of backward computation
|
||||
@ -1227,22 +1257,12 @@ void Reducer::search_unused_parameters(
|
||||
void Reducer::prepare_for_backward(
|
||||
const std::vector<torch::autograd::Variable>& outputs) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
++num_backward_calls_;
|
||||
backward_compute_start_time_ = current_time_in_nanos();
|
||||
if (should_collect_runtime_stats()) {
|
||||
record_backward_compute_start_time();
|
||||
}
|
||||
|
||||
// Reset accounting.
|
||||
expect_autograd_hooks_ = true;
|
||||
|
||||
reset_bucket_counting();
|
||||
|
||||
// Reset unused parameter accounting.
|
||||
has_marked_unused_parameters_ = false;
|
||||
// Reset per iteration marked ready parameters.
|
||||
perIterationReadyParams_.clear();
|
||||
|
||||
// If static graph is not set, search graph to detect unused parameters.
|
||||
// When static graph is set, unused_parameters_ will be detected and will
|
||||
// not change after 1st iteration.
|
||||
@ -1402,9 +1422,9 @@ void Reducer::save_thread_local_state() {
|
||||
}
|
||||
|
||||
void Reducer::finalize_backward() {
|
||||
// No longer expect autograd hooks to fire after this function returns.
|
||||
// Note that we don't reset expect_autograd_hooks_ so that we can re-run
|
||||
// backwards with retain_graph=True.
|
||||
TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
|
||||
expect_autograd_hooks_ = false;
|
||||
|
||||
// No longer require call to finalize after this function returns.
|
||||
TORCH_INTERNAL_ASSERT(require_finalize_);
|
||||
@ -1445,7 +1465,7 @@ void Reducer::finalize_backward() {
|
||||
}
|
||||
|
||||
// See Note [Skip allreducing local_used_maps_dev]
|
||||
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
|
||||
if (dynamic_graph_find_unused() || static_graph_first_bwd()) {
|
||||
// Due to the lazy wait, it is possible that reduction of the current
|
||||
// iteration is still going when the one for next iteration gets kicked off.
|
||||
// For such case, we want to wait explicitly to make sure the reduction does
|
||||
@ -1466,6 +1486,15 @@ void Reducer::finalize_backward() {
|
||||
local_used_maps_reduced_ = false;
|
||||
}
|
||||
|
||||
// Reset various accounting variables including bucket counting to ensure we
|
||||
// can appropriately launch allreduce for each bucket in the next backwards.
|
||||
reset_variable_counting();
|
||||
// If we populated rebuilt params list in this backward call, avoid
|
||||
// repopulating in subsequent backward calls. In particular this is needed to
|
||||
// avoid re-pushing parameters when calling multiple backwards with
|
||||
// retain_graph=True.
|
||||
all_rebuilt_params_pushed_ = all_rebuilt_params_pushed_ || !rebuilt_params_.empty();
|
||||
|
||||
if (should_collect_runtime_stats()) {
|
||||
record_backward_comm_end_time();
|
||||
}
|
||||
|
@ -83,9 +83,9 @@ class TORCH_API Reducer {
|
||||
// a call to this function can simply be omitted.
|
||||
void prepare_for_backward(const std::vector<at::Tensor>& outputs);
|
||||
|
||||
// Called at the begginning of forward() inside DistributedDataParallel,
|
||||
// Called at the beginning of forward() inside DistributedDataParallel,
|
||||
// right now it caputures the starting time of forward in each iteration.
|
||||
void prepare_for_forward();
|
||||
void prepare_for_forward(bool will_run_grad_reduction = true);
|
||||
|
||||
// Returns the relative time in nanoseconds when gradients were ready,
|
||||
// with respect to the time `prepare_for_backward` was called. The outer
|
||||
@ -157,6 +157,12 @@ class TORCH_API Reducer {
|
||||
// Delay all reduce to be after all gradients' calculation is complete.
|
||||
void delay_all_reduce();
|
||||
|
||||
bool static_graph_first_bwd();
|
||||
|
||||
// Resets various counters Reducer uses to manager internal state such as
|
||||
// buckets that need to be reduced across workers.
|
||||
void reset_variable_counting();
|
||||
|
||||
// Weak reference to associated DDP logger. The reference is weak to avoid
|
||||
// refcycle between reducer and logger.
|
||||
void set_logger(std::weak_ptr<c10d::Logger> logger);
|
||||
@ -178,6 +184,8 @@ class TORCH_API Reducer {
|
||||
std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
|
||||
hooks_;
|
||||
|
||||
// Whether we need to run autograd hooks (only false if user runs with
|
||||
// no_grad or no_sync context manager)
|
||||
bool expect_autograd_hooks_;
|
||||
bool require_finalize_;
|
||||
size_t next_bucket_;
|
||||
@ -365,7 +373,13 @@ class TORCH_API Reducer {
|
||||
std::vector<VariableLocator> variable_locators_;
|
||||
|
||||
// track the number of iterations to synchronize grads in training so far.
|
||||
// This is the number of calls to the forward pass, not necessarily equal to
|
||||
// number of calls to backward pass.
|
||||
long num_iterations_;
|
||||
// Number of times backward() has been called. This is mainly used for static
|
||||
// graph training to know when to populate the map of how many times grad
|
||||
// hooks have been triggered.
|
||||
long num_backward_calls_;
|
||||
// track the number of buckets that have been ready for
|
||||
// communication calls like allReduce or communication hooks.
|
||||
int num_buckets_ready_;
|
||||
@ -392,7 +406,13 @@ class TORCH_API Reducer {
|
||||
bool is_multi_device_module_ = false;
|
||||
|
||||
// Following variables are to help build dynamic bucket order
|
||||
// Whether the process of rebuilding buckets has occured.
|
||||
bool has_rebuilt_bucket_;
|
||||
// Flag indicating all rebuilt param indices have been pushed. This is needed
|
||||
// because there can be multiple calls to backward with retain_graph=True
|
||||
// without a forward that actually rebuilds the buckets. In this case, we use
|
||||
// this flag to avoid pushing parameters multiple times.
|
||||
bool all_rebuilt_params_pushed_{false};
|
||||
std::vector<at::Tensor> rebuilt_params_;
|
||||
std::vector<int64_t> rebuilt_param_indices_;
|
||||
const int64_t bucket_bytes_cap_;
|
||||
@ -457,8 +477,7 @@ class TORCH_API Reducer {
|
||||
// get current cuda stream
|
||||
const c10::Stream get_current_stream();
|
||||
bool dynamic_graph_find_unused();
|
||||
bool static_graph_first_iteration();
|
||||
bool static_graph_after_first_iteration();
|
||||
bool static_graph_after_first_bwd();
|
||||
|
||||
// comm_hook_ is used to access the DDP communication hook if registered.
|
||||
std::unique_ptr<CommHookInterface> comm_hook_;
|
||||
|
Reference in New Issue
Block a user