mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68691 TraceType is a sharded file, so by only including specific operator headers, we ensure that changing one (non-method) operator only needs one shard to be re-compiled. This also changes all the included autograd and jit headers from including `ATen/ATen.h` to just including `ATen/core/Tensor.h`. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D33336948 Pulled By: albanD fbshipit-source-id: 4e40371592b9a5a7e7fcd1d8cecae11ffb873113
105 lines
4.3 KiB
C++
105 lines
4.3 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/autograd/forward_grad.h>
|
|
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
using Variable = at::Tensor;
|
|
struct Node;
|
|
|
|
TORCH_API extern const char* ERR_BACKWARD_TWICE;
|
|
|
|
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
|
|
/// enough information to reconstruct a variable from a certain point in time.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_API SavedVariable {
|
|
public:
|
|
SavedVariable() = default;
|
|
SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view=false);
|
|
SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_on_view=false);
|
|
SavedVariable(SavedVariable&&) = default;
|
|
SavedVariable& operator=(SavedVariable&&) = default;
|
|
~SavedVariable() {
|
|
if (fw_grad_) {
|
|
// See note [ Using ForwardGrad ]
|
|
fw_grad_->clear();
|
|
}
|
|
}
|
|
|
|
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
|
|
/// function if constructing the `SavedVariable` with it would have caused a
|
|
/// circular reference.
|
|
Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
|
|
|
|
void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
|
|
|
|
void reset_data();
|
|
|
|
private:
|
|
// This field contains either:
|
|
// 1. the variable to save
|
|
// 2. or its tensor_data.
|
|
// If storing the variable itself would create a circular reference,
|
|
// we fall into the second case and its metadata is also saved separately.
|
|
// In that case, the grad_fn must be passed in to the unpack function when
|
|
// reconstructing the Variable (except when we are doing an inplace operation on
|
|
// a view, see below).
|
|
// The field saved_orignal_ below reflects the two cases: its value is true
|
|
// in the first case and false in the second case.
|
|
// The value data_.defined() can be false in three cases:
|
|
// 1. SavedVariable was constructed without a Tensor (the value to save is None), in
|
|
// that case was_default_constructed_ will be kept at true
|
|
// 2. The saved variable has been released by calling SavedVariable::reset_data(), typically
|
|
// during the backward pass
|
|
// 3. Hooks have been registered. In that case, hooks_ will be defined instead.
|
|
// Note that the value of saved_original_ only reflects what happened during the construction
|
|
// of the SavedVariable. If saved_original_ is true, we saved the original tensor in data_,
|
|
// but if the user registers hooks, we will no longer have it (despite the saved_original_ still
|
|
// being true)
|
|
at::Tensor data_;
|
|
|
|
// This field is used to store the forward AD gradients associated with
|
|
// the saved Tensor. Note that this shared_ptr must never be shared with
|
|
// either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ]
|
|
std::shared_ptr<ForwardGrad> fw_grad_;
|
|
|
|
// Weak version of grad_fn_ that prevents leaks in rebase_history() for
|
|
// inplace views.
|
|
// This variable is used when the user chooses to create a SavedVariable with
|
|
// is_inplace_on_view = true.
|
|
// In that case, the grad_fn passed in to the unpack function at unwrapping
|
|
// time is unused.
|
|
std::weak_ptr<Node> weak_grad_fn_;
|
|
c10::VariableVersion version_counter_;
|
|
|
|
uint32_t saved_version_ = 0;
|
|
uint32_t output_nr_ = 0;
|
|
bool was_default_constructed_ = true;
|
|
bool is_inplace_on_view_ = false;
|
|
bool saved_original_ = false;
|
|
bool is_leaf_ = false;
|
|
bool is_output_ = false;
|
|
|
|
// Hooks are a pair of functions pack_hook/unpack_hook that provides fine-grained control
|
|
// over how the SavedVariable should save its data.
|
|
// pack_hook is called upon registration, while unpack_hook is called when unpacking.
|
|
std::unique_ptr<SavedVariableHooks> hooks_;
|
|
// Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if hooks are defined.
|
|
// They are set before pack_hook is called and used after unpack_hook is called.
|
|
std::shared_ptr<Node> grad_fn_;
|
|
std::weak_ptr<Node> grad_accumulator_;
|
|
bool requires_grad_ = false;
|
|
|
|
void save_metadata(const Variable& data);
|
|
static std::unique_ptr<SavedVariableHooks> get_default_hooks();
|
|
void set_hooks_and_pack_data(std::unique_ptr<SavedVariableHooks>&& hooks, const Variable& data);
|
|
};
|
|
}} // namespace torch::autograd
|