Improve Variable interface (#5127)

* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
This commit is contained in:
Peter Goldsborough
2018-02-12 20:26:26 -08:00
committed by Soumith Chintala
parent 0ef10385b2
commit 2d5fbe6e0d
38 changed files with 948 additions and 630 deletions

View File

@ -37,6 +37,7 @@ BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: true
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4

View File

@ -16,6 +16,7 @@ struct Storage;
struct TensorImpl : public Retainable {
explicit TensorImpl(Type * type)
: is_scalar(false), type_(type) {}
Type & type() const {
return *type_;
}
@ -49,7 +50,7 @@ struct TensorImpl : public Retainable {
void setScalar(bool s) {
is_scalar = s;
}
private:
protected:
bool is_scalar;
Type * type_;
};

View File

@ -474,6 +474,7 @@ main_sources = [
"torch/csrc/jit/python_ir.cpp",
"torch/csrc/jit/test_jit.cpp",
"torch/csrc/jit/tracer.cpp",
"torch/csrc/jit/tracer_state.cpp",
"torch/csrc/jit/python_tracer.cpp",
"torch/csrc/jit/passes/shape_analysis.cpp",
"torch/csrc/jit/interned_strings.cpp",

View File

@ -129,7 +129,7 @@ def process_function(func):
name = arg['name']
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
saved_variables.append('SavedVariable {}_;'.format(name))
release_variables.append('{}_.data.reset();'.format(name))
release_variables.append('{}_.reset_data();'.format(name))
ptr = 'shared_from_this()' if is_output else ''
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':

View File

@ -478,7 +478,7 @@ Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor &
#ifdef WITH_SCALARS
grad_input.masked_fill_(input == value, grad);
#else
auto grad_data = static_cast<Variable&>(grad).data();
auto grad_data = as_variable_ref(grad).data();
grad_input.masked_fill_(input == value, Scalar(grad_data[0]));
#endif
return grad_input;
@ -1088,9 +1088,9 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
for (auto s : input.sizes().slice(2)) {
M *= s;
}
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean), input);
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean, /*requires_grad=*/false), input);
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5)), input);
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), /*requires_grad=*/false), input);
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

View File

@ -5,6 +5,7 @@
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/generated/Functions.h"
@ -28,7 +29,6 @@ using namespace at;
using namespace torch::autograd::generated;
namespace torch { namespace autograd {
// Helper methods for working with Attributes (torch/csrc/jit/attributes.h)
// The overloaded accessors are convenient for the generated code (since we
@ -74,7 +74,7 @@ std::unique_ptr<Storage> VariableType::storageWithAllocator(int64_t size, std::u
return baseType->storageWithAllocator(size, std::move(allocator));
}
Tensor VariableType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), false);
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), /*requires_grad=*/false);
}
std::unique_ptr<Generator> VariableType::generator() const {
return baseType->generator();
@ -164,7 +164,7 @@ Variable & VariableType::checked_cast_variable(const Tensor & t, const char * na
runtime_error("Expected object of type Variable but found type %s for argument #%d '%s'",
t.type().toString(), pos, name);
}
return static_cast<Variable&>(const_cast<Tensor&>(t));
return as_variable_ref(const_cast<Tensor&>(t));
}
Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
@ -207,49 +207,35 @@ static std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
return SavedVariable{tensor, false /* is output */}; });
}
template <typename... Tensors, size_t... Is>
std::tuple<Tensors...> as_variable_impl(
std::tuple<Tensors...> tensors,
Indices<Is...>) {
// Expand the integer parameter pack into a sequence of Variable
// constructions. This turns into (boolean omitted):
// Variable(std::get<0>(tensors)), Variable(std::get<1>(tensors)), ...
return std::tuple<Tensors...>(
make_variable(std::get<Is>(tensors), /*requires_grad=*/false)...);
}
template <typename... Tensors>
std::tuple<Tensors...> as_variable(std::tuple<Tensors...> tensors) {
// `sizeof...(Tensors)` gets us the size of the `Tensors` parameter pack at
// compile time. We use it to parameterize a `MakeIndices` class, which will
// expand into an Indices object containing the numbers 0 to
// sizeof...(Tensors) - 1.
return as_variable_impl(
tensors, typename MakeIndices<sizeof...(Tensors)>::indices());
}
static Tensor as_variable(Tensor tensor) {
return make_variable(std::move(tensor));
}
static std::tuple<Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))));
}
static std::tuple<Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))));
}
static std::tuple<Tensor, Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))),
make_variable(std::move(std::get<3>(tensors))));
}
static std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))),
make_variable(std::move(std::get<3>(tensors))),
make_variable(std::move(std::get<4>(tensors)))
);
return make_variable(std::move(tensor), /*requires_grad=*/false);
}
static std::vector<Tensor> as_variable(TensorList tl) {
std::vector<Tensor> variables;
for (auto& t : tl) {
variables.emplace_back(make_variable(std::move(t)));
variables.emplace_back(make_variable(std::move(t), /*requires_grad=*/false));
}
return variables;
}
@ -316,20 +302,20 @@ static void throw_error_out_requires_grad(const char* name) {
static void rebase_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
if (grad_fn && tensor.defined()) {
auto& var = static_cast<Variable&>(tensor);
auto& var = as_variable_ref(tensor);
grad_fn->num_inputs = 1;
var.rebase_history(0, std::move(grad_fn));
var.rebase_history({std::move(grad_fn), 0});
}
}
static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
grad_fn->num_inputs = tensors.size();
int output_nr = 0;
uint32_t output_nr = 0;
for (auto& tensor : tensors) {
if (tensor.defined()) {
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
var.rebase_history(output_nr, grad_fn);
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
var.rebase_history({grad_fn, output_nr});
}
output_nr++;
}
@ -340,22 +326,20 @@ static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn
// overload for functions with multiple differentiable outputs.
static void set_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
if (grad_fn && tensor.defined()) {
auto& var = static_cast<Variable&>(tensor);
auto& var = as_variable_ref(tensor);
grad_fn->num_inputs = 1;
var.get()->output_nr = 0;
var.get()->_grad_fn = std::move(grad_fn);
var.set_gradient_edge({std::move(grad_fn), 0});
}
}
static void set_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
grad_fn->num_inputs = tensors.size();
int64_t output_nr = 0;
uint32_t output_nr = 0;
for (auto& tensor : tensors) {
if (tensor.defined()) {
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
var.get()->output_nr = output_nr;
var.get()->_grad_fn = grad_fn;
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
var.set_gradient_edge({grad_fn, output_nr});
}
output_nr++;
}
@ -378,9 +362,8 @@ template<typename... Args> inline variable_list flatten(Args&&... args) {
return out; // RVO
}
static void increment_version(const Tensor & t) {
auto& var = static_cast<const Variable&>(t);
var.version_counter().increment();
static void increment_version(Tensor & t) {
as_variable_ref(t).bump_version();
}
static bool isFloatingPoint(ScalarType s) {
@ -411,7 +394,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
Tensor & VariableType::resize_(Tensor & self, IntList size) const {
auto& self_ = unpack(self, "self", 0);
if (static_cast<Variable&>(self).requires_grad()) {
if (as_variable_ref(self).requires_grad()) {
at::runtime_error("cannot resize variables that require grad");
}
baseType->resize_(self_, size);
@ -421,7 +404,7 @@ Tensor & VariableType::resize_(Tensor & self, IntList size) const {
Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) const {
auto& self_ = unpack(self, "self", 0);
auto& the_template_ = unpack(the_template, "the_template", 1);
if (static_cast<Variable&>(self).requires_grad()) {
if (as_variable_ref(self).requires_grad()) {
at::runtime_error("cannot resize variables that require grad");
}
baseType->resize_as_(self_, the_template_);

View File

@ -3,6 +3,10 @@
// ${generated_comment}
#include <ATen/ATen.h>
#include <cstdint> // for size_t
#include <functional> // for function
#include <memory> // for unique_ptr
#include <string>
#include <vector>
@ -56,7 +60,6 @@ private:
static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
private:
at::Type* baseType;
std::string str;
};

View File

@ -25,7 +25,7 @@ using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
static_cast<Variable&>(self).get()->_requires_grad = requires_grad;
as_variable_ref(self).set_requires_grad(requires_grad);
return self;
}
@ -70,7 +70,7 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
{
HANDLE_TH_ERRORS
auto data = torch::utils::tensor_from_numpy(arg);
return THPVariable_Wrap(make_variable(std::move(data)));
return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false));
END_HANDLE_TH_ERRORS
}

View File

@ -6,8 +6,7 @@
#include "torch/csrc/utils/hash.h"
namespace torch {
namespace autograd {
namespace torch { namespace autograd {
struct Function;
@ -15,8 +14,8 @@ struct Function;
struct Edge {
Edge() noexcept : function(nullptr), input_nr(0) {}
Edge(const std::shared_ptr<Function>& function_, uint32_t input_nr_) noexcept
: function(function_), input_nr(input_nr_) {}
Edge(std::shared_ptr<Function> function_, uint32_t input_nr_) noexcept
: function(std::move(function_)), input_nr(input_nr_) {}
/// Convenience method to test if an edge is valid.
bool is_valid() const noexcept {
@ -38,8 +37,7 @@ struct Edge {
/// The identifier of a particular input to the function.
uint32_t input_nr;
};
} // namespace autograd
} // namespace torch
}} // namespace torch::autograd
// The idiomatic way of enabling use of a custom type as the key of hash
// containers in C++11. This method removes the requirement of having to pass

View File

@ -1,6 +1,5 @@
#pragma once
#include <memory>
#include <vector>
// A hook that's called on gradients

View File

@ -5,6 +5,8 @@
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"
#include <ATen/ATen.h>
#include <memory>
#include <utility>
@ -19,7 +21,7 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
outputs.reserve(inputs.size());
for (auto& var : inputs) {
// FIXME: share version counters
outputs.emplace_back(var.defined() ? var.data() : Tensor());
outputs.emplace_back(var.defined() ? var.data() : at::Tensor());
}
return wrap_outputs(inputs, std::move(outputs), [&](function_list&& next_functions) {
return std::make_shared<Error>(msg, std::move(next_functions));

View File

@ -4,6 +4,7 @@
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"
#include <cstdint>
#include <memory>
@ -264,11 +265,10 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// This output is already rebased. This happens when there
// the same Variable has been returned multiple times, and
// is repeated in this list.
if (output.get()->_grad_fn.get() == this) {
if (output.grad_fn_unsafe() == this) {
auto replicate = std::make_shared<Replicate>();
replicate->next_functions.emplace_back(this_shared, output.output_nr());
output.get()->_grad_fn = replicate;
output.get()->output_nr = 0;
output.set_gradient_edge({std::move(replicate), 0});
repeated_outputs.emplace(&output);
}
// NOTE: this check should be fairly cheap, and the set shouldn't
@ -277,8 +277,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
auto & replicate = output.grad_fn();
replicate->next_functions.emplace_back(this_shared, num_inputs++);
} else {
output.get()->_grad_fn = this_shared;
output.get()->output_nr = num_inputs++;
output.set_gradient_edge(Edge(this_shared, num_inputs++));
}
}

View File

@ -1,4 +1,6 @@
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
@ -14,7 +16,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
if (!any_variable_requires_grad(inputs)) {
for (auto& output : outputs) {
if (output.defined()) {
result.emplace_back(make_variable(output, false));
result.push_back(make_variable(output, /*requires_grad=*/false));
} else {
result.emplace_back();
}
@ -23,7 +25,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
auto grad_fn = ctr(get_next_functions(inputs));
for (auto& output : outputs) {
if (output.defined()) {
result.emplace_back(make_variable(output, grad_fn));
result.push_back(make_variable(output, Edge(grad_fn, grad_fn->num_inputs++)));
} else {
++grad_fn->num_inputs;
result.emplace_back();

View File

@ -142,10 +142,10 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
int output_nr = input_var->cdata.output_nr();
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
grad_fn = input_var->cdata.get()->grad_accumulator.lock();
grad_fn = input_var->cdata.try_get_grad_accumulator();
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Variables does not require grad");

View File

@ -323,7 +323,7 @@ static std::vector<PyObject*> _mark_dirty(THPFunction *self)
dirty_inputs.push_back(obj);
auto variable = (THPVariable*)obj;
variable->cdata.version_counter().increment();
variable->cdata.bump_version();
}
// We're not going to ever need this so let's remove references now
Py_CLEAR(self->dirty_tensors);
@ -368,14 +368,14 @@ static void _wrap_outputs(THPFunction *self,
}
if (THPModule_isTensor(obj)) {
// temporarily wrap tensors as variables until the classes are merged
return make_variable(createTensor(obj));
return make_variable(createTensor(obj), /*requires_grad=*/false);
}
throw TypeError("%s.forward: expected Variable (got %s) for return value %d",
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
};
// Sets the grad_fn and output_nr of an output Variable.
auto set_history = [&](Variable& var, int output_nr, bool is_input, bool is_modified,
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
bool is_differentiable) {
if (!is_differentiable) {
if (!var.requires_grad()) return;
@ -393,24 +393,22 @@ static void _wrap_outputs(THPFunction *self,
}
// If the input was modified, transplant the grad_fn in the graph:
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
var.get()->grad.reset();
var.get()->hooks.clear();
if (auto grad_acc_fn = var.get()->grad_accumulator.lock()) {
var.reset_grad();
var.clear_hooks();
if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
grad_acc->variable.reset();
}
if (cdata) {
var.rebase_history(output_nr, cdata);
var.rebase_history({cdata, output_nr});
}
} else if (is_input) {
// An input has been returned, but it wasn't modified. Return it as a view
// so that we can attach a new grad_fn to the Variable.
var = var.slice();
var.get()->output_nr = output_nr;
var.get()->_grad_fn = cdata;
var.set_gradient_edge({cdata, output_nr});
} else if (cdata) {
var.get()->output_nr = output_nr;
var.get()->_grad_fn = cdata;
var.set_gradient_edge({cdata, output_nr});
}
};
@ -457,7 +455,7 @@ static void _save_variables(THPFunction* self)
self->saved_variables.emplace_back(variable->cdata, is_output);
} else if (THPModule_isTensor(obj)) {
// TODO: remove once Variable and Tensor classes are merged
auto var = make_variable(createTensor(obj), false);
auto var = make_variable(createTensor(obj), /*requires_grad=*/false);
self->saved_variables.emplace_back(std::move(var), false);
} else {
throw TypeError(

View File

@ -1,12 +1,11 @@
#include "torch/csrc/autograd/python_variable.h"
#include <structmember.h>
#include "THP.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Types.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/autograd/python_hook.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/python_variable_indexing.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
@ -16,7 +15,15 @@
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/jit/tracer_state.h"
#include <ATen/ATen.h>
#include <list>
#include <memory>
#include <structmember.h>
using namespace at;
using namespace torch::autograd;
@ -35,11 +42,13 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
if (obj) {
auto v = (THPVariable*) obj;
new (&v->cdata) Variable(std::move(var));
v->cdata.get()->pyobj = obj;
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.get()->_grad_fn.get())) {
v->cdata.set_pyobj(obj);
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn_unsafe())) {
// Create a new reference to the THPFunction. This ensures that ref count
// of the THPFunction is at least the number of referring THPVariables.
v->cdata.get()->_grad_fn = THPFunction_asFunction((THPFunction*)fn->obj);
const auto output_nr = v->cdata.output_nr();
v->cdata.set_gradient_edge(
{THPFunction_asFunction((THPFunction*)fn->obj), output_nr});
}
}
return obj;
@ -57,7 +66,7 @@ PyObject * THPVariable_Wrap(Variable var)
}
#endif
if (auto obj = var.get()->pyobj) {
if (auto obj = var.pyobj()) {
Py_INCREF(obj);
return obj;
}
@ -84,7 +93,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
// for more details about the race condition involving traversing the grad_fn
// and the python GC.
if (self->cdata.defined()) {
for (auto& hook : self->cdata.hooks()) {
for (const auto& hook : self->cdata.hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
@ -98,10 +107,10 @@ static int THPVariable_clear(THPVariable *self)
Py_CLEAR(self->data);
Py_CLEAR(self->backward_hooks);
if (self->cdata.defined()) {
if (auto grad_acc = self->cdata.get()->grad_accumulator.lock()) {
if (auto grad_acc = self->cdata.try_get_grad_accumulator()) {
grad_acc->pre_hooks.clear();
}
self->cdata.get()->pyobj = nullptr;
self->cdata.set_pyobj(nullptr);
}
self->cdata.reset();
return 0;
@ -154,13 +163,15 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
Variable var;
if (grad_fn) {
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
var = make_variable(torch::createTensor(data), grad_fn_);
Edge edge(grad_fn_, grad_fn_->num_inputs++);
var = make_variable(torch::createTensor(data), std::move(edge));
} else {
var = make_variable(torch::createTensor(data), requires_grad);
}
if (name)
var.name() = std::string(name);
if (name) {
var.set_name(name);
}
PyObject* self = THPVariable_NewWithVar(type, std::move(var));
if (self) {
@ -223,7 +234,7 @@ int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
{
HANDLE_TH_ERRORS
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
self->cdata.get()->_grad_fn = nullptr;
self->cdata.set_gradient_edge(Edge());
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
@ -246,30 +257,6 @@ PyObject * THPVariable_get_data(THPVariable *self)
END_HANDLE_TH_ERRORS
}
namespace {
// XXX: This is a hack to access private TensorImpl::type_
// http://bloglitb.blogspot.com/2011/12/access-to-private-members-safer.html
// This is currently needed because module.float() changes the type of the
// data field of each variable. We should fix this and not allow changing the
// type of var.data.
template<typename Tag, typename Tag::type M>
struct Rob {
friend typename Tag::type get(Tag) {
return M;
}
};
struct TensorImpl_Type {
typedef Type* TensorImpl::*type;
friend type get(TensorImpl_Type);
};
template struct Rob<TensorImpl_Type, &TensorImpl::type_>;
}
int THPVariable_set_data(THPVariable *self, PyObject *data)
{
HANDLE_TH_ERRORS
@ -282,7 +269,7 @@ int THPVariable_set_data(THPVariable *self, PyObject *data)
if (&self->cdata.data().type() != &tensor.type()) {
// we change the type of var.data so we must change the type of var
auto newType = VariableType::getType(tensor);
self->cdata.get()->*get(TensorImpl_Type()) = newType;
self->cdata.temporary_hack_set_type(newType);
}
self->cdata.data() = tensor;
return 0;
@ -301,7 +288,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
HANDLE_TH_ERRORS
auto& var = self->cdata;
if (py_grad == Py_None) {
var.grad().reset();
var.reset_grad();
return 0;
}
@ -342,7 +329,8 @@ int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
PyObject *THPVariable_get_output_nr(THPVariable *self)
{
HANDLE_TH_ERRORS
return PyInt_FromLong(self->cdata.output_nr());
const auto output_nr = static_cast<long>(self->cdata.output_nr());
return PyInt_FromLong(output_nr);
END_HANDLE_TH_ERRORS
}
@ -368,7 +356,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
return -1;
}
var.get()->_requires_grad = (obj == Py_True);
var.set_requires_grad(obj == Py_True);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
@ -400,9 +388,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
Py_XINCREF(obj);
Py_XDECREF(self->backward_hooks);
self->backward_hooks = obj;
self->cdata.hooks().clear();
self->cdata.clear_hooks();
if (obj) {
self->cdata.hooks().emplace_back(new PyFunctionPreHook(obj, 0));
self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
}
return 0;
END_HANDLE_TH_ERRORS_RET(-1)

View File

@ -3,8 +3,10 @@
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/THP_export.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/python_compat.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/tensor_new.h"
@ -115,7 +117,7 @@ static Variable valueToTensor(const Type & type, PyObject* value) {
return type.scalarTensor(Scalar(THPUtils_unpackDouble(value)));
}
if (THPModule_isTensor(value)) {
return make_variable(createTensor(value));
return make_variable(createTensor(value), /*requires_grad=*/false);
}
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
}
@ -152,7 +154,7 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
} else if (THPVariable_Check(obj)) {
handle_var(reinterpret_cast<THPVariable*>(obj)->cdata);
} else if (THPModule_isTensor(obj)) {
handle_var(make_variable(createTensor(obj)));
handle_var(make_variable(createTensor(obj), /*requires_grad=*/false));
} else if (PySequence_Check(obj)) {
handle_var(sequenceToVariable(self.type(), obj));
} else {
@ -270,7 +272,7 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
variable_list variableIndices;
Variable sliced = applySlicing(self_, holder.get(), variableIndices);
if (variableIndices.empty()) {
if (sliced.get() == self_.get()) {
if (sliced.is_same(self_)) {
// ensure we return a shallow copy for things like x[...]
sliced = at::alias(sliced);
}

View File

@ -1,49 +1,57 @@
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/tracer_state.h"
using namespace at;
#include <ATen/Tensor.h>
#include <cstdint>
#include <list>
#include <memory>
namespace torch { namespace autograd {
SavedVariable::SavedVariable(const Variable& variable, bool is_output)
: SavedVariable() {
if (!variable.defined()) {
return;
}
data = variable.data();
requires_grad = variable.requires_grad();
expected_version = variable.current_version();
version = variable.get()->version_counter.save();
has_grad_fn = !variable.is_leaf();
output_nr = variable.output_nr();
if (!has_grad_fn) {
grad_accumulator = variable.grad_accumulator();
}
if (!is_output) {
_grad_fn = variable.grad_fn();
}
if (variable.tracing_state()) {
tracing_state.reset(new jit::tracer::ValueTracingState(*variable.tracing_state()));
: output_nr_(variable.output_nr()),
requires_grad_(variable.requires_grad()),
has_grad_fn_(!variable.is_leaf()) {
if (variable.defined()) {
was_default_constructed_ = false;
// These copies are all shared_ptr copies, so slightly more expensive.
// Do them here instead of in the init list in case data is undefined.
data_ = variable.data();
if (variable.is_leaf()) {
grad_accumulator_ = variable.grad_accumulator();
} else if (!is_output) {
grad_fn_ = variable.grad_fn();
}
version_counter_ = variable.version_counter();
saved_version_ = version_counter_.current_version();
if (variable.has_tracing_state()) {
tracing_state_.reset(
new jit::tracer::ValueTracingState(variable.tracing_state()));
}
}
}
auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variable {
if (!data.defined()) {
if (version.defined()) {
Variable SavedVariable::unpack(std::shared_ptr<Function> saved_for) const {
if (!data_.defined()) {
if (!was_default_constructed_) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
return Variable();
}
if (version.is_modified()) {
if (saved_version_ != version_counter_.current_version()) {
throw std::runtime_error(
"one of the variables needed for gradient computation has been "
"modified by an inplace operation");
}
auto grad_fn = _grad_fn;
if (has_grad_fn && !grad_fn) {
auto grad_fn = grad_fn_;
if (has_grad_fn_ && !grad_fn) {
if (!saved_for) {
// If saving the grad_fn would create a circular reference, then it must
// be passed in to the unpack function.
@ -57,20 +65,22 @@ auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variabl
// in-place functions on unpacked variables.
Variable var;
if (grad_fn) {
var = make_variable(data, output_nr, std::move(grad_fn));
var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data, requires_grad);
var = make_variable(data_, requires_grad_);
}
var.version_counter() = version;
var.set_version_counter(saved_version_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the graph).
if (requires_grad && !var.grad_fn() && grad_accumulator.expired())
// alive, the accumulator should be kept alive by the references in the
// graph).
if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
var.get()->grad_accumulator = grad_accumulator;
if (tracing_state)
var.tracing_state().reset(new jit::tracer::ValueTracingState(*tracing_state));
var.set_grad_accumulator(grad_accumulator_);
if (tracing_state_) {
var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_));
}
return var;
}

View File

@ -1,46 +1,55 @@
#pragma once
#include <mutex>
#include <memory>
#include <functional>
#include "torch/csrc/autograd/variable_version.h"
#include "torch/csrc/jit/tracer_state.h"
#include <ATen/ATen.h>
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/variable_version.h"
#include "torch/csrc/Types.h"
#include <cstdint>
#include <list>
#include <memory>
namespace torch { namespace autograd {
struct Variable;
struct Function;
extern const char* ERR_BACKWARD_TWICE;
struct SavedVariable {
SavedVariable()
: data()
, has_grad_fn(false)
, version()
, requires_grad(false)
, expected_version(-1) {}
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
/// enough information to reconstruct a variable from a certain point in time.
class SavedVariable {
public:
SavedVariable() = default;
SavedVariable(const Variable& variable, bool is_output);
SavedVariable(SavedVariable&&) = default;
SavedVariable& operator=(SavedVariable&&) = default;
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
/// function if constructing the `SavedVariable` with it would have caused a
/// circular reference.
Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const;
void reset_data() {
return data_.reset();
}
private:
at::Tensor data_;
at::Tensor data;
// The gradient function associated with this node. If has_grad_fn
// is false, then this is a leaf node. Note that the grad_fn is not saved if
// it would create a circular reference. In that case, the grad_fn must be
// passed in to the unpack function when reconstructing the Variable.
bool has_grad_fn;
std::shared_ptr<Function> _grad_fn;
std::weak_ptr<Function> grad_accumulator;
SavedVersion version;
bool requires_grad;
int expected_version;
int output_nr;
std::unique_ptr<jit::tracer::ValueTracingState> tracing_state;
std::shared_ptr<Function> grad_fn_;
std::weak_ptr<Function> grad_accumulator_;
std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
VariableVersion version_counter_;
Variable unpack(std::shared_ptr<Function> saved_for=nullptr) const;
uint32_t saved_version_;
uint32_t output_nr_;
bool was_default_constructed_ = true;
bool requires_grad_;
bool has_grad_fn_;
};
}} // namespace torch::autograd

View File

@ -1,84 +1,82 @@
#include "Python.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
#include "torch/csrc/autograd/functions/tensor.h"
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/variable_version.h"
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include <ATen/ATen.h>
#include <list>
#include <memory>
using namespace at;
#include <mutex>
#include <stdexcept>
#include <string>
#include <vector>
namespace torch { namespace autograd {
Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) {
// TODO: If you ever want to support returning an undefined tensor from
// a function, you'll have to uncomment the line below. Not sure if
// we actually want to support this.
// if (!data.defined()) return Variable();
TORCH_ASSERT(grad_fn);
int output_nr = grad_fn->num_inputs++;
return make_variable(std::move(data), output_nr, std::move(grad_fn));
}
VariableImpl::VariableImpl(Tensor data_, bool requires_grad, int output_nr, std::shared_ptr<Function> grad_fn)
: TensorImpl(VariableType::getType(data_))
, data(std::move(data_))
, grad()
, _grad_fn(std::move(grad_fn))
, version_counter()
, _requires_grad(requires_grad)
, is_view(false)
, output_nr(output_nr)
, pyobj(nullptr) {
TORCH_ASSERTM(!_grad_fn || !_requires_grad, "_requires_grad should be false if grad_fn is set");
Variable::Impl::Impl(at::Tensor data_, bool requires_grad_, Edge gradient_edge_)
: TensorImpl(VariableType::getType(data_)),
data(std::move(data_)),
grad_fn(std::move(gradient_edge_.function)),
requires_grad(requires_grad_),
is_view(false),
output_nr(gradient_edge_.input_nr),
pyobj(nullptr) {
TORCH_ASSERTM(
!grad_fn || !requires_grad,
"_requires_grad should be false if grad_fn is set");
if (!data.defined()) {
throw std::runtime_error("data is undefined");
}
}
VariableImpl::~VariableImpl() {
}
Variable::Impl::~Impl() = default;
const char * VariableImpl::toString() const {
const char* Variable::Impl::toString() const {
return "Variable";
}
IntList VariableImpl::sizes() const {
IntList Variable::Impl::sizes() const {
return data.sizes();
}
IntList VariableImpl::strides() const {
IntList Variable::Impl::strides() const {
return data.strides();
}
int64_t VariableImpl::dim() const {
int64_t Variable::Impl::dim() const {
return data.dim();
}
const char * VariableImpl::typeString() {
const char* Variable::Impl::typeString() {
return "VariableType";
}
void * VariableImpl::unsafeGetTH(bool retain) {
void* Variable::Impl::unsafeGetTH(bool retain) {
return data.unsafeGetTH(retain);
}
std::unique_ptr<at::Storage> VariableImpl::storage() {
std::unique_ptr<at::Storage> Variable::Impl::storage() {
return data.storage();
}
Scalar VariableImpl::localScalar() {
Scalar Variable::Impl::localScalar() {
return data.pImpl->localScalar();
}
std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
if (_grad_fn) {
throw std::logic_error("get_grad_accumulator() should be only called on leaf Variables");
std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
if (grad_fn) {
throw std::logic_error(
"get_grad_accumulator() should be only called on leaf Variables");
}
if (!_requires_grad) {
if (!requires_grad) {
return nullptr;
}
@ -92,11 +90,12 @@ std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
return result;
}
VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, int output_nr,
std::shared_ptr<Function> grad_fn)
: VariableImpl(std::move(data_), false, output_nr, std::move(grad_fn))
, base(std::move(base_))
, attr_version(0) {
Variable::ViewImpl::ViewImpl(
Variable base_,
at::Tensor data_,
Edge gradient_edge_)
: Variable::Impl(std::move(data_), false, std::move(gradient_edge_)),
base(std::move(base_)) {
TORCH_ASSERTM(base.defined(), "base is undefined");
if (base.is_view()) {
base = base.base();
@ -106,52 +105,71 @@ VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, int output_
attr_version = version_counter.current_version();
}
std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
std::lock_guard<std::mutex> lock(mutex);
if (!_grad_fn && !base.requires_grad()) {
return _grad_fn;
if (!grad_fn && !base.requires_grad()) {
return grad_fn;
}
auto current_version = version_counter.current_version();
if (attr_version != current_version) {
TORCH_ASSERT(output_nr == 0);
auto fn = std::make_shared<generated::AsStridedBackward>();
fn->self_geometry = TensorGeometry(base);
fn->self_geometry = at::TensorGeometry(base);
fn->size = sizes();
fn->stride = strides();
fn->storage_offset = data.storage_offset();
fn->set_next_functions(get_next_functions(base));
fn->num_inputs = 1;
_grad_fn = std::move(fn);
grad_fn = std::move(fn);
attr_version = current_version;
}
return _grad_fn;
return grad_fn;
}
void VariableViewImpl::rebase_history(int output_nr, std::shared_ptr<Function> grad_fn) {
TORCH_ASSERT(output_nr == 0);
TORCH_ASSERT(grad_fn);
TORCH_ASSERTM(grad_fn->num_inputs == 1, "Functions which modify views in-place must return a single Variable");
this->output_nr = output_nr;
base.output_nr() = 0;
base.get()->_grad_fn = std::make_shared<CopySlices>(
base, TensorGeometry(data), std::move(grad_fn));
get_grad_fn(); // trigger an update to the view's grad_fn
void Variable::ViewImpl::rebase_history(Edge gradient_edge) {
TORCH_ASSERT(gradient_edge.input_nr == 0);
TORCH_ASSERT(gradient_edge.function);
TORCH_ASSERTM(
gradient_edge.function->num_inputs == 1,
"Functions which modify views in-place must return a single Variable");
this->output_nr = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
base, at::TensorGeometry(data), std::move(gradient_edge.function));
base.set_gradient_edge({std::move(copy_slices), 0});
get_grad_fn(); // trigger an update to the view's grad_fn
}
void Variable::rebase_history(Edge gradient_edge) {
TORCH_ASSERT(gradient_edge.function != nullptr);
if (is_view()) {
auto& impl = static_cast<Variable::ViewImpl&>(*get());
impl.rebase_history(std::move(gradient_edge));
} else {
set_gradient_edge(std::move(gradient_edge));
}
}
Variable Variable::detach() const {
Variable detached = make_variable(data());
detached.version_counter() = version_counter();
auto detached = make_variable(data(), /*requires_grad=*/false);
detached.set_version_counter(version_counter());
return detached;
}
void Variable::detach_() {
if (is_view()) {
throw std::runtime_error("Can't detach views in-place. Use detach() instead");
throw std::runtime_error(
"Can't detach views in-place. Use detach() instead");
}
get()->_requires_grad = false;
output_nr() = 0;
get()->_grad_fn = nullptr;
set_requires_grad(false);
set_gradient_edge(Edge());
}
void Variable::set_tracing_state(
jit::tracer::ValueTracingState* new_tracing_state) {
get()->tracing_state.reset(new_tracing_state);
}
jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept {
return *get()->tracing_state;
}
}} // namespace torch::autograd

View File

@ -1,46 +1,187 @@
#pragma once
// A wrapper around at::Tensor to represent autograd Variables. Variables
// can be implicitly converted to an at::Tensor.
#include <Python.h>
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable_version.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include <mutex>
#include <memory>
#include <vector>
#include <functional>
#include <ATen/ATen.h>
#include "torch/csrc/assertions.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include "torch/csrc/autograd/variable_version.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/Types.h"
#include <list>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <vector>
namespace torch {
namespace autograd {
struct Function;
} // namespace autograd
namespace jit { namespace tracer {
// Has to be forward declared because tracer_state.h has a dependency on
// variable.h.
struct ValueTracingStateElem;
using ValueTracingState = std::list<ValueTracingStateElem>;
}} // namespace jit::tracer
} // namespace torch
namespace torch { namespace autograd {
using at::Tensor;
struct VariableImpl;
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Variable
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A `Variable` augments a `Tensor` with the ability to interact in our
/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
/// `Function`s in the autograd graph. A `Variable` can either be a leaf, like a
/// weight in a neural network, or an interior variable, when it is the result
/// of an operation between variables. Every `Variable` also stores another
/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
/// gradient will be accumulated into this variable.
///
/// Gradient Edges
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
/// edge in the autograd graph that connects the variable to a particular input
/// of the gradient function that will be invoked with the variable during the
/// backward pass. More precisely, this gradient function can be one of two
/// things:
/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
/// gradient of the function that produced the variable.
/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
/// scalar gradient value into its `grad` variable.
///
/// Versioning
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Another major feature of `Variable`s are *versions*. Versions are
/// incremented when an in-place mutation of a variable occurs. Versions are
/// useful when constructing `SavedVariable`s, which take a snapshot of a
/// `Variable` at a certain version. You can retrieve a `Variable`'s version
/// through its `current_version()` method.
///
/// Views
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// It is possible for a `Variable` to be a *view* of another `Variable`, in
/// which case it tracks that `Variable`'s data and autograd history. Beyond
/// construction, the interface of a view is identical to that of a regular
/// `Variable`. You can determine whether `Variable` is in fact a view by
/// probing its `is_view()` method.
///
/// Interface
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// `Variable` inherits from `Tensor` and thus its API is a superset of that of
/// `Tensor`. This means you can perform all the usual mathematical and other
/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore,
/// `Variable` and `Tensor` actually convert implicitly between each other. You
/// can thus call functions defined on `Tensor`s also with `Variable`s. For
/// this, the `Variable` class allows implicit construction from `Tensor`. It is
/// the responsibility of calling code to ensure that this constructor is
/// invoked only when the `Tensor`'s dynamic type is actually `Variable`. Most
/// notably, it is *not* correct to construct a brand new `Variable` from a
/// `Tensor` using this constructor. To do so, you must use the `make_variable`
/// free function instead. To create a view variable, use `make_variable_view`.
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct Variable : public at::Tensor {
inline Variable(VariableImpl * self, bool retain);
Variable() : Tensor() {}
Variable(const Variable & rhs) : Tensor(rhs) {}
Variable(Variable && rhs) noexcept : Tensor(std::move(rhs)) {}
/// Default constructor.
Variable() = default;
// Implicitly casts a Tensor to a Variable. This should only be called on
// Tensors which you know are actually Variables.
/*implicit*/ Variable(Tensor const & rhs) : Tensor(rhs) {}
/*implicit*/ Variable(Tensor && rhs) noexcept : Tensor(std::move(rhs)) {}
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// NOTE: These factory functions have to be friends to access the
// `Variable::Impl`. As a side effect, it allows us to keep them in the class.
/// Creates a `Variable` that is a *view* of another (*base*) variable.
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
friend Variable
make_variable_view(Variable base, at::Tensor data, Edge gradient_edge);
/// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
/// set only for leaves, and determines whether the `Variable` will accumulate
/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
/// type *must* be `Tensor`.
friend Variable make_variable(at::Tensor data, bool requires_grad);
/// Creates a `Variable` from the given `Tensor` and specify a `gradient_edge`,
/// i.e. a (function, input_nr) pair specifying the function in the autograd
/// graph, and what particular input of that function, this variable is
/// connected to.
friend Variable make_variable(at::Tensor data, Edge gradient_edge);
// Tensor Conversions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you
// know are Variables.
/*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {}
/*implicit*/ Variable(at::Tensor&& rhs) noexcept
: at::Tensor(std::move(rhs)) {}
// NOTE: Assignment operators to Tensor come for free from the constructors.
/// Downcasts the `Tensor` reference to a `Variable` reference. If compiling
/// in DEBUG mode and the tensor's dynamic type is not in fact `Variable`,
/// throws a `std::runtime_error` exception.
/// NOTE: Has to be a friend function because runtime type information is
/// available only for `TensorImpl`/`Impl` and not the `Tensor`/`Variable`
/// classes, as the latter are not polymorphic classes (`Tensor` has no
/// virtual methods).
friend Variable& as_variable_ref(at::Tensor& tensor);
const at::Tensor& data() const noexcept;
at::Tensor& data() noexcept;
// Gradient Function and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Gets the gradient function of the `Variable`. If this is a leaf variable,
/// the pointer returned will be null.
const std::shared_ptr<Function>& grad_fn() const;
/// Gets the raw gradient function pointer, whatever it currently is.
Function* grad_fn_unsafe() const;
/// Sets the gradient accumulator of the `Variable`. This is only applicable
/// to leaf variables. Interior variables should call `set_gradient_edge()`.
void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator);
/// Attempts to get a pointer to the gradient accumulator of the `Variable`,
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
std::shared_ptr<Function> try_get_grad_accumulator() const;
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
std::shared_ptr<Function> grad_accumulator() const;
/// Sets the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
/// `Variable`.
/// NOTE: This will always set the `grad_fn`, even if this is a leaf
/// variable, and never the `grad_accumulator`. For the latter, use
/// `set_grad_accumulator`. This allows late construction of an interior
/// `Variable`.
/// You will likely want to call `rebase_history()` if this call is involved
/// in an in-place modification of a `Variable`.
void set_gradient_edge(Edge&& edge) noexcept;
/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
/// gradient function if this is an interior `Variable`, or the gradient
/// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
/// will store the input index of the `Function` to which this variable is
/// connected in its `input_nr` field. For leaves, the `input_nr` is always
/// zero. Note that `set_gradient_edge` and `gradient_edge` are not
/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
/// `set_grad_accumulator` to set the accumulator.
Edge gradient_edge() const {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a grad accumulator,
// which will accumulate its inputs into the grad property of the
// variable. These nodes get suppressed in some situations,
// see "suppress grad accumulation" below. Note that only variables which
// have `requires_grad = True` can have grad accumulators.
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = grad_fn()) {
return Edge(gradient, output_nr());
} else {
@ -48,282 +189,434 @@ struct Variable : public at::Tensor {
}
}
inline VariableImpl* get() const;
/// Returns the input index of the gradient `Function` to which this `Variable`
/// is connected.
uint32_t output_nr() const noexcept;
inline const Tensor & data() const;
inline Tensor & data();
/// True if this `Variable` is a leaf and thus does not have a `grad_fn`.
bool is_leaf() const noexcept;
inline Tensor opt_data() const;
// The Grad Variable
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline const Variable & grad() const;
inline Variable & grad();
/// Accesses the gradient `Variable` of this `Variable`.
const Variable& grad() const noexcept;
Variable& grad() noexcept;
void reset_grad() noexcept;
inline bool is_leaf() const;
/// Sets the `requires_grad` property of `Variable`. This should be true for
/// leaf variables that want to accumulate gradients, and false for all other
/// variables.
void set_requires_grad(bool requires_grad) noexcept;
bool requires_grad() const noexcept;
inline const std::shared_ptr<Function>& grad_fn() const;
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Updates the grad_fn of an existing Variable. Called after in-place modifications.
// XXX: this should be called only _after_ the version counter is implemented.
inline void rebase_history(int output_nr, std::shared_ptr<Function> grad_fn);
/// Increments the version count of this `Variable`.
void bump_version() noexcept;
void set_version_counter(const VariableVersion& version_counter) noexcept;
std::shared_ptr<Function> grad_accumulator() const;
/// Retrieves this `Variable`s version counter.
const VariableVersion& version_counter() const noexcept;
/// Retrieves the current value of the `Variable`'s version counter.
/// Equivalent to calling `version_counter().current_version()`.
uint32_t current_version() const noexcept;
// Autograd Graph Interaction
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Update the `grad_fn` of an existing Variable. Called after in-place
/// modifications.
void rebase_history(Edge gradient_edge);
/// Returns a copy of this `Variable` that is detached from its autograd graph
/// and has a blank version. This method is OK to call if the `Variable` is a
/// view.
Variable detach() const;
/// Like `detach()`, but removes this `Variable` in-place. This method may
/// only be called on non-view `Variable`s. You can use `is_view()` to check
/// this. If this `Variable` is a view, throws an `std::runtime_error()`.
void detach_();
inline const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const;
inline std::vector<std::shared_ptr<FunctionPreHook>>& hooks();
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline auto_unique_ptr<jit::tracer::ValueTracingState>& tracing_state() const;
void add_hook(std::shared_ptr<FunctionPreHook> hook);
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
void clear_hooks();
inline int current_version() const;
// JIT Tracing
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline VariableVersion& version_counter() const;
void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state);
jit::tracer::ValueTracingState& tracing_state() const noexcept;
inline const int& output_nr() const;
inline int& output_nr();
/// Returns true if the `Variable`'s tracing state is not null.
bool has_tracing_state() const noexcept;
inline bool requires_grad() const;
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool is_view() const;
inline Variable& base() const;
/// Returns true if this `Variable` is a view of another `Variable`.
bool is_view() const noexcept;
inline const std::string& name() const;
inline std::string& name();
/// Returns the `Variable` that this `Variable` is a view of. If this
/// `Variable` is not a view, throw a `std::runtime_error`.
const Variable& base() const;
inline Variable & operator=(Variable && rhs) &;
inline Variable & operator=(const Variable & rhs) &;
inline Variable & operator=(Tensor && rhs) &;
inline Variable & operator=(const Tensor & rhs) &;
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Compares this `Variable` to another `Variable` (or `Tensor`) via
/// pointer-equality.
bool is_same(const Variable& other) const noexcept {
return this->pImpl == other.pImpl;
}
void set_name(const std::string& name);
const std::string& name() const noexcept;
PyObject* pyobj() const noexcept;
void set_pyobj(PyObject* pyobj) noexcept;
// Hacks!
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Sets the type of the underlying `Tensor`. Used for a bad (hopefully)
/// temporary hack in python_variable.h. If removed, also remove the `using
/// at::TensorImpl::type_;` in `Variable::Impl`.
void temporary_hack_set_type(at::Type*) noexcept;
private:
/// Private implementation struct of the `Variable`. This struct declaration
/// and the `get()` method which exposes it shall forever remain private and
/// never be exposed to the public interface of this class.
struct Impl;
struct ViewImpl;
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Variable(Variable::Impl* self, bool retain);
Impl* get() const noexcept;
};
struct VariableImpl : public at::TensorImpl {
public:
VariableImpl(at::Tensor data, bool requires_grad=false, int output_nr=0,
std::shared_ptr<Function> grad_fn=nullptr);
virtual ~VariableImpl();
virtual const char * toString() const override;
virtual at::IntList sizes() const override;
virtual at::IntList strides() const override;
virtual int64_t dim() const override;
virtual at::Scalar localScalar() override;
virtual void * unsafeGetTH(bool retain) override;
virtual std::unique_ptr<at::Storage> storage() override;
static const char * typeString();
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::Impl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct Variable::Impl : public at::TensorImpl {
explicit Impl(
at::Tensor data_,
bool requires_grad_ = false,
Edge edge = Edge());
virtual ~Impl();
const char* toString() const override;
at::IntList sizes() const override;
at::IntList strides() const override;
int64_t dim() const override;
at::Scalar localScalar() override;
void* unsafeGetTH(bool retain) override;
std::unique_ptr<at::Storage> storage() override;
static const char* typeString();
public:
std::shared_ptr<Function> get_grad_accumulator();
virtual std::shared_ptr<Function>& get_grad_fn() { return _grad_fn; }
virtual std::shared_ptr<Function>& get_grad_fn() {
return grad_fn;
}
// Make this field public so we can access it from `Variable`. Part of
// temporary_hack_set_type.
using at::TensorImpl::type_;
std::string name;
at::Tensor data;
Variable grad;
std::shared_ptr<Function> _grad_fn;
std::shared_ptr<Function> grad_fn;
std::weak_ptr<Function> grad_accumulator;
VariableVersion version_counter;
std::vector<std::shared_ptr<FunctionPreHook>> hooks;
std::weak_ptr<Function> grad_accumulator;
// Mutex to ensure that concurrent read operations that modify internal state
// are still thread-safe. Used by get_grad_fn and get_grad_accumulator.
std::mutex mutex;
bool _requires_grad; // only meaningful on leaf variables (must be false otherwise)
bool requires_grad; // only meaningful on leaf variables (must be false
// otherwise)
bool is_view;
// The "output number" of this variable; e.g., if this variable
// was the second output of a function, then output_nr == 1.
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
int output_nr;
PyObject *pyobj; // weak reference
uint32_t output_nr;
PyObject* pyobj; // weak reference
std::string name;
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by get_grad_fn and
// get_grad_accumulator.
std::mutex mutex;
// For use in torch::jit::tracer
auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state;
friend struct VariableType;
};
// A Variable that is a view on another Variable. The base and view share the
// same version_counter. The _grad_fn field of the Variable may become stale
// due to in-place modifications of the shared data. Accesses should go through
// get_grad_fn(). All other fields are always valid.
struct VariableViewImpl : public VariableImpl {
VariableViewImpl(Variable base, at::Tensor data, int output_nr, std::shared_ptr<Function> grad_fn);
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::ViewImpl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Gets the up-to-date grad_fn. If the shared data or base was modified, we
// re-create the grad_fn to express the up-to-date view relationship between
// this and the base Variable.
/// A Variable that is a view on another Variable. The base and view share the
/// same version_counter. The grad_fn field of the Variable may become stale
/// due to in-place modifications of the shared data. Accesses should go
/// through get_grad_fn(). All other fields are always valid.
struct Variable::ViewImpl : public Variable::Impl {
ViewImpl(Variable base_, at::Tensor data_, Edge gradient_edge);
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
/// re-create the grad_fn to express the up-to-date view relationship between
/// this and the base Variable.
virtual std::shared_ptr<Function>& get_grad_fn() override;
// Called after in-place modifications. Modifies the grad_fn of the base
// Variable.
void rebase_history(int output_nr, std::shared_ptr<Function> grad_fn);
/// Called after in-place modifications. Modifies the grad_fn of the base
/// Variable.
void rebase_history(Edge gradient_edge);
// The base Variable (never a view)
/// The base `Variable` (never a view).
Variable base;
// The value of the version_counter at the time grad_fn was created. The
// _grad_fn field is stale if attr_version != version_counter.current_version()
int attr_version;
/// The value of the version_counter at the time grad_fn was created. The
/// grad_fn field is stale if attr_version !=
/// version_counter.current_version().
uint32_t attr_version;
};
inline Variable make_variable(at::Tensor data, bool requires_grad=false) {
if (!data.defined()) {
return Variable();
}
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable Implementation
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace detail {
inline at::Tensor handle_scalars(at::Tensor& data) {
#ifndef WITH_SCALARS
if (data.dim() == 0) {
// don't expose 0-dim tensors to Variable API.
data = data.as_strided_({1}, {1});
// Don't expose 0-dim tensors to Variable API.
return data.as_strided_({1}, {1});
}
#endif
return data;
}
} // namespace detail
return Variable(new VariableImpl(std::move(data), requires_grad), false);
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Variable make_variable_view(
Variable base,
at::Tensor data,
Edge gradient_edge = Edge()) {
if (data.defined()) {
data = detail::handle_scalars(data);
auto impl = new Variable::ViewImpl(
std::move(base), std::move(data), std::move(gradient_edge));
return Variable(impl, /*retain=*/false);
}
return Variable();
}
inline Variable make_variable(at::Tensor data, int output_nr, std::shared_ptr<Function> grad_fn) {
if (!data.defined()) {
return Variable();
inline Variable make_variable(at::Tensor data, bool requires_grad) {
if (data.defined()) {
auto impl = new Variable::Impl(detail::handle_scalars(data), requires_grad);
return Variable(impl, /*retain=*/false);
}
return Variable();
}
#ifndef WITH_SCALARS
if (data.dim() == 0) {
// don't expose 0-dim tensors to Variable API.
data = data.as_strided_({1}, {1});
inline Variable make_variable(at::Tensor data, Edge gradient_edge) {
if (data.defined()) {
auto impl = new Variable::Impl(
detail::handle_scalars(data), false, std::move(gradient_edge));
return Variable(impl, /*retain=*/false);
}
return Variable();
}
// Tensor Conversion
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Variable& as_variable_ref(at::Tensor& tensor) {
#ifdef DEBUG
// dynamic_cast will return a nullptr if the `TensorImpl`'s dynamic type is
// not `Variable::Impl`.
if (dynamic_cast<Variable::Impl*>(tensor.get()) == nullptr) {
throw std::runtime_error(
"Attempted to cast a Tensor to a Variable, but "
"the dynamic type of the value is not Variable.");
}
#endif
return Variable(new VariableImpl(std::move(data), false, output_nr, std::move(grad_fn)), false);
return static_cast<Variable&>(tensor);
}
Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn);
inline Variable make_variable_view(Variable base, at::Tensor data, int output_nr=0,
std::shared_ptr<Function> grad_fn=nullptr) {
if (!data.defined()) {
return Variable();
}
#ifndef WITH_SCALARS
if (data.dim() == 0) {
// don't expose 0-dim tensors to Variable API.
data = data.as_strided_({1}, {1});
}
#endif
return Variable(new VariableViewImpl(std::move(base), std::move(data), output_nr, std::move(grad_fn)), false);
}
inline Variable::Variable(VariableImpl * self, bool retain) : Tensor(self, retain) {
}
inline VariableImpl* Variable::get() const {
return static_cast<VariableImpl*>(pImpl);
}
inline const Tensor & Variable::data() const {
return get()->data;
}
inline Tensor & Variable::data() {
inline const at::Tensor& Variable::data() const noexcept {
return get()->data;
}
inline Tensor Variable::opt_data() const {
if (!defined()) {
return Tensor();
}
return data();
inline at::Tensor& Variable::data() noexcept {
return get()->data;
}
inline const Variable & Variable::grad() const {
return get()->grad;
}
inline Variable & Variable::grad() {
return get()->grad;
}
inline bool Variable::is_leaf() const {
return get()->_grad_fn == nullptr;
}
// Gradient Function and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline const std::shared_ptr<Function>& Variable::grad_fn() const {
return get()->get_grad_fn();
};
inline void Variable::rebase_history(int output_nr, std::shared_ptr<Function> grad_fn) {
TORCH_ASSERT(grad_fn);
if (is_view()) {
auto& impl = static_cast<VariableViewImpl&>(*get());
impl.rebase_history(output_nr, std::move(grad_fn));
} else {
get()->output_nr = output_nr;
get()->_grad_fn = std::move(grad_fn);
}
}
inline Function* Variable::grad_fn_unsafe() const {
return get()->grad_fn.get();
}
inline void Variable::set_grad_accumulator(
std::weak_ptr<Function> grad_accumulator) {
get()->grad_accumulator = std::move(grad_accumulator);
}
inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
return get()->grad_accumulator.lock();
}
inline std::shared_ptr<Function> Variable::grad_accumulator() const {
return get()->get_grad_accumulator();
};
}
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() const {
return get()->hooks;
};
inline std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() {
return get()->hooks;
};
inline void Variable::set_gradient_edge(Edge&& edge) noexcept {
get()->grad_fn = std::move(edge.function);
get()->output_nr = edge.input_nr;
}
inline auto_unique_ptr<jit::tracer::ValueTracingState>& Variable::tracing_state() const {
return get()->tracing_state;
};
inline uint32_t Variable::output_nr() const noexcept {
return get()->output_nr;
}
inline int Variable::current_version() const {
inline bool Variable::is_leaf() const noexcept {
return get()->grad_fn == nullptr;
}
// The Grad Variable
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline const Variable& Variable::grad() const noexcept {
return get()->grad;
}
inline Variable& Variable::grad() noexcept {
return get()->grad;
}
inline void Variable::reset_grad() noexcept {
get()->grad.reset();
}
inline void Variable::set_requires_grad(bool requires_grad) noexcept {
get()->requires_grad = requires_grad;
}
inline bool Variable::requires_grad() const noexcept {
return get()->requires_grad || get()->grad_fn ||
(is_view() && base().requires_grad());
}
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_version_counter(
const VariableVersion& version_counter) noexcept {
get()->version_counter = version_counter;
}
inline void Variable::bump_version() noexcept {
get()->version_counter.bump();
}
inline uint32_t Variable::current_version() const noexcept {
return get()->version_counter.current_version();
}
inline VariableVersion& Variable::version_counter() const {
inline const VariableVersion& Variable::version_counter() const noexcept {
return get()->version_counter;
}
inline const int& Variable::output_nr() const {
return get()->output_nr;
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
get()->hooks.push_back(std::move(hook));
}
inline int& Variable::output_nr() {
return get()->output_nr;
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
const noexcept {
return get()->hooks;
}
inline bool Variable::requires_grad() const {
return get()->_requires_grad || get()->_grad_fn || (is_view() && base().requires_grad());
inline void Variable::clear_hooks() {
get()->hooks.clear();
}
inline const std::string& Variable::name() const {
return get()->name;
}
inline std::string& Variable::name() {
return get()->name;
// JIT Tracing
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool Variable::has_tracing_state() const noexcept {
return get()->tracing_state != nullptr;
}
inline bool Variable::is_view()const {
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool Variable::is_view() const noexcept {
return get()->is_view;
}
inline Variable& Variable::base() const {
inline const Variable& Variable::base() const {
if (is_view()) {
return static_cast<VariableViewImpl&>(*get()).base;
return static_cast<Variable::ViewImpl*>(get())->base;
}
throw std::runtime_error("Can't get base of non-view");
}
inline Variable & Variable::operator=(Variable && rhs) & {
rhs.swap(*this);
return *this;
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_name(const std::string& name) {
get()->name = name;
}
inline Variable & Variable::operator=(const Variable & rhs) & {
Variable(rhs).swap(*this);
return *this;
inline const std::string& Variable::name() const noexcept {
return get()->name;
}
inline Variable & Variable::operator=(Tensor && rhs) & {
rhs.swap(*this);
return *this;
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
get()->pyobj = pyobj;
}
inline Variable & Variable::operator=(const Tensor & rhs) & {
Variable(rhs).swap(*this);
return *this;
inline PyObject* Variable::pyobj() const noexcept {
return get()->pyobj;
}
// Hacks!
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::temporary_hack_set_type(at::Type* new_type) noexcept {
get()->type_ = new_type;
}
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Variable::Variable(Variable::Impl* self, bool retain)
: at::Tensor(self, retain) {}
inline Variable::Impl* Variable::get() const noexcept {
return static_cast<Variable::Impl*>(pImpl);
}
}} // namespace torch::autograd

View File

@ -1,96 +1,38 @@
#pragma once
#include <atomic>
#include <cstdint>
#include <memory>
// Every Variable has a version counter. Version counters are incremented
// whenever the data or shape of a tensor changes through Variable operations.
// whenever the data or shape of a tensor changes through Variable operations.
// These are typicallly in-place operations. Version counters are used to
// detect modifications to saved varaibles which would result in incorrect
// detect modifications to saved variables which would result in incorrect
// gradient calculations. Version counters may be shared between Variables:
//
// 1. A view shares the version counter of the base Variable
// 2. Detached variables share the version counter of the source
// 3. Unpacked saved variables share the version counter of the source
// 1. A view shares the version counter of the base Variable,
// 2. Detached variables share the version counter of the source,
// 3. Unpacked saved variables share the version counter of the source.
namespace torch { namespace autograd {
struct VersionBlock {
VersionBlock() : version() {}
// monotonically increasing version
std::atomic<int> version;
};
struct SavedVersion;
struct VariableVersion {
VariableVersion() : version_block(std::make_shared<VersionBlock>()) {}
VariableVersion(const VariableVersion&) = delete;
VariableVersion(VariableVersion&&) = delete;
public:
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
// leaves it in a persistently undefined state. See
// https://cplusplus.github.io/LWG/issue2334.
VariableVersion(uint32_t version = 0)
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
// increment the version counter
void increment() { version_block->version++; }
// current version
int current_version() const { return version_block->version.load(); }
// creates a saved reference with the current version and the counter
inline SavedVersion save() const;
// Uses another variable's version counter. Used for variables which share storages
// NOTE: not thread-safe to call this from multiple threads without synchronization
// because shared_ptr assignment isn't thread-safe.
VariableVersion& operator=(const VariableVersion& other) {
version_block = other.version_block;
return *this;
void bump() noexcept {
version_block_->fetch_add(1);
}
// Uses the version counter from a SavedVariable
// NOTE: not thread-safe to call this from multiple threads without synchronization
inline VariableVersion& operator=(const SavedVersion& other);
uint32_t current_version() const noexcept {
return version_block_->load();
}
private:
friend struct SavedVersion;
std::shared_ptr<VersionBlock> version_block; // always non-null
private:
std::shared_ptr<std::atomic<uint32_t>> version_block_;
};
// The version counter used in SavedVariables. Saves the expected_version (the
// version at the time of save) and a reference to the version counter's
// version_block.
struct SavedVersion {
SavedVersion() {}
SavedVersion(const VariableVersion& version)
: expected_version(version.current_version())
, version_block(version.version_block) {}
// if the version_block has been modified since when it was saved
bool is_modified() const {
return expected_version != version_block->version.load();
}
// true if the version_block is defined
bool defined() const {
return static_cast<bool>(version_block);
}
private:
friend struct VariableVersion;
int expected_version;
std::shared_ptr<VersionBlock> version_block; // may be null
};
SavedVersion VariableVersion::save() const {
return SavedVersion(*this);
}
VariableVersion& VariableVersion::operator=(const SavedVersion& other) {
if (!other.version_block) {
throw std::runtime_error(
"Can't take version counter from empty SavedVersion. File a bug report.");
}
version_block = other.version_block;
return *this;
}
}} // namespace torch::autograd

View File

@ -15,6 +15,7 @@
#include "torch/csrc/jit/passes/batch_mm.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"
#include <unordered_map>
@ -82,26 +83,26 @@ private:
// inplace to avoid allocations
tensor_list unwrapVariables(variable_tensor_list && list) const {
for(auto & v : list) {
v = v.defined() ? static_cast<Variable&>(v).data() : at::Tensor();
v = v.defined() ? autograd::as_variable_ref(v).data() : at::Tensor();
}
return std::move(list);
}
// inplace to avoid allocations
variable_tensor_list wrapTensors(tensor_list && list) const {
for(auto & v : list) {
v = autograd::make_variable(v);
v = autograd::make_variable(v, /*requires_grad=*/false);
}
return variable_tensor_list(std::move(list));
}
// Capture (save) inputs that would be required to subsequently run backwards
void captureInputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & inputs) const {
for(auto offset : grad.df_input_captured_inputs) {
grad_fn.captures.emplace_back(static_cast<Variable&>(inputs[offset]), false);
grad_fn.captures.emplace_back(autograd::as_variable_ref(inputs[offset]), false);
}
}
void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & outputs) const {
for(auto offset : grad.df_input_captured_outputs) {
grad_fn.captures.emplace_back(static_cast<Variable&>(outputs[offset]), true);
grad_fn.captures.emplace_back(autograd::as_variable_ref(outputs[offset]), true);
}
}
@ -111,7 +112,7 @@ private:
// hook up the outputs of df to the gradient functions of the inputs that require
// gradients
for(auto idx : grad.df_output_vjps) {
auto & v = static_cast<Variable&>(inputs[idx]);
auto & v = autograd::as_variable_ref(inputs[idx]);
// TODO: this kinda stuff is _way_ to low level to the public API of variable.
// Why do I have to care here whether v has a grad_fn or grad accumulator?
// Why do I have to care here about output_nr? I just want to say
@ -133,15 +134,12 @@ private:
// this is currently intentionally not done here so we can get an idea of our
// perf before introducing overhead for correctness
for(auto idx : grad.df_input_vjps) {
auto & o = static_cast<Variable&>(outputs[idx]);
auto impl = o.get();
// Note: we have to set this up in place, or we have to
// throw away and reallocate variables that were already created in
// wrapTensors. We should add an API for this, and more generally
// we need to clean up the fields of Variable.
impl->_grad_fn = grad_fn;
impl->output_nr = grad_fn->num_inputs++;
impl->_requires_grad = true;
// Note: we have to set this up in place, or we have to throw away and
// reallocate variables that were already created in wrapTensors. We
// should add an API for this.
auto& output = autograd::as_variable_ref(outputs[idx]);
output.set_gradient_edge(autograd::Edge(grad_fn, grad_fn->num_inputs++));
output.set_requires_grad(true);
}
captureOutputs(*grad_fn, outputs);
// drop the temporary outputs so that we return the same number of

View File

@ -1,11 +1,13 @@
#include "Python.h"
#include "interpreter.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/jit/generated/aten_dispatch.h"
#include "torch/csrc/jit/pybind.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/functions/special.h"
@ -62,12 +64,12 @@ struct HandleBuilder {
}
autograd::Variable addInput(at::Retainable* input, const VariableFlags & flags_) {
if(handle && flags_.requires_grad) {
auto gradient_edge = autograd::Edge(
handle->forward_inputs, handle->forward_inputs->num_inputs++);
return autograd::make_variable(
unsafeToTensorShare(input),
handle->forward_inputs->num_inputs++,
handle->forward_inputs);
unsafeToTensorShare(input), std::move(gradient_edge));
} else {
return autograd::make_variable(unsafeToTensorShare(input));
return autograd::make_variable(unsafeToTensorShare(input), /*requires_grad=*/false);
}
}
at::Retainable* addOutput(const autograd::Variable & output) {

View File

@ -1,7 +1,13 @@
#include "Python.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/jit/interpreter_autograd_function.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/tracer_state.h"
#include <ATen/ATen.h>
#include <algorithm>
#include <memory>
@ -135,9 +141,10 @@ autograd::variable_list InterpreterAutogradFunction::apply(
auto & flags = details.output_flags[i];
if (flags.requires_grad) { // See Note [Null-edge pruning]
if (!grad_fn) make_grad_fn();
result.push_back(autograd::make_variable(toutputs[i], grad_fn));
autograd::Edge edge(grad_fn, grad_fn->num_inputs++);
result.push_back(autograd::make_variable(toutputs[i], std::move(edge)));
} else {
result.push_back(autograd::make_variable(toutputs[i], false));
result.push_back(autograd::make_variable(toutputs[i], /*requires_grad=*/false));
}
}

View File

@ -437,7 +437,7 @@ public:
return outputs_.back();
}
void eraseOutput(size_t i);
Block * addBlock();
void eraseBlock(size_t i);

View File

@ -1,11 +1,15 @@
#pragma once
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/tracer.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;

View File

@ -113,7 +113,7 @@ struct CompiledFunction {
std::size_t num_captured = fn_.captured_vars_.size();
// Check that no captured Variables were replaced by enter. It's hard to handle that.
for (std::size_t i = num_all_inputs - num_captured; i < num_all_inputs; ++i) {
TORCH_EXPECTM(input_info.vars[i].get() == new_vars[i].get(),
TORCH_EXPECTM(input_info.vars[i].is_same(new_vars[i]),
"Some of the Variables captured by the JIT are repeated");
}
// Now only arguments to this function could have changed. Slice their vars out, and

View File

@ -463,7 +463,7 @@ struct ADTestSpec {
std::vector<Variable> make_vars() const {
std::vector<Variable> out;
for (const auto & m : input_meta) {
out.emplace_back(make_variable(at::CPU(at::kFloat).tensor(m).normal_(), true));
out.emplace_back(autograd::make_variable(at::CPU(at::kFloat).tensor(m).normal_(), /*requires_grad=*/true));
}
return out;
}

View File

@ -8,6 +8,9 @@
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/python_strings.h"
#include <string>
#include <sstream>
#include <memory>
#include <frameobject.h>
#include <patchlevel.h>

View File

@ -7,6 +7,7 @@
#include "torch/csrc/utils/variadic.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include <memory>
#include <mutex>
@ -26,12 +27,13 @@ std::string getPythonInterpreterStackTrace();
namespace detail {
inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>& state, const Variable& var, bool alloc = true) {
for (auto it = var.tracing_state()->begin(); it != var.tracing_state()->end();) {
auto& tracing_state = var.tracing_state();
for (auto it = tracing_state.begin(); it != tracing_state.end();) {
auto ts = it->state.lock();
// GC of invalidated tracing states
if (!ts) {
auto current_it = it++;
var.tracing_state()->erase(current_it);
tracing_state.erase(current_it);
continue;
} else if (ts == state) {
return &(*it);
@ -39,8 +41,8 @@ inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>&
++it;
}
if (alloc) {
var.tracing_state()->emplace_front();
auto & vts = var.tracing_state()->front();
tracing_state.emplace_front();
auto & vts = tracing_state.front();
vts.state = state;
return &vts;
} else {
@ -70,8 +72,8 @@ inline std::vector<VariableFlags> getVarFlags(const variable_list& vars) {
// need it (in most cases if we have a variable_list it is already
// flattened).
inline bool isTracingVar(const Variable& var) {
if (!var.defined() || !var.tracing_state()) return false;
return std::any_of(var.tracing_state()->begin(), var.tracing_state()->end(), detail::isElemActive);
if (!var.defined() || !var.has_tracing_state()) return false;
return std::any_of(var.tracing_state().begin(), var.tracing_state().end(), detail::isElemActive);
}
inline bool isTracingVar(at::ArrayRef<Variable> vars) {
@ -104,8 +106,8 @@ inline bool isTracing(Args&&... args) {
inline std::shared_ptr<TracingState> getTracingState(const variable_list& vars) {
std::shared_ptr<TracingState> state;
for (auto& var : vars) {
if (!var.defined() || !var.tracing_state()) continue;
for (auto & vts : *var.tracing_state()) {
if (!var.defined() || !var.has_tracing_state()) continue;
for (auto & vts : var.tracing_state()) {
auto var_state = vts.state.lock();
if (!var_state || !var_state->active) continue;
if (!state) state = var_state;

View File

@ -0,0 +1,38 @@
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/ir.h"
#include <atomic>
#include <cstdint>
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch { namespace jit { namespace tracer {
TracingState::TracingState(size_t num_stages)
: graph(new Graph()),
active(false),
num_stages(num_stages),
eval_count(0),
var_flags(num_stages),
output_edges(num_stages) {}
TracingState::~TracingState() = default;
bool TracingState::is_complete() const {
return !is_expired() && graph->stage() == num_stages - 1;
}
void TracingState::push_scope(const std::string& scope_name) {
graph->push_scope(scope_name);
}
void TracingState::pop_scope() {
graph->pop_scope();
}
}}} // namespace torch::jit::tracer

View File

@ -1,28 +1,28 @@
#pragma once
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/variable.h"
#include <memory>
#include <mutex>
#include <vector>
#include <atomic>
#include <cstdint>
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch { namespace autograd {
struct Variable;
struct Function;
}}
namespace torch { namespace jit {
struct Graph;
struct Value;
struct VariableFlags;
}} // namespace torch::jit
namespace torch { namespace jit { namespace tracer {
using torch::autograd::Variable;
using torch::autograd::Function;
using edge_list = std::vector<autograd::Edge>;
using variable_list = std::vector<Variable>;
using variable_list = std::vector<autograd::Variable>;
// TracingState tracks the necessary state when we are tracing the execution of
// autograd code; most importantly, it holds a reference to the actual IR
@ -34,26 +34,19 @@ using variable_list = std::vector<Variable>;
// from arising when a variable that participated in a trace outlives the
// actual trace itself.
using io_variable_flags_list =
std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>>;
using io_variable_flags_list = std::vector<
std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>>;
struct TracingState : public std::enable_shared_from_this<TracingState> {
TracingState(std::size_t num_stages)
: graph(new Graph())
, active(false)
, num_stages(num_stages)
, eval_count(0)
, var_flags(num_stages)
, output_edges(num_stages) {}
explicit TracingState(size_t num_stages);
~TracingState();
// XXX: graph can be NULL if it's a failed trace (failed = didn't capture all
// the stages we care about)
std::shared_ptr<Graph> graph;
bool active;
// Used to free the Graph as soon as we know this trace will fail
std::size_t num_stages;
std::atomic<std::size_t> eval_count;
size_t num_stages;
std::atomic<size_t> eval_count;
// void* is an unsafe TH. NON-OWNING, so it might get invalidated.
// TODO: Perhaps, turn this into an owning reference. The buffers
@ -66,23 +59,17 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
std::mutex mutex;
variable_list inputs; // Used only for the duration of first stage
std::unique_lock<std::mutex> lock() { return std::unique_lock<std::mutex>(mutex); };
std::unique_lock<std::mutex> lock() {
return std::unique_lock<std::mutex>(mutex);
}
bool is_expired() const {
bool is_expired() const noexcept {
return !graph;
}
bool is_complete() const {
return !is_expired() && graph->stage() == num_stages - 1;
}
void push_scope(const std::string& scope_name) {
graph->push_scope(scope_name);
}
void pop_scope() {
graph->pop_scope();
}
bool is_complete() const;
void push_scope(const std::string& scope_name);
void pop_scope();
};
struct ValueTracingStateElem {
@ -102,4 +89,4 @@ struct FunctionTracingState {
bool in_eval_subgraph = false;
};
}}}
}}} // namespace torch::jit::tracer

View File

@ -1,5 +1,7 @@
#include "torch/csrc/jit/variable_flags.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/tracer_state.h"
using torch::autograd::Variable;

View File

@ -2,7 +2,6 @@
#include <functional>
#include <vector>
#include <ATen/ATen.h>
namespace torch {

View File

@ -93,14 +93,14 @@ static Tensor new_from_data(ScalarType scalarType, PyObject* data) {
}
#ifdef WITH_NUMPY
if (PyArray_Check(data)) {
return autograd::make_variable(tensor_from_numpy(data), false);
return autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
}
#endif
auto sizes = compute_sizes(data);
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), false);
// TODO: we should pass tensor.sizes() rather than sizes, but this doesn't works
// if scalars are disabled because the size changes without WITH_SCALARS.
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), /*requires_grad=*/false);
recursive_store(
(char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
scalarType, tensor.type().elementSizeInBytes(), data);
@ -198,7 +198,7 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
}
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
static_cast<torch::autograd::Variable&>(self).get()->_requires_grad = requires_grad;
static_cast<torch::autograd::Variable&>(self).set_requires_grad(requires_grad);
return self;
}

View File

@ -1,12 +1,15 @@
#include "tuple_parser.h"
#include <string>
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/autograd/python_variable.h"
#include "python_strings.h"
#include "python_numbers.h"
#include <string>
#include <stdexcept>
#include <vector>
namespace torch {
TupleParser::TupleParser(PyObject* args, int num_args) : args(args), idx(0) {

View File

@ -3,6 +3,11 @@
#include <ATen/ATen.h>
#include "torch/csrc/autograd/variable.h"
#include <cstdint>
#include <utility>
#include <tuple>
#include <type_traits>
namespace torch {
// This class allows you to write variadic functions which
@ -84,4 +89,23 @@ inline size_t count_variables(Args&&... args) {
return CountVariables().apply(std::forward<Args>(args)...).out;
}
//===----------------------------------------------------------------------===//
// std::index_sequence shim for C++11
//===----------------------------------------------------------------------===//
// A container of type-template parameter indices.
template<size_t... Is>
struct Indices {};
// Decrements the index N, adds N-1 to the list of indices and forwards
// whatever we arleady have.
template<size_t N, size_t... Is>
struct MakeIndices : MakeIndices<N-1, N-1, Is...> {};
// Partial specialization that forms our base case. When N is zero, we stop
// and define a typedef that will be visible to earlier classes due to
// inheritance. The typedef we define is an index list containing the numbers
// 0 through N-1.
template<size_t... Is>
struct MakeIndices<0, Is...> { using indices = Indices<Is...>; };
} // namespace torch