mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Before Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op. ## After We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as: ```python # pseudocode class SavedVariable: def unpack(self): if self.hook: return self.hook(self.packed_data) else: return self.packed_data # This approach won't directly work when we add support for Forward AD or double-backward. ``` Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution. All tests pass when running the CA graph directly, the remaining issues are in Dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242 Approved by: https://github.com/jansel
294 lines
11 KiB
C++
294 lines
11 KiB
C++
#include <torch/csrc/autograd/saved_variable.h>
|
|
|
|
#include <torch/csrc/autograd/anomaly_mode.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/grad_mode.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
|
|
#include <ATen/Tensor.h>
|
|
|
|
#include <memory>
|
|
#include <sstream>
|
|
|
|
namespace torch::autograd {
|
|
|
|
SavedVariable::SavedVariable(
|
|
const Variable& variable,
|
|
bool is_output,
|
|
bool is_inplace_on_view) {
|
|
if (variable.defined()) {
|
|
// Note [Inference tensor cannot be saved for backward]
|
|
// Invariant:
|
|
// You can't save an inference tensor for backwards.
|
|
// If an inference tensor was saved for backward in an autograd session and
|
|
// then you reenter inference mode and make an inplace update to the tensor
|
|
// without bumping version_counter, it'll lead to silent wrong result when
|
|
// you do backward() for the previous autograd session. Technically we
|
|
// don't have to check here since it'll fail when querying `current_version`
|
|
// on the inference tensor, but we can give a much better error message
|
|
// here.
|
|
//
|
|
// Note in the documentation we say "inference tensor cannot participate
|
|
// in autograd" which is more restrictive than the invariant. In practice
|
|
// the check is more permissive and only error out when an inference tensor
|
|
// is saved for backward. Whether a tensor is saved for backward is
|
|
// determined by derivative formula and thus varies op by op, so by saying
|
|
// "no inference tensor in autograd" it's easier for users to understand and
|
|
// follow.
|
|
TORCH_CHECK(
|
|
!variable.is_inference(),
|
|
"Inference tensors cannot be saved for backward. To work around "
|
|
"you can make a clone to get a normal tensor and use it in autograd.")
|
|
|
|
was_default_constructed_ = false;
|
|
saved_version_ = variable._version();
|
|
is_leaf_ = variable.is_leaf();
|
|
is_output_ = is_output;
|
|
is_inplace_on_view_ = is_inplace_on_view;
|
|
|
|
if (is_inplace_on_view) {
|
|
TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
|
|
weak_grad_fn_ = variable.grad_fn();
|
|
}
|
|
|
|
auto maybe_hooks = get_default_hooks();
|
|
|
|
// Avoid wrapped numbers from being leaked to the user
|
|
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
|
|
save_metadata(variable);
|
|
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
|
|
TORCH_INTERNAL_ASSERT(!data_.defined());
|
|
return;
|
|
}
|
|
|
|
// If the variable is a leaf or is not an output, we can safely save the
|
|
// original variable without running the risk of reference cycles.
|
|
// 1. If the variable is not an output, its grad_fn has already been fully
|
|
// created and in particular will be a different Node than the one
|
|
// we are currently constructing (the one that owns this SavedVariable).
|
|
// 2. If the variable is a leaf, it only has weak reference to the
|
|
// grad_accumulator which cannot create a cycle. In those cases, we save the
|
|
// original variable and don't need further processing.
|
|
if (!is_output || is_leaf_) {
|
|
saved_original_ = true;
|
|
data_ = variable;
|
|
return;
|
|
}
|
|
|
|
save_metadata(variable);
|
|
|
|
// Only do this if we actually need to.
|
|
data_ = variable.tensor_data();
|
|
}
|
|
}
|
|
|
|
void SavedVariable::save_metadata(const Variable& data) {
|
|
// Save output number, version counter and fw_grad if needed
|
|
|
|
output_nr_ = data.output_nr();
|
|
|
|
if (is_leaf_) {
|
|
grad_accumulator_ = impl::grad_accumulator(data);
|
|
requires_grad_ = data.requires_grad();
|
|
} else if (!is_output_) {
|
|
grad_fn_ = data.grad_fn();
|
|
}
|
|
|
|
// TODO(albanD) This needs to be updated when moving to multiple levels
|
|
const auto& fw_grad = data._fw_grad(/* level */ 0);
|
|
if (fw_grad.defined()) {
|
|
fw_grad_ = std::make_shared<ForwardGrad>();
|
|
fw_grad_->set_value(fw_grad, /* level */ 0);
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() {
|
|
return Engine::get_default_engine().get_default_saved_variable_hooks();
|
|
}
|
|
|
|
void SavedVariable::reset_data() {
|
|
hooks_.reset();
|
|
grad_fn_.reset();
|
|
data_.reset();
|
|
}
|
|
|
|
SavedVariable::SavedVariable(
|
|
const std::optional<Variable>& variable,
|
|
bool is_output,
|
|
bool is_inplace_on_view)
|
|
: SavedVariable(
|
|
variable.has_value() ? *variable : Variable(),
|
|
is_output,
|
|
is_inplace_on_view) {}
|
|
|
|
Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
|
if (was_default_constructed_) {
|
|
return Variable();
|
|
}
|
|
|
|
if (!data_.defined()) {
|
|
TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE);
|
|
}
|
|
|
|
// We want grad_fn here to provide the most helpful debug message to the user
|
|
// if versions don't match
|
|
|
|
std::shared_ptr<Node> grad_fn;
|
|
if (is_inplace_on_view_) {
|
|
grad_fn = weak_grad_fn_.lock();
|
|
} else if (!hooks_) {
|
|
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
|
|
} else {
|
|
grad_fn = grad_fn_;
|
|
}
|
|
|
|
if (!is_leaf_ && !grad_fn) {
|
|
// This issue was introduced when we added logic to save the original
|
|
// because now we rely on data_.grad_fn(), but can be unreliable if the
|
|
// autograd_meta of that saved tensor is cleared with an in-place detach.
|
|
// As a simple fix, we choose to disallow that behavior here even though
|
|
// it makes behavior inconsistent depending on whether you are saving
|
|
// input or output.
|
|
TORCH_CHECK(
|
|
saved_for,
|
|
"Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()."
|
|
"This is not supported, please use out-of-place `.detach()` instead");
|
|
grad_fn = std::move(saved_for);
|
|
}
|
|
|
|
// Only check version counter in the case without hooks
|
|
// If user provides hooks, we can't track versions through the hooks
|
|
if (!hooks_) {
|
|
auto current_version = impl::version_counter(data_).current_version();
|
|
|
|
if (saved_version_ != current_version) {
|
|
std::stringstream message;
|
|
message
|
|
<< "one of the variables needed for gradient computation has been "
|
|
"modified by an inplace operation: ["
|
|
<< data_.toString() << " ";
|
|
if (data_.is_nested()) {
|
|
message << data_._nested_tensor_size() << "]";
|
|
} else {
|
|
message << data_.sizes() << "]";
|
|
}
|
|
if (grad_fn) {
|
|
message << ", which is output " << output_nr_ << " of "
|
|
<< grad_fn->name() << ",";
|
|
}
|
|
message << " is at version " << current_version << "; expected version "
|
|
<< saved_version_ << " instead.";
|
|
if (!AnomalyMode::is_enabled()) {
|
|
message << " Hint: enable anomaly detection to find the operation "
|
|
"that failed to compute its gradient, with torch.autograd."
|
|
"set_detect_anomaly(True).";
|
|
} else {
|
|
message
|
|
<< " Hint: the backtrace further above shows the operation "
|
|
"that failed to compute its gradient. The variable in question "
|
|
"was changed in there or anywhere later. Good luck!";
|
|
}
|
|
TORCH_CHECK(false, message.str());
|
|
}
|
|
}
|
|
|
|
// The version counter is correct.
|
|
// Additionally, if we deal with a non-leaf variable, we have its correct
|
|
// grad_fn.
|
|
|
|
// If we have the original variable, we simply return it
|
|
if (!hooks_ && saved_original_) {
|
|
return data_;
|
|
}
|
|
|
|
auto data = hooks_ ? hooks_->call_unpack_hook() : data_;
|
|
|
|
if (!grad_fn && !requires_grad_ && !data.requires_grad() &&
|
|
!(fw_grad_ && !fw_grad_->empty())) {
|
|
// Avoid detaching if we don't need to.
|
|
return data;
|
|
}
|
|
|
|
// NB: saved views are unpacked as normal Variables (not views) even though
|
|
// they still share the same storage. This works only because we never call
|
|
// in-place functions on unpacked variables.
|
|
Variable var;
|
|
if (grad_fn) {
|
|
var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
|
|
} else {
|
|
var = make_variable(data, requires_grad_);
|
|
}
|
|
|
|
impl::set_grad_accumulator(var, grad_accumulator_);
|
|
impl::set_version_counter(var, impl::version_counter(data));
|
|
|
|
// NB: var here is never a view so there is no need to make anything special
|
|
// for the case where the saved Tensor was a view. This whole argument relies
|
|
// on the fact that the Tensor returned by this function is never
|
|
// modified in-place.
|
|
if (fw_grad_ && !fw_grad_->empty()) {
|
|
// TODO(albanD) This needs to be updated when moving to multiple levels
|
|
auto new_fw_grad = fw_grad_->value(/* level */ 0);
|
|
var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
|
|
}
|
|
|
|
return var;
|
|
}
|
|
|
|
void SavedVariable::set_hooks_and_pack_data(
|
|
std::unique_ptr<SavedVariableHooks>&& hooks,
|
|
const Variable& data) {
|
|
hooks_ = std::move(hooks);
|
|
at::NoGradGuard guard;
|
|
const auto version = impl::version_counter(data).current_version();
|
|
hooks_->call_pack_hook(saved_original_ ? data.detach() : data);
|
|
TORCH_CHECK(
|
|
version == impl::version_counter(data).current_version(),
|
|
"A saved tensor pack hook is modifying its input in place. "
|
|
"Tensors provided as input to pack hook can not be modified by "
|
|
"in-place operations as this can lead to unexpected side-effects. "
|
|
"Please open an issue if you need to perform in-place operations on "
|
|
"the input to a pack hook.");
|
|
}
|
|
|
|
void SavedVariable::register_hooks(
|
|
std::unique_ptr<SavedVariableHooks>&& hooks) {
|
|
TORCH_INTERNAL_ASSERT(hooks);
|
|
TORCH_CHECK(
|
|
!hooks_,
|
|
"Calling register_hooks on a saved tensor whose hooks have already been set. "
|
|
"Hint: only one pair of hooks is allowed at a time.");
|
|
if (!data_.defined()) {
|
|
if (!was_default_constructed_) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Calling register_hooks on a saved tensor after it has been freed. "
|
|
"Saved intermediate values of the graph are freed when you call "
|
|
".backward() or autograd.grad(). Specify retain_graph=True if you "
|
|
"need to backward through the graph a second time or if you need to "
|
|
"access saved variables after calling backward.");
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Calling register_hooks on a saved tensor with value None is forbidden");
|
|
}
|
|
}
|
|
// If we didn't save the original variable, we already saved metadata
|
|
if (saved_original_) {
|
|
save_metadata(data_);
|
|
}
|
|
set_hooks_and_pack_data(std::move(hooks), data_);
|
|
data_.reset();
|
|
}
|
|
|
|
const char* ERR_BACKWARD_TWICE =
|
|
"Trying to backward through the graph a second time (or directly access saved "
|
|
"tensors after they have already been freed). Saved intermediate values "
|
|
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
|
|
"retain_graph=True if you need to backward through the graph a second time or "
|
|
"if you need to access saved tensors after calling backward.";
|
|
|
|
} // namespace torch::autograd
|