mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
[PyTorch][Dist] Trigger pre/post hooks of output function nodes under distributed autograd (#34501)
Summary: # Goals Do the following things during a distributed backward pass. 1. Accumulate the gradient of a variable to RPC context once the gradient is ready instead of at the very end of the backward pass. 2. Run post/pre hooks installed in`AccumulateGrad` nodes once the gradient is ready for the variable. Currently, the hooks in `AccumulateGrad` are not executed just because the function `AccumulateGrad` itself is not even evaluated by the local engine. 3. Make it extensible to support post hooks installed by DDP's reducer. # Introduce GradCapturePreHook ## Why do we need this? ### Root issue: * dist engine uses the autograd.grad-like API on the vanilla engine and then in the Future callback populates the context with the gradients. This is a bad emulation of the .backward() call on the vanilla engine. ### Practical issue: * The leaf’s hook are not called (because associated with the AccumulateGrad that is not call in the autograd.grad-like API). Modules like DDP rely on these hooks. * The Future is marked as completed before the context is actually populated with the grads leading to unexpected behavior on the user side. * The Future callback is only called at the complete end of the backward and so too late for DDP if they want to overlap compute/transfert. ### Proposed solution: * Provide hooks in the autograd.grad-like API that will allow the distributed engine to populate the context and call the hooks to better emulate the .backward call. ## Who can install a grad capture pre-hook? This will be an internal hook at C++ level and it won’t be exposed to PyThon code. Only call-sites directly interacting with the local engine can install such hooks. ## Signature The returned `grad` will be captured. ``` virtual const torch::Tensor& grad operator()(const torch::Tensor& grads) = 0; ``` ## Where are hooks installed? Grad capture pre-hooks are install in GraphTask::ExecInfo::Capture. ExecInfo is per node. Every backward run will have its own GraphTask instance. ## When/How will hooks be called? When the local engine captures the grads for a node, all grad capture pre hooks are called one by one in the order they are added. The output grads of the hooks will replace the original grads. The output of the last hook will be used for grad capturing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34501 Test Plan: All existing tests should pass. ``` python setup.py develop python test/distributed/rpc/test_dist_autograd_spawn.py DistAutogradTestWithSpawn.test_post_hooks ``` Differential Revision: D20953673 Pulled By: hczhu fbshipit-source-id: 543b3844823330ea9f9856bab7c5cb2679290a53
This commit is contained in:
committed by
Facebook GitHub Bot
parent
97d3a8495d
commit
ea97fa1f2a
@ -631,9 +631,12 @@ void Engine::evaluate_function(
|
||||
if (auto* capture_vec = fn_info.captures_.get()) {
|
||||
// Lock mutex for writing to graph_task->captured_vars_.
|
||||
std::lock_guard<std::mutex> lock(graph_task->mutex_);
|
||||
for (auto capture : *capture_vec) {
|
||||
graph_task->captured_vars_[capture.output_idx_] =
|
||||
inputs[capture.input_idx_];
|
||||
for (const auto& capture : *capture_vec) {
|
||||
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
|
||||
captured_grad = inputs[capture.input_idx_];
|
||||
for (auto& hook : capture.hooks_) {
|
||||
captured_grad = (*hook)(captured_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!fn_info.needed_) {
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
// Engine implements backpropagation from output variables and their gradients
|
||||
// to "root" variables (variables created by the user with requires_grad=True).
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/autograd/anomaly_mode.h>
|
||||
@ -64,10 +65,25 @@ struct GraphTask {
|
||||
|
||||
struct ExecInfo {
|
||||
struct Capture {
|
||||
Capture(const Capture&) = delete;
|
||||
Capture(Capture&&) = default;
|
||||
|
||||
Capture(int input_idx, int output_idx)
|
||||
: input_idx_(input_idx), output_idx_(output_idx) {}
|
||||
int input_idx_; // within Node inputs
|
||||
int output_idx_; // within the output vector of a GraphTask
|
||||
|
||||
// This hook will be executed after a grad is captured. The captured
|
||||
// grad will be replaced by the return value of the hook.
|
||||
struct GradCaptureHook {
|
||||
virtual ~GradCaptureHook() = default;
|
||||
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
|
||||
};
|
||||
// The hooks will be called one by one in the order as they were added.
|
||||
// The input grad of a hook will be the output of its preceding hook. The
|
||||
// first hook will take the captured grad as the input. The output of the
|
||||
// last hook will replace the captured grad.
|
||||
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
|
||||
};
|
||||
|
||||
bool should_execute() const {
|
||||
|
||||
@ -544,7 +544,7 @@ static void _trace_post_record(
|
||||
return;
|
||||
}
|
||||
|
||||
node->i_(attr::inplace, is_inplace);
|
||||
node->i_(jit::attr::inplace, is_inplace);
|
||||
|
||||
// Isolate C variable ptrs in a vector
|
||||
int num_outputs = PyTuple_GET_SIZE(output_objects);
|
||||
|
||||
@ -35,7 +35,7 @@ static bool in_bad_fork = false; // True for children forked after cuda init
|
||||
// Called in the forked child if cuda has already been initialized
|
||||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
utils::set_run_yet_variable_to_false();
|
||||
torch::utils::set_run_yet_variable_to_false();
|
||||
state = nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -71,6 +71,7 @@ class TORCH_API DistAutogradContext {
|
||||
friend class BackwardPassCleanupGuard;
|
||||
friend class DistEngine;
|
||||
friend class RecvRpcBackward;
|
||||
friend class DistAccumulateGradCaptureHook;
|
||||
|
||||
// Record that we would like to accumulate the provided gradient on the given
|
||||
// variable.
|
||||
|
||||
@ -23,6 +23,52 @@ using torch::autograd::variable_list;
|
||||
static constexpr char* kNumBackwardPasses = "num_current_backward_passes";
|
||||
static constexpr char* kNumAutogradContexts = "num_autograd_contexts";
|
||||
|
||||
// This hook does 3 things:
|
||||
// 1. Call pre hooks of the original AccumulateGrad to modify the input grad.
|
||||
// 2. Accumuate the gard to RPC context.
|
||||
// 3. Call post hooks of the original AccumulateGrad.
|
||||
class DistAccumulateGradCaptureHook
|
||||
: public GraphTask::ExecInfo::Capture::GradCaptureHook {
|
||||
public:
|
||||
DistAccumulateGradCaptureHook(
|
||||
std::shared_ptr<AccumulateGrad> accumulateGrad,
|
||||
ContextPtr autogradContext)
|
||||
: accumulateGrad_(std::move(accumulateGrad)),
|
||||
autogradContext_(std::move(autogradContext)) {}
|
||||
|
||||
at::Tensor operator()(const at::Tensor& grad) override {
|
||||
variable_list inputGrads = {grad};
|
||||
// It's intended that pre/post hooks are still called even if the grad is
|
||||
// undenfined here.
|
||||
for (const auto& hook : accumulateGrad_->pre_hooks()) {
|
||||
inputGrads = (*hook)(inputGrads);
|
||||
}
|
||||
|
||||
// It is possible that the grad is not defined since a separate
|
||||
// invocation of the autograd engine on the same node might actually
|
||||
// compute this gradient.
|
||||
if (inputGrads[0].defined()) {
|
||||
// There are 3 internal references to 'inputGrads[0]' at this moment:
|
||||
// 1. 'inputGrads[0]' in this function.
|
||||
// 2. 'graph_task->captured_vars_' on the callsite in the local engine.
|
||||
// 3. 'InputBuffer& inputs' on the callsite as the inputs of the
|
||||
// function node.
|
||||
autogradContext_->accumulateGrad(
|
||||
accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
|
||||
}
|
||||
|
||||
const variable_list kEmptyOuput;
|
||||
for (const auto& hook : accumulateGrad_->post_hooks()) {
|
||||
(*hook)(kEmptyOuput, inputGrads);
|
||||
}
|
||||
return inputGrads[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<AccumulateGrad> accumulateGrad_;
|
||||
ContextPtr autogradContext_;
|
||||
};
|
||||
|
||||
DistEngine::DistEngine()
|
||||
: initializedContextIds_(), engine_(Engine::get_default_engine()) {}
|
||||
|
||||
@ -179,6 +225,24 @@ void DistEngine::computeDependencies(
|
||||
// Create a dummy GraphRoot and run init_to_execute with it.
|
||||
GraphRoot dummyRoot(edges, {});
|
||||
graphTask->init_to_execute(dummyRoot, outputEdges);
|
||||
for (auto& mapEntry : graphTask->exec_info_) {
|
||||
auto& execInfo = mapEntry.second;
|
||||
if (!execInfo.captures_) {
|
||||
continue;
|
||||
}
|
||||
auto fn = mapEntry.first;
|
||||
// There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
|
||||
// to be captured.
|
||||
if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
|
||||
for (auto& capture : *execInfo.captures_) {
|
||||
capture.hooks_.push_back(
|
||||
std::make_unique<DistAccumulateGradCaptureHook>(
|
||||
std::dynamic_pointer_cast<AccumulateGrad>(
|
||||
accumulateGradFn->shared_from_this()),
|
||||
autogradContext));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark all 'RecvRPCBackward' as needing execution.
|
||||
for (const auto& recvBackwardEdge : recvBackwardEdges) {
|
||||
@ -226,23 +290,6 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::runEngineAndAccumulateGradients(
|
||||
try {
|
||||
const variable_list& grads = futureGrads.constValue();
|
||||
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
|
||||
|
||||
// Accumulate all the gradients in the context.
|
||||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
// It is possible that the grad is not defined since a separate
|
||||
// invocation of the autograd engine on the same node might actually
|
||||
// compute this gradient. Also accumulate grads only for
|
||||
// AccumulateGrad function.
|
||||
if (grads[i].defined() &&
|
||||
dynamic_cast<AccumulateGrad*>(outputEdges[i].function.get())) {
|
||||
auto& variable =
|
||||
std::static_pointer_cast<AccumulateGrad>(outputEdges[i].function)
|
||||
->variable;
|
||||
autogradContext->accumulateGrad(
|
||||
variable, grads[i], 1 /* num_expected_refs */);
|
||||
}
|
||||
}
|
||||
|
||||
accumulateGradFuture->markCompleted(rpc::Message());
|
||||
} catch (std::exception& e) {
|
||||
accumulateGradFuture->setErrorIfNeeded(e.what());
|
||||
|
||||
@ -2091,6 +2091,37 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
# refcount.
|
||||
self.assertTrue(p_g == p_a)
|
||||
|
||||
@dist_init
|
||||
def test_post_hooks(self):
|
||||
self.hook_called_times = 0
|
||||
|
||||
def post_hook_add_one(output_grads, input_grads):
|
||||
self.hook_called_times += 1
|
||||
return output_grads
|
||||
|
||||
def post_hook_add_two(output_grads, input_grads):
|
||||
self.hook_called_times += 2
|
||||
return output_grads
|
||||
|
||||
t = torch.rand(10, 10, requires_grad=True)
|
||||
a = t + t
|
||||
|
||||
# Register post hooks
|
||||
accumulate_grad_0 = a.grad_fn.next_functions[0][0]
|
||||
accumulate_grad_0.register_hook(post_hook_add_one)
|
||||
accumulate_grad_0.register_hook(post_hook_add_two)
|
||||
|
||||
accumulate_grad_1 = a.grad_fn.next_functions[1][0]
|
||||
accumulate_grad_1.register_hook(post_hook_add_two)
|
||||
|
||||
with dist_autograd.context() as context_id:
|
||||
loss = a.sum()
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
self.assertEqual(5, self.hook_called_times)
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
self.assertEqual(1, len(grads))
|
||||
self.assertTrue(t in grads)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._six.PY3,
|
||||
"Pytorch distributed autograd package does not support python2",
|
||||
|
||||
Reference in New Issue
Block a user