Files
pytorch/torch/csrc/autograd/variable.h
Edward Yang 25261a4776 Merge Tensor and Variable. (#28620)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28620

All Tensors are Variables now, they just happen to have requires_grad=False. Tensors ALWAYS have `VariableTensorId` in their type set.

When constructing this patch, I had to make decisions about what I would fix in this patch, and what I would leave for follow up PRs. Here is the cleanup that happens in this patch:

- The `is_variable` property is removed from TensorOptions. I removed this immediately because unlike Tensor::is_variable, TensorOptions::is_variable doesn't respect our VariableTensorId thread-local state. This means that there were a bunch of places where TensorOptions::is_variable was false, which is obviously bogus in the world when tensor and variable are merged. Instead of keeping the method as a function that always returns true, I just opted to remove it entirely (it's not public API.) All places we set `is_variable` are deleted.
  - Knock on effect: there is no longer a separate DeprecatedTypeProperties for the variable and non-variable versions of type.
  - Knock on effect: instead of asserting on TensorOptions::is_variable, instead we just test `at::impl::variable_is_excluded()`
- There is now only one copy of the cuDNN RNN dropout cache, not two (I'm not sure why we had two to begin with)

Some cleanup that doesn't happen in this patch:
- Eliminating unnecessary uses of `make_variable`
- Eliminating `Tensor::is_variable`

The most subtle part of this patch is retaining tracing behavior: the fact that everything is a Variable means that more code gets routed to VariableType than before; this can change traces. I identified two places where we didn't appropriately turn off VariableType, mostly factory functions:

- `torch.tensor` must turn off VariableType before invoking `at::empty` to construct the tensor, as it subsequently does direct data access
- `tensor_slow` (invoked when you pass a Python scalar to a tensor argument) must turn off VariableType before calling `scalar_to_tensor` so the scalar gets traced as constant, rather than as a call to `scalar_to_tensor`.

Honestly, these are all giant hacks, and should be replaced with a more specialized guard that just toggles tracing.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: dreiss

Differential Revision: D18171156

Pulled By: ezyang

fbshipit-source-id: 5b6a045beba37492647e350190f495114e86504d
2019-11-04 14:59:57 -08:00

799 lines
31 KiB
C++

#pragma once
#include <torch/csrc/utils/python_stub.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/cpp_hook.h>
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <c10/util/Exception.h>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace torch { namespace autograd {
struct Node;
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Variable
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A `Variable` augments a `Tensor` with the ability to interact in our
/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
/// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
/// weight in a neural network, or an interior variable, when it is the result
/// of an operation between variables. Every `Variable` also stores another
/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
/// gradient will be accumulated into this variable.
///
/// Every Tensor is a Variable, but sometimes we colloquially refer to Variables
/// that don't require gradients as Tensors (since none of the autograd
/// machinery for Variables applies). Historically, Variables and Tensors
/// were separate concepts, but we've merged them together.
///
/// Gradient Edges
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
/// edge in the autograd graph that connects the variable to a particular input
/// of the gradient function that will be invoked with the variable during the
/// backward pass. More precisely, this gradient function can be one of two
/// things:
/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
/// gradient of the function that produced the variable.
/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
/// scalar gradient value into its `grad` variable.
///
/// Versioning
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Another major feature of `Variable`s are *versions*. Versions are
/// incremented when an in-place mutation of a variable occurs. Versions are
/// useful when constructing `SavedVariable`s, which take a snapshot of a
/// `Variable` at a certain version. You can retrieve a `Variable`'s version
/// through its `current_version()` method.
///
/// Views
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// It is possible for a `Variable` to be a *view* of another `Variable`, in
/// which case it tracks that `Variable`'s data and autograd history. Beyond
/// construction, the interface of a view is identical to that of a regular
/// `Variable`. You can determine whether `Variable` is in fact a view by
/// probing its `is_view()` method. Note that the *view* semantics are only
/// meaningful for `Variable` relations that are relevant to autograd. For
/// example, if you hide your code from autograd using `.no_grad()`, the
/// `Variable`s will not be registered as having view relations, even if they
/// share storage.
/// See NOTE [ Autograd View Variables ] for more details.
///
///
/// Interface
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// `Variable` inherits from `Tensor` and thus its API is a superset of that of
/// `Tensor`. This means you can perform all the usual mathematical and other
/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore,
/// `Variable` and `Tensor` actually convert implicitly between each other. You
/// can thus call functions defined on `Tensor`s also with `Variable`s. For
/// this, the `Variable` class allows implicit construction from `Tensor`.
///
/// Our intention is to eliminate the Variable class in the near future, or make
/// it so that only internal code uses it to do internal operations.
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct AutogradMeta;
struct DifferentiableViewMeta;
struct TORCH_API Variable : public at::Tensor {
/// Default constructor.
Variable() = default;
Variable(c10::intrusive_ptr<at::TensorImpl> self);
// Tensor Conversions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// "Downcasts" a `Tensor` into a `Variable`.
/*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {
}
/*implicit*/ Variable(at::Tensor&& rhs)
: at::Tensor(std::move(rhs)) {
}
// NOTE: Assignment operators to Tensor come for free from the constructors.
/// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
/// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
/// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
///
/// One notable difference with the legacy `.data()` function is that changes to the
/// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
/// will not update the original `Variable`, due to the fact that this function
/// shallow-copies the `Variable`'s underlying TensorImpl.
at::Tensor tensor_data() const noexcept;
/// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
/// in Python, which create a new `Variable` that shares the same storage and
/// tensor metadata with the original `Variable`, but with a completely new
/// autograd history.
///
/// NOTE: If we change the tensor metadata (e.g. sizes / strides /
/// storage / storage_offset) of a variable created from `var.variable_data()`, those
/// changes will not update the original variable `var`. In `.variable_data()`, we set
/// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
/// in order to prevent users from changing metadata of `var.variable_data()`
/// and expecting the original variable `var` to also be updated.
at::Tensor variable_data() const noexcept;
// Gradient Node and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Gets the gradient function of the `Variable`. If this is a leaf variable,
/// the pointer returned will be null.
///
/// For View Variables:
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
/// re-create the grad_fn to express the up-to-date view relationship between
/// this and the base Variable.
const std::shared_ptr<Node>& grad_fn() const;
/// Gets the raw gradient function pointer, whatever it currently is.
Node* grad_fn_unsafe() const;
private:
/// Set the gradient accumulator of the `Variable`. This is only applicable to
/// leaf variables. Interior variables should call `set_gradient_edge()`.
void set_grad_accumulator(std::weak_ptr<Node> grad_accumulator);
// Only user of set_grad_accumulator
friend class SavedVariable;
public:
/// Attempts to get a pointer to the gradient accumulator of the `Variable`,
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
std::shared_ptr<Node> try_get_grad_accumulator() const;
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
std::shared_ptr<Node> grad_accumulator() const;
/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
/// gradient function if this is an interior `Variable`, or the gradient
/// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
/// will store the input index of the `Node` to which this variable is
/// connected in its `input_nr` field. For leaves, the `input_nr` is always
/// zero. Note that `set_gradient_edge` and `gradient_edge` are not
/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
/// `set_grad_accumulator` to set the accumulator.
Edge gradient_edge() const {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = grad_fn()) {
return Edge(gradient, output_nr());
} else {
return Edge(grad_accumulator(), 0);
}
}
/// Returns a copy of this `Variable` that is detached from its autograd graph
/// and has a blank version. This method is OK to call if the `Variable` is a
/// view.
/// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
/// storage / storage_offset) of a tensor created from `detach()`, those metadata
/// in the original tensor will also be updated. However, the new behavior is that
/// those metadata changes to the detached tensor will not update the original tensor
/// anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_`
/// to false to make such changes explicitly illegal, in order to prevent users from
/// changing metadata of the detached tensor and expecting the original tensor to also
/// be updated.
Variable detach() const;
/// Like `detach()`, but removes this `Variable` in-place. This method may
/// only be called on non-view `Variable`s. You can use `is_view()` to check
/// this. If this `Variable` is a view, throws an `std::runtime_error()`.
void detach_();
/// Computes the gradient of current tensor w.r.t. graph leaves.
void backward(
const Tensor& gradient,
bool keep_graph,
bool create_graph) const;
/// Sets the tensor data held by this `Variable` to be the same as `new_data`.
/// It requires that `new_data` and `Variable` have compatible tensor type, by
/// checking `_has_compatible_shallow_copy_type(this, new_data)`.
void set_data(const at::Tensor &new_data) const;
/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
/// `Variable`.
/// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
/// and never the `grad_accumulator`. For the latter, use
/// `set_grad_accumulator`. This allows late construction of an interior
/// `Variable`.
void set_gradient_edge(Edge edge) noexcept;
/// Returns the input index of the gradient `Node` to which this
/// `Variable` is connected. Note: input indexes of the gradient `Node`
/// correspond to output indexes of the corresponding forward `Node`.
uint32_t output_nr() const noexcept;
/// True if this `Variable` is a leaf and thus does not have a `grad_fn`.
bool is_leaf() const noexcept;
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Increments the version count of this `Variable`.
void bump_version() noexcept;
void set_version_counter(const c10::VariableVersion& version_counter) noexcept;
/// Retrieves this `Variable`s version counter.
const c10::VariableVersion& version_counter() const noexcept;
/// Retrieves the current value of the `Variable`'s version counter.
/// Equivalent to calling `version_counter().current_version()`.
uint32_t current_version() const noexcept;
// Autograd Graph Interaction
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Update the `grad_fn` of an existing Variable. Called after in-place
/// modifications.
///
/// For View Variables:
/// Called after in-place modifications. Modifies the grad_fn of the base
/// Variable.
void rebase_history(Edge gradient_edge);
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void add_hook(std::shared_ptr<FunctionPreHook> hook);
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
void clear_hooks();
template <typename T>
using hook_return_void_t = c10::guts::enable_if_t<std::is_void<typename std::result_of<T&(Variable)>::type>::value, unsigned>;
template <typename T>
using hook_return_var_t = c10::guts::enable_if_t<std::is_same<typename std::result_of<T&(Variable)>::type, Variable>::value, unsigned>;
// Remove hook at given position
void remove_hook(unsigned pos);
// Returns the index of the hook in the list which can be used to remove hook
// Register a hook with no return value
template <typename T>
hook_return_void_t<T> register_hook(T&& hook);
// Register a hook with variable return value
template <typename T>
hook_return_var_t<T> register_hook(T&& hook);
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Returns true if this `Variable` is a view of another `Variable`.
bool is_view() const noexcept;
/// Returns the `Variable` that this `Variable` is a view of. If this
/// `Variable` is not a view, throw a `std::runtime_error`.
const Variable& base() const;
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void set_name(const std::string& name);
const std::string& name() const noexcept;
PyObject* pyobj() const noexcept;
void set_pyobj(PyObject* pyobj) noexcept;
public:
// WARNING: This may return a nullptr. If you require AutogradMeta to return
// a materialized structure, use materialize_autograd_meta instead.
AutogradMeta* get_autograd_meta() const noexcept;
private:
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
at::TensorImpl* get() const;
void create_cpp_hook();
// Returns the current autograd meta, materializing it if it was previously
// none. This counts as a *mutating* operation, so do not call it on
// "read-only" operators; in particular, this is NOT thread safe
AutogradMeta* materialize_autograd_meta();
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd history.
/// As an optimization, a Variable may store a nullptr, in lieu of a default
/// constructed AutogradMeta.
struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
std::string name_;
Variable grad_;
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list;
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_;
bool is_view_;
// The "output number" of this variable; e.g., if this variable
// was the second output of a function, then output_nr == 1.
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn() and
// grad_accumulator().
std::mutex mutex_;
/// Sets the `requires_grad` property of `Variable`. This should be true for
/// leaf variables that want to accumulate gradients, and false for all other
/// variables.
void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) override {
TORCH_CHECK(
!requires_grad || at::isFloatingType(at::typeMetaToScalarType(self_impl->dtype())),
"Only Tensors of floating point dtype can require gradients");
requires_grad_ = requires_grad;
}
bool requires_grad() const override {
return requires_grad_ || grad_fn_;
}
/// Accesses the gradient `Variable` of this `Variable`.
Variable& grad() override {
return grad_;
}
const Variable& grad() const override {
return grad_;
}
AutogradMeta(
at::TensorImpl* self_impl = nullptr,
bool requires_grad = false,
Edge gradient_edge = Edge());
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// DifferentiableViewMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// NOTE [ Autograd View Variables ]
///
/// Many operations return Variable that shares storage with an input Variable.
/// The returned Variable is called a **view** Variable on the input **base**
/// Variable.
///
/// In PyTorch, we have two types of views: differentiable views, and
/// non-differentiable views. In either type, to support proper version
/// checking, the base and view Variables must always share the same
/// version_counter.
///
///
/// Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Differentiable views are the view variables where you want gradients to flow
/// back to the base variables. Out-of-place operations on views are quite
/// straightforward, but in-place ones are very tricky. Even if the base
/// variable may not require grad when we create the view, we still need to
/// track the view relation because future in-place ops may require back-proping
/// through it. For example, we need to support
///
/// (1) in-place operation on view, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// base[1] = var # i.e., base[1].copy_(var)
/// torch.autograd.grad(base.sum(), var) <- should return an all ones tensor
///
/// (2) in-place operation on base after view is created, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// view = base[1]
/// base.copy_(var)
/// torch.autograd.grad(view.sum(), var) <- should return a tensor with
/// var[1] filled with all ones and
/// zeros everywhere else
///
/// DifferentiableViewMeta is created to support gradient tracking of
/// such **in-place** operations. In particular,
/// + if an in-place op is done on base, the grad_fn field of the view may
/// become stale. So accesses should always go through grad_fn(), which
/// reconstructs an updated grad_fn if the version_counter has incremented.
/// All other fields are always valid.
/// + if an in-place op is done on view, in rebase_history() of view, which is
/// called after every in-place op in VariableType.cpp, the grad_fn of base
/// is updated.
///
///
/// Non-Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// In certain cases, although function outputs share storage with inputs, they
/// will **never** require gradient history tracking. Instead of registering the
/// view relation via DifferentiableViewMeta in autograd, the views will be
/// using usual AutogradMeta and just share the version counters with the base
/// Variables.
/// Such views include:
/// 1. Views created from .detach()
/// 2. Views that are non-differentiable by its nature.
/// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
/// floating point tensor.
/// See top of `derivatives.yaml` on how to specify that outputs of a
/// function are non-differentiable.
/// These are called non-differentiable views as the gradients do not flow
/// through the view relation.
/// Relevant logic for non-differentiable views is implemented in
/// make_variable_view below, and wrap_output of gen_variable_type.py.
struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
/// The base `Variable` (never a view).
Variable base_;
/// The value of the version_counter at the time grad_fn was created. The
/// grad_fn field is stale if attr_version !=
/// version_counter.current_version().
uint32_t attr_version;
bool requires_grad() const override {
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
}
DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base);
~DifferentiableViewMeta();
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable Implementation
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Creates a `Variable` that is a *view* of another (*base*) variable.
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
/// `is_differentiable` is a bool that specifies whether this view is
/// differentiable, i.e., whether the relation should be tracked by autograd.
/// See NOTE [ Autograd View Variables ] for details.
/// NOTE: `allow_tensor_metadata_change` is set to true by default, because there
/// are a lot of call sites to these factory functions that need to change the
/// variable's size or storage afterwards, and they don't expect the original
/// tensor (where the variable is created from) to be updated. Setting
/// `allow_tensor_metadata_change_` to false by default would unnecessarily
/// prevent those changes from happening and is undesirable.
// See NOTE [ Autograd View Variables ] for details.
inline Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable = true,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
if (is_differentiable) {
/// Differentiable view. Track history with DifferentiableViewMeta.
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<DifferentiableViewMeta>(
data_impl_copy.get(), std::move(base)));
return Variable(data_impl_copy);
} else {
/// Non-differentiable view. Just share version counter.
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/base.version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(nullptr);
return Variable(data_impl_copy);
}
}
return Variable();
}
/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`.
/// `requires_grad` should be
/// set only for leaves, and determines whether the `Variable` will accumulate
/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
/// type *must* be `Tensor`.
///
/// TODO: Eliminate this function as much as possible, as it can be expressed
/// more clearly as detach() or a no-op in most call sites (especially when
/// there is only one use of the variable).
inline Variable make_variable(
at::Tensor data,
bool requires_grad = false,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
if (data.getIntrusivePtr().use_count() == 1 && data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.getIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
if (requires_grad) {
data_impl->set_autograd_meta(c10::guts::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
} else {
data_impl->set_autograd_meta(nullptr);
}
return Variable(std::move(data_impl));
} else {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
if (requires_grad) {
data_impl_copy->set_autograd_meta(c10::guts::make_unique<AutogradMeta>(
data_impl_copy.get(), requires_grad));
} else {
data_impl_copy->set_autograd_meta(nullptr);
}
return Variable(data_impl_copy);
}
}
return Variable();
}
/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`.
/// `gradient_edge` should be a (function, input_nr) pair specifying the function
/// in the autograd graph, and what particular input of that function, this
/// variable is connected to.
inline Variable make_variable(
at::Tensor data,
Edge gradient_edge,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<AutogradMeta>(
data_impl_copy.get(), false, std::move(gradient_edge)));
return Variable(data_impl_copy);
}
return Variable();
}
// Tensor Conversion
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In the old days, these casts were checked, but now that every Tensor
// is a Variable this cast is always valid
inline Variable& as_variable_ref(at::Tensor& tensor) {
return static_cast<Variable&>(tensor);
}
inline const Variable& as_variable_ref(const at::Tensor& tensor) {
return static_cast<const Variable&>(tensor);
}
inline at::Tensor Variable::tensor_data() const noexcept {
auto self_impl_copy = get()->shallow_copy_and_detach(
/*version_counter=*/get()->version_counter(),
/*allow_tensor_metadata_change=*/get()->allow_tensor_metadata_change());
return at::Tensor(self_impl_copy);
}
inline at::Tensor Variable::variable_data() const noexcept {
auto self_impl_copy = get()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/false);
self_impl_copy->set_autograd_meta(nullptr);
return at::Tensor(self_impl_copy);
}
// Gradient Node and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Node* Variable::grad_fn_unsafe() const {
if (get_autograd_meta()) {
return get_autograd_meta()->grad_fn_.get();
} else {
return nullptr;
}
}
inline void Variable::set_grad_accumulator(
std::weak_ptr<Node> grad_accumulator) {
materialize_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
}
inline std::shared_ptr<Node> Variable::try_get_grad_accumulator() const {
if (get_autograd_meta()) {
return get_autograd_meta()->grad_accumulator_.lock();
} else {
return nullptr;
}
}
inline Variable Variable::detach() const {
auto var = make_variable_view(*this, *this, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(var, *this);
#endif
return var;
}
inline void Variable::set_gradient_edge(Edge edge) noexcept {
auto* meta = materialize_autograd_meta();
meta->grad_fn_ = std::move(edge.function);
meta->output_nr_ = edge.input_nr;
}
inline uint32_t Variable::output_nr() const noexcept {
if (get_autograd_meta()) {
return get_autograd_meta()->output_nr_;
} else {
return 0;
}
}
inline bool Variable::is_leaf() const noexcept {
if (get_autograd_meta()) {
return get_autograd_meta()->grad_fn_ == nullptr;
} else {
return true;
}
}
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_version_counter(
const c10::VariableVersion& version_counter) noexcept {
unsafeGetTensorImpl()->set_version_counter(version_counter);
}
inline void Variable::bump_version() noexcept {
unsafeGetTensorImpl()->bump_version();
}
inline uint32_t Variable::current_version() const noexcept {
return unsafeGetTensorImpl()->version_counter().current_version();
}
inline const c10::VariableVersion& Variable::version_counter() const noexcept {
return unsafeGetTensorImpl()->version_counter();
}
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
materialize_autograd_meta()->hooks_.push_back(std::move(hook));
}
namespace {
std::vector<std::shared_ptr<FunctionPreHook>> empty_singleton;
}
// TODO: Return an ArrayRef instead
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
const noexcept {
if (get_autograd_meta()) {
return get_autograd_meta()->hooks_;
} else {
return empty_singleton;
}
}
inline void Variable::clear_hooks() {
// This is a little goofy, but usually this should be a no oop
materialize_autograd_meta()->hooks_.clear();
}
template <typename T>
auto Variable::register_hook(T&& hook) -> Variable::hook_return_void_t<T> {
TORCH_CHECK(requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
// NB: materialize_autograd_meta unnecessary due to requires grad check
auto &list = get_autograd_meta()->cpp_hooks_list;
if(!list) {
create_cpp_hook();
}
unsigned idx = list->size();
// Return the grad argument in case of a hook with void return type to have an
// std::function with Variable return type
std::function<void(Variable)> fn(hook);
list->emplace_back([fn](Variable grad){
fn(grad);
return Variable();});
return idx;
}
template <typename T>
auto Variable::register_hook(T&& hook) -> Variable::hook_return_var_t<T> {
TORCH_CHECK(requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
// NB: materialize_autograd_meta unnecessary due to requires grad check
auto &list = get_autograd_meta()->cpp_hooks_list;
if(!list) {
create_cpp_hook();
}
unsigned idx = list->size();
list->push_back(hook);
return idx;
}
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool Variable::is_view() const noexcept {
if (get_autograd_meta()) {
return get_autograd_meta()->is_view_;
} else {
return false;
}
}
inline const Variable& Variable::base() const {
if (is_view()) {
// is_view() implies get_autograd_meta()
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(get_autograd_meta());
return diff_view_meta->base_;
} else {
throw std::runtime_error("Can't get base of non-view Variable");
}
}
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_name(const std::string& name) {
materialize_autograd_meta()->name_ = name;
}
namespace {
std::string singleton_string;
}
inline const std::string& Variable::name() const noexcept {
if (get_autograd_meta()) {
return get_autograd_meta()->name_;
} else {
return singleton_string;
}
}
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
get()->set_pyobj(pyobj);
}
inline PyObject* Variable::pyobj() const noexcept {
return get()->pyobj();
}
inline AutogradMeta* Variable::get_autograd_meta() const noexcept {
// NB: could return null
return static_cast<AutogradMeta*>(get()->autograd_meta());
}
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Variable::Variable(c10::intrusive_ptr<at::TensorImpl> self)
: at::Tensor(std::move(self)) {}
inline at::TensorImpl* Variable::get() const {
TORCH_CHECK(defined(), "Called Variable::get() on an undefined Variable");
return unsafeGetTensorImpl();
}
}} // namespace torch::autograd