mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Before Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op. ## After We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as: ```python # pseudocode class SavedVariable: def unpack(self): if self.hook: return self.hook(self.packed_data) else: return self.packed_data # This approach won't directly work when we add support for Forward AD or double-backward. ``` Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution. All tests pass when running the CA graph directly, the remaining issues are in Dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242 Approved by: https://github.com/jansel
133 lines
4.9 KiB
C++
133 lines
4.9 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/autograd/forward_grad.h>
|
|
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
|
|
namespace torch::autograd {
|
|
|
|
using Variable = at::Tensor;
|
|
struct Node;
|
|
|
|
TORCH_API extern const char* ERR_BACKWARD_TWICE;
|
|
|
|
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
|
|
/// enough information to reconstruct a variable from a certain point in time.
|
|
class TORCH_API SavedVariable {
|
|
public:
|
|
SavedVariable() = default;
|
|
SavedVariable(
|
|
const Variable& variable,
|
|
bool is_output,
|
|
bool is_inplace_on_view = false);
|
|
SavedVariable(
|
|
const std::optional<Variable>& variable,
|
|
bool is_output,
|
|
bool is_inplace_on_view = false);
|
|
SavedVariable(const SavedVariable&) = delete;
|
|
SavedVariable(SavedVariable&&) = default;
|
|
SavedVariable& operator=(const SavedVariable&) = delete;
|
|
SavedVariable& operator=(SavedVariable&&) = default;
|
|
~SavedVariable() {
|
|
if (fw_grad_) {
|
|
// See note [ Using ForwardGrad ]
|
|
fw_grad_->clear();
|
|
}
|
|
}
|
|
|
|
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
|
|
/// function if constructing the `SavedVariable` with it would have caused a
|
|
/// circular reference.
|
|
Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
|
|
|
|
void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
|
|
|
|
void reset_data();
|
|
|
|
bool has_hooks() const {
|
|
return (bool)hooks_;
|
|
}
|
|
|
|
// Used by compiled autograd
|
|
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
|
retrieve_unpack_hook_data() const {
|
|
if (!hooks_) {
|
|
return std::nullopt;
|
|
}
|
|
return hooks_->retrieve_unpack_hook_data();
|
|
}
|
|
|
|
private:
|
|
// This field contains either:
|
|
// 1. the variable to save
|
|
// 2. or its tensor_data.
|
|
// If storing the variable itself would create a circular reference,
|
|
// we fall into the second case and its metadata is also saved separately.
|
|
// In that case, the grad_fn must be passed in to the unpack function when
|
|
// reconstructing the Variable (except when we are doing an inplace operation
|
|
// on a view, see below). The field saved_original_ below reflects the two
|
|
// cases: its value is true in the first case and false in the second case.
|
|
// The value data_.defined() can be false in three cases:
|
|
// 1. SavedVariable was constructed without a Tensor (the value to save is
|
|
// None), in that case was_default_constructed_ will be kept at true
|
|
// 2. The saved variable has been released by calling
|
|
// SavedVariable::reset_data(), typically during the backward pass
|
|
// 3. Hooks have been registered. In that case, hooks_ will be defined
|
|
// instead. Note that the value of saved_original_ only reflects what happened
|
|
// during the construction of the SavedVariable. If saved_original_ is true,
|
|
// we saved the original tensor in data_, but if the user registers hooks, we
|
|
// will no longer have it (despite the saved_original_ still being true)
|
|
at::Tensor data_;
|
|
|
|
// This field is used to store the forward AD gradients associated with
|
|
// the saved Tensor. Note that this shared_ptr must never be shared with
|
|
// either the saved Tensor or the unpacked Tensor. See note [ Using
|
|
// ForwardGrad ]
|
|
std::shared_ptr<ForwardGrad> fw_grad_;
|
|
|
|
// Weak version of grad_fn_ that prevents leaks in rebase_history() for
|
|
// inplace views.
|
|
// This variable is used when the user chooses to create a SavedVariable with
|
|
// is_inplace_on_view = true.
|
|
// In that case, the grad_fn passed in to the unpack function at unwrapping
|
|
// time is unused.
|
|
std::weak_ptr<Node> weak_grad_fn_;
|
|
|
|
uint32_t saved_version_ = 0;
|
|
uint32_t output_nr_ = 0;
|
|
bool was_default_constructed_ = true;
|
|
bool is_inplace_on_view_ = false;
|
|
bool saved_original_ = false;
|
|
bool is_leaf_ = false;
|
|
bool is_output_ = false;
|
|
|
|
// Hooks are a pair of functions pack_hook/unpack_hook that provides
|
|
// fine-grained control over how the SavedVariable should save its data.
|
|
// pack_hook is called upon registration, while unpack_hook is called when
|
|
// unpacking.
|
|
std::unique_ptr<SavedVariableHooks> hooks_;
|
|
// Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if
|
|
// hooks are defined. They are set before pack_hook is called and used after
|
|
// unpack_hook is called.
|
|
std::shared_ptr<Node> grad_fn_;
|
|
// For the usual case where leaf tensors are the input, we expect its
|
|
// grad_acc to be kept alive by the graph. The reason SavedVariable holds
|
|
// a owning reference is to support the case where a custom autograd Function
|
|
// saves an intermediate.
|
|
std::shared_ptr<Node> grad_accumulator_;
|
|
bool requires_grad_ = false;
|
|
|
|
void save_metadata(const Variable& data);
|
|
static std::unique_ptr<SavedVariableHooks> get_default_hooks();
|
|
void set_hooks_and_pack_data(
|
|
std::unique_ptr<SavedVariableHooks>&& hooks,
|
|
const Variable& data);
|
|
};
|
|
} // namespace torch::autograd
|