mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve Variable interface (#5127)
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
This commit is contained in:
committed by
Soumith Chintala
parent
0ef10385b2
commit
2d5fbe6e0d
@ -37,6 +37,7 @@ BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: false
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
CompactNamespaces: true
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
|
@ -16,6 +16,7 @@ struct Storage;
|
||||
struct TensorImpl : public Retainable {
|
||||
explicit TensorImpl(Type * type)
|
||||
: is_scalar(false), type_(type) {}
|
||||
|
||||
Type & type() const {
|
||||
return *type_;
|
||||
}
|
||||
@ -49,7 +50,7 @@ struct TensorImpl : public Retainable {
|
||||
void setScalar(bool s) {
|
||||
is_scalar = s;
|
||||
}
|
||||
private:
|
||||
protected:
|
||||
bool is_scalar;
|
||||
Type * type_;
|
||||
};
|
||||
|
1
setup.py
1
setup.py
@ -474,6 +474,7 @@ main_sources = [
|
||||
"torch/csrc/jit/python_ir.cpp",
|
||||
"torch/csrc/jit/test_jit.cpp",
|
||||
"torch/csrc/jit/tracer.cpp",
|
||||
"torch/csrc/jit/tracer_state.cpp",
|
||||
"torch/csrc/jit/python_tracer.cpp",
|
||||
"torch/csrc/jit/passes/shape_analysis.cpp",
|
||||
"torch/csrc/jit/interned_strings.cpp",
|
||||
|
@ -129,7 +129,7 @@ def process_function(func):
|
||||
name = arg['name']
|
||||
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
|
||||
saved_variables.append('SavedVariable {}_;'.format(name))
|
||||
release_variables.append('{}_.data.reset();'.format(name))
|
||||
release_variables.append('{}_.reset_data();'.format(name))
|
||||
ptr = 'shared_from_this()' if is_output else ''
|
||||
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
|
||||
elif arg['type'] == 'TensorList':
|
||||
|
@ -478,7 +478,7 @@ Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor &
|
||||
#ifdef WITH_SCALARS
|
||||
grad_input.masked_fill_(input == value, grad);
|
||||
#else
|
||||
auto grad_data = static_cast<Variable&>(grad).data();
|
||||
auto grad_data = as_variable_ref(grad).data();
|
||||
grad_input.masked_fill_(input == value, Scalar(grad_data[0]));
|
||||
#endif
|
||||
return grad_input;
|
||||
@ -1088,9 +1088,9 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
||||
for (auto s : input.sizes().slice(2)) {
|
||||
M *= s;
|
||||
}
|
||||
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean), input);
|
||||
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean, /*requires_grad=*/false), input);
|
||||
auto input_sub_mu = input - mu;
|
||||
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5)), input);
|
||||
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), /*requires_grad=*/false), input);
|
||||
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
|
||||
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
#include "torch/csrc/autograd/saved_variable.h"
|
||||
#include "torch/csrc/autograd/generated/Functions.h"
|
||||
@ -28,7 +29,6 @@ using namespace at;
|
||||
using namespace torch::autograd::generated;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
// Helper methods for working with Attributes (torch/csrc/jit/attributes.h)
|
||||
|
||||
// The overloaded accessors are convenient for the generated code (since we
|
||||
@ -74,7 +74,7 @@ std::unique_ptr<Storage> VariableType::storageWithAllocator(int64_t size, std::u
|
||||
return baseType->storageWithAllocator(size, std::move(allocator));
|
||||
}
|
||||
Tensor VariableType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
|
||||
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), false);
|
||||
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), /*requires_grad=*/false);
|
||||
}
|
||||
std::unique_ptr<Generator> VariableType::generator() const {
|
||||
return baseType->generator();
|
||||
@ -164,7 +164,7 @@ Variable & VariableType::checked_cast_variable(const Tensor & t, const char * na
|
||||
runtime_error("Expected object of type Variable but found type %s for argument #%d '%s'",
|
||||
t.type().toString(), pos, name);
|
||||
}
|
||||
return static_cast<Variable&>(const_cast<Tensor&>(t));
|
||||
return as_variable_ref(const_cast<Tensor&>(t));
|
||||
}
|
||||
|
||||
Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
|
||||
@ -207,49 +207,35 @@ static std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
|
||||
return SavedVariable{tensor, false /* is output */}; });
|
||||
}
|
||||
|
||||
template <typename... Tensors, size_t... Is>
|
||||
std::tuple<Tensors...> as_variable_impl(
|
||||
std::tuple<Tensors...> tensors,
|
||||
Indices<Is...>) {
|
||||
// Expand the integer parameter pack into a sequence of Variable
|
||||
// constructions. This turns into (boolean omitted):
|
||||
// Variable(std::get<0>(tensors)), Variable(std::get<1>(tensors)), ...
|
||||
return std::tuple<Tensors...>(
|
||||
make_variable(std::get<Is>(tensors), /*requires_grad=*/false)...);
|
||||
}
|
||||
|
||||
template <typename... Tensors>
|
||||
std::tuple<Tensors...> as_variable(std::tuple<Tensors...> tensors) {
|
||||
// `sizeof...(Tensors)` gets us the size of the `Tensors` parameter pack at
|
||||
// compile time. We use it to parameterize a `MakeIndices` class, which will
|
||||
// expand into an Indices object containing the numbers 0 to
|
||||
// sizeof...(Tensors) - 1.
|
||||
return as_variable_impl(
|
||||
tensors, typename MakeIndices<sizeof...(Tensors)>::indices());
|
||||
}
|
||||
|
||||
static Tensor as_variable(Tensor tensor) {
|
||||
return make_variable(std::move(tensor));
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor>
|
||||
as_variable(std::tuple<Tensor, Tensor> tensors) {
|
||||
return std::make_tuple<>(
|
||||
make_variable(std::move(std::get<0>(tensors))),
|
||||
make_variable(std::move(std::get<1>(tensors))));
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor>
|
||||
as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) {
|
||||
return std::make_tuple<>(
|
||||
make_variable(std::move(std::get<0>(tensors))),
|
||||
make_variable(std::move(std::get<1>(tensors))),
|
||||
make_variable(std::move(std::get<2>(tensors))));
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor, Tensor>
|
||||
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) {
|
||||
return std::make_tuple<>(
|
||||
make_variable(std::move(std::get<0>(tensors))),
|
||||
make_variable(std::move(std::get<1>(tensors))),
|
||||
make_variable(std::move(std::get<2>(tensors))),
|
||||
make_variable(std::move(std::get<3>(tensors))));
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
|
||||
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> tensors) {
|
||||
return std::make_tuple<>(
|
||||
make_variable(std::move(std::get<0>(tensors))),
|
||||
make_variable(std::move(std::get<1>(tensors))),
|
||||
make_variable(std::move(std::get<2>(tensors))),
|
||||
make_variable(std::move(std::get<3>(tensors))),
|
||||
make_variable(std::move(std::get<4>(tensors)))
|
||||
);
|
||||
return make_variable(std::move(tensor), /*requires_grad=*/false);
|
||||
}
|
||||
|
||||
static std::vector<Tensor> as_variable(TensorList tl) {
|
||||
std::vector<Tensor> variables;
|
||||
for (auto& t : tl) {
|
||||
variables.emplace_back(make_variable(std::move(t)));
|
||||
variables.emplace_back(make_variable(std::move(t), /*requires_grad=*/false));
|
||||
}
|
||||
return variables;
|
||||
}
|
||||
@ -316,20 +302,20 @@ static void throw_error_out_requires_grad(const char* name) {
|
||||
|
||||
static void rebase_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
|
||||
if (grad_fn && tensor.defined()) {
|
||||
auto& var = static_cast<Variable&>(tensor);
|
||||
auto& var = as_variable_ref(tensor);
|
||||
grad_fn->num_inputs = 1;
|
||||
var.rebase_history(0, std::move(grad_fn));
|
||||
var.rebase_history({std::move(grad_fn), 0});
|
||||
}
|
||||
}
|
||||
|
||||
static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
|
||||
if (grad_fn) {
|
||||
grad_fn->num_inputs = tensors.size();
|
||||
int output_nr = 0;
|
||||
uint32_t output_nr = 0;
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor.defined()) {
|
||||
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
|
||||
var.rebase_history(output_nr, grad_fn);
|
||||
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
|
||||
var.rebase_history({grad_fn, output_nr});
|
||||
}
|
||||
output_nr++;
|
||||
}
|
||||
@ -340,22 +326,20 @@ static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn
|
||||
// overload for functions with multiple differentiable outputs.
|
||||
static void set_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
|
||||
if (grad_fn && tensor.defined()) {
|
||||
auto& var = static_cast<Variable&>(tensor);
|
||||
auto& var = as_variable_ref(tensor);
|
||||
grad_fn->num_inputs = 1;
|
||||
var.get()->output_nr = 0;
|
||||
var.get()->_grad_fn = std::move(grad_fn);
|
||||
var.set_gradient_edge({std::move(grad_fn), 0});
|
||||
}
|
||||
}
|
||||
|
||||
static void set_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
|
||||
if (grad_fn) {
|
||||
grad_fn->num_inputs = tensors.size();
|
||||
int64_t output_nr = 0;
|
||||
uint32_t output_nr = 0;
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor.defined()) {
|
||||
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
|
||||
var.get()->output_nr = output_nr;
|
||||
var.get()->_grad_fn = grad_fn;
|
||||
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
|
||||
var.set_gradient_edge({grad_fn, output_nr});
|
||||
}
|
||||
output_nr++;
|
||||
}
|
||||
@ -378,9 +362,8 @@ template<typename... Args> inline variable_list flatten(Args&&... args) {
|
||||
return out; // RVO
|
||||
}
|
||||
|
||||
static void increment_version(const Tensor & t) {
|
||||
auto& var = static_cast<const Variable&>(t);
|
||||
var.version_counter().increment();
|
||||
static void increment_version(Tensor & t) {
|
||||
as_variable_ref(t).bump_version();
|
||||
}
|
||||
|
||||
static bool isFloatingPoint(ScalarType s) {
|
||||
@ -411,7 +394,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
|
||||
|
||||
Tensor & VariableType::resize_(Tensor & self, IntList size) const {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
if (static_cast<Variable&>(self).requires_grad()) {
|
||||
if (as_variable_ref(self).requires_grad()) {
|
||||
at::runtime_error("cannot resize variables that require grad");
|
||||
}
|
||||
baseType->resize_(self_, size);
|
||||
@ -421,7 +404,7 @@ Tensor & VariableType::resize_(Tensor & self, IntList size) const {
|
||||
Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) const {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
auto& the_template_ = unpack(the_template, "the_template", 1);
|
||||
if (static_cast<Variable&>(self).requires_grad()) {
|
||||
if (as_variable_ref(self).requires_grad()) {
|
||||
at::runtime_error("cannot resize variables that require grad");
|
||||
}
|
||||
baseType->resize_as_(self_, the_template_);
|
||||
|
@ -3,6 +3,10 @@
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cstdint> // for size_t
|
||||
#include <functional> // for function
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -56,7 +60,6 @@ private:
|
||||
static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
|
||||
static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
|
||||
|
||||
private:
|
||||
at::Type* baseType;
|
||||
std::string str;
|
||||
};
|
||||
|
@ -25,7 +25,7 @@ using namespace torch::autograd::utils;
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
||||
static_cast<Variable&>(self).get()->_requires_grad = requires_grad;
|
||||
as_variable_ref(self).set_requires_grad(requires_grad);
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto data = torch::utils::tensor_from_numpy(arg);
|
||||
return THPVariable_Wrap(make_variable(std::move(data)));
|
||||
return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
@ -6,8 +6,7 @@
|
||||
|
||||
#include "torch/csrc/utils/hash.h"
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Function;
|
||||
|
||||
@ -15,8 +14,8 @@ struct Function;
|
||||
struct Edge {
|
||||
Edge() noexcept : function(nullptr), input_nr(0) {}
|
||||
|
||||
Edge(const std::shared_ptr<Function>& function_, uint32_t input_nr_) noexcept
|
||||
: function(function_), input_nr(input_nr_) {}
|
||||
Edge(std::shared_ptr<Function> function_, uint32_t input_nr_) noexcept
|
||||
: function(std::move(function_)), input_nr(input_nr_) {}
|
||||
|
||||
/// Convenience method to test if an edge is valid.
|
||||
bool is_valid() const noexcept {
|
||||
@ -38,8 +37,7 @@ struct Edge {
|
||||
/// The identifier of a particular input to the function.
|
||||
uint32_t input_nr;
|
||||
};
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
}} // namespace torch::autograd
|
||||
|
||||
// The idiomatic way of enabling use of a custom type as the key of hash
|
||||
// containers in C++11. This method removes the requirement of having to pass
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
// A hook that's called on gradients
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include "torch/csrc/autograd/functions/utils.h"
|
||||
#include "torch/csrc/utils/auto_gpu.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
@ -19,7 +21,7 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
|
||||
outputs.reserve(inputs.size());
|
||||
for (auto& var : inputs) {
|
||||
// FIXME: share version counters
|
||||
outputs.emplace_back(var.defined() ? var.data() : Tensor());
|
||||
outputs.emplace_back(var.defined() ? var.data() : at::Tensor());
|
||||
}
|
||||
return wrap_outputs(inputs, std::move(outputs), [&](function_list&& next_functions) {
|
||||
return std::make_shared<Error>(msg, std::move(next_functions));
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "torch/csrc/autograd/python_engine.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
@ -264,11 +265,10 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
|
||||
// This output is already rebased. This happens when there
|
||||
// the same Variable has been returned multiple times, and
|
||||
// is repeated in this list.
|
||||
if (output.get()->_grad_fn.get() == this) {
|
||||
if (output.grad_fn_unsafe() == this) {
|
||||
auto replicate = std::make_shared<Replicate>();
|
||||
replicate->next_functions.emplace_back(this_shared, output.output_nr());
|
||||
output.get()->_grad_fn = replicate;
|
||||
output.get()->output_nr = 0;
|
||||
output.set_gradient_edge({std::move(replicate), 0});
|
||||
repeated_outputs.emplace(&output);
|
||||
}
|
||||
// NOTE: this check should be fairly cheap, and the set shouldn't
|
||||
@ -277,8 +277,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
|
||||
auto & replicate = output.grad_fn();
|
||||
replicate->next_functions.emplace_back(this_shared, num_inputs++);
|
||||
} else {
|
||||
output.get()->_grad_fn = this_shared;
|
||||
output.get()->output_nr = num_inputs++;
|
||||
output.set_gradient_edge(Edge(this_shared, num_inputs++));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include "torch/csrc/autograd/functions/utils.h"
|
||||
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
@ -14,7 +16,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
||||
if (!any_variable_requires_grad(inputs)) {
|
||||
for (auto& output : outputs) {
|
||||
if (output.defined()) {
|
||||
result.emplace_back(make_variable(output, false));
|
||||
result.push_back(make_variable(output, /*requires_grad=*/false));
|
||||
} else {
|
||||
result.emplace_back();
|
||||
}
|
||||
@ -23,7 +25,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
||||
auto grad_fn = ctr(get_next_functions(inputs));
|
||||
for (auto& output : outputs) {
|
||||
if (output.defined()) {
|
||||
result.emplace_back(make_variable(output, grad_fn));
|
||||
result.push_back(make_variable(output, Edge(grad_fn, grad_fn->num_inputs++)));
|
||||
} else {
|
||||
++grad_fn->num_inputs;
|
||||
result.emplace_back();
|
||||
|
@ -142,10 +142,10 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
|
||||
THPUtils_assert(THPVariable_Check(input),
|
||||
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
|
||||
THPVariable *input_var = (THPVariable*)input;
|
||||
int output_nr = input_var->cdata.output_nr();
|
||||
const auto output_nr = input_var->cdata.output_nr();
|
||||
auto grad_fn = input_var->cdata.grad_fn();
|
||||
if (!grad_fn) {
|
||||
grad_fn = input_var->cdata.get()->grad_accumulator.lock();
|
||||
grad_fn = input_var->cdata.try_get_grad_accumulator();
|
||||
}
|
||||
THPUtils_assert(input_var->cdata.requires_grad(),
|
||||
"One of the differentiated Variables does not require grad");
|
||||
|
@ -323,7 +323,7 @@ static std::vector<PyObject*> _mark_dirty(THPFunction *self)
|
||||
|
||||
dirty_inputs.push_back(obj);
|
||||
auto variable = (THPVariable*)obj;
|
||||
variable->cdata.version_counter().increment();
|
||||
variable->cdata.bump_version();
|
||||
}
|
||||
// We're not going to ever need this so let's remove references now
|
||||
Py_CLEAR(self->dirty_tensors);
|
||||
@ -368,14 +368,14 @@ static void _wrap_outputs(THPFunction *self,
|
||||
}
|
||||
if (THPModule_isTensor(obj)) {
|
||||
// temporarily wrap tensors as variables until the classes are merged
|
||||
return make_variable(createTensor(obj));
|
||||
return make_variable(createTensor(obj), /*requires_grad=*/false);
|
||||
}
|
||||
throw TypeError("%s.forward: expected Variable (got %s) for return value %d",
|
||||
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
|
||||
};
|
||||
|
||||
// Sets the grad_fn and output_nr of an output Variable.
|
||||
auto set_history = [&](Variable& var, int output_nr, bool is_input, bool is_modified,
|
||||
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
|
||||
bool is_differentiable) {
|
||||
if (!is_differentiable) {
|
||||
if (!var.requires_grad()) return;
|
||||
@ -393,24 +393,22 @@ static void _wrap_outputs(THPFunction *self,
|
||||
}
|
||||
// If the input was modified, transplant the grad_fn in the graph:
|
||||
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
|
||||
var.get()->grad.reset();
|
||||
var.get()->hooks.clear();
|
||||
if (auto grad_acc_fn = var.get()->grad_accumulator.lock()) {
|
||||
var.reset_grad();
|
||||
var.clear_hooks();
|
||||
if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
|
||||
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
|
||||
grad_acc->variable.reset();
|
||||
}
|
||||
if (cdata) {
|
||||
var.rebase_history(output_nr, cdata);
|
||||
var.rebase_history({cdata, output_nr});
|
||||
}
|
||||
} else if (is_input) {
|
||||
// An input has been returned, but it wasn't modified. Return it as a view
|
||||
// so that we can attach a new grad_fn to the Variable.
|
||||
var = var.slice();
|
||||
var.get()->output_nr = output_nr;
|
||||
var.get()->_grad_fn = cdata;
|
||||
var.set_gradient_edge({cdata, output_nr});
|
||||
} else if (cdata) {
|
||||
var.get()->output_nr = output_nr;
|
||||
var.get()->_grad_fn = cdata;
|
||||
var.set_gradient_edge({cdata, output_nr});
|
||||
}
|
||||
};
|
||||
|
||||
@ -457,7 +455,7 @@ static void _save_variables(THPFunction* self)
|
||||
self->saved_variables.emplace_back(variable->cdata, is_output);
|
||||
} else if (THPModule_isTensor(obj)) {
|
||||
// TODO: remove once Variable and Tensor classes are merged
|
||||
auto var = make_variable(createTensor(obj), false);
|
||||
auto var = make_variable(createTensor(obj), /*requires_grad=*/false);
|
||||
self->saved_variables.emplace_back(std::move(var), false);
|
||||
} else {
|
||||
throw TypeError(
|
||||
|
@ -1,12 +1,11 @@
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
|
||||
#include <structmember.h>
|
||||
|
||||
#include "THP.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Types.h"
|
||||
#include "torch/csrc/autograd/python_cpp_function.h"
|
||||
#include "torch/csrc/autograd/python_hook.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/python_variable_indexing.h"
|
||||
#include "torch/csrc/autograd/functions/accumulate_grad.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
@ -16,7 +15,15 @@
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
#include "torch/csrc/Size.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/generated/VariableType.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <structmember.h>
|
||||
|
||||
using namespace at;
|
||||
using namespace torch::autograd;
|
||||
@ -35,11 +42,13 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
|
||||
if (obj) {
|
||||
auto v = (THPVariable*) obj;
|
||||
new (&v->cdata) Variable(std::move(var));
|
||||
v->cdata.get()->pyobj = obj;
|
||||
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.get()->_grad_fn.get())) {
|
||||
v->cdata.set_pyobj(obj);
|
||||
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn_unsafe())) {
|
||||
// Create a new reference to the THPFunction. This ensures that ref count
|
||||
// of the THPFunction is at least the number of referring THPVariables.
|
||||
v->cdata.get()->_grad_fn = THPFunction_asFunction((THPFunction*)fn->obj);
|
||||
const auto output_nr = v->cdata.output_nr();
|
||||
v->cdata.set_gradient_edge(
|
||||
{THPFunction_asFunction((THPFunction*)fn->obj), output_nr});
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
@ -57,7 +66,7 @@ PyObject * THPVariable_Wrap(Variable var)
|
||||
}
|
||||
#endif
|
||||
|
||||
if (auto obj = var.get()->pyobj) {
|
||||
if (auto obj = var.pyobj()) {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
@ -84,7 +93,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
// for more details about the race condition involving traversing the grad_fn
|
||||
// and the python GC.
|
||||
if (self->cdata.defined()) {
|
||||
for (auto& hook : self->cdata.hooks()) {
|
||||
for (const auto& hook : self->cdata.hooks()) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
}
|
||||
@ -98,10 +107,10 @@ static int THPVariable_clear(THPVariable *self)
|
||||
Py_CLEAR(self->data);
|
||||
Py_CLEAR(self->backward_hooks);
|
||||
if (self->cdata.defined()) {
|
||||
if (auto grad_acc = self->cdata.get()->grad_accumulator.lock()) {
|
||||
if (auto grad_acc = self->cdata.try_get_grad_accumulator()) {
|
||||
grad_acc->pre_hooks.clear();
|
||||
}
|
||||
self->cdata.get()->pyobj = nullptr;
|
||||
self->cdata.set_pyobj(nullptr);
|
||||
}
|
||||
self->cdata.reset();
|
||||
return 0;
|
||||
@ -154,13 +163,15 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
Variable var;
|
||||
if (grad_fn) {
|
||||
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
|
||||
var = make_variable(torch::createTensor(data), grad_fn_);
|
||||
Edge edge(grad_fn_, grad_fn_->num_inputs++);
|
||||
var = make_variable(torch::createTensor(data), std::move(edge));
|
||||
} else {
|
||||
var = make_variable(torch::createTensor(data), requires_grad);
|
||||
}
|
||||
|
||||
if (name)
|
||||
var.name() = std::string(name);
|
||||
if (name) {
|
||||
var.set_name(name);
|
||||
}
|
||||
|
||||
PyObject* self = THPVariable_NewWithVar(type, std::move(var));
|
||||
if (self) {
|
||||
@ -223,7 +234,7 @@ int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
|
||||
self->cdata.get()->_grad_fn = nullptr;
|
||||
self->cdata.set_gradient_edge(Edge());
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
@ -246,30 +257,6 @@ PyObject * THPVariable_get_data(THPVariable *self)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// XXX: This is a hack to access private TensorImpl::type_
|
||||
// http://bloglitb.blogspot.com/2011/12/access-to-private-members-safer.html
|
||||
// This is currently needed because module.float() changes the type of the
|
||||
// data field of each variable. We should fix this and not allow changing the
|
||||
// type of var.data.
|
||||
|
||||
template<typename Tag, typename Tag::type M>
|
||||
struct Rob {
|
||||
friend typename Tag::type get(Tag) {
|
||||
return M;
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorImpl_Type {
|
||||
typedef Type* TensorImpl::*type;
|
||||
friend type get(TensorImpl_Type);
|
||||
};
|
||||
|
||||
template struct Rob<TensorImpl_Type, &TensorImpl::type_>;
|
||||
|
||||
}
|
||||
|
||||
int THPVariable_set_data(THPVariable *self, PyObject *data)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -282,7 +269,7 @@ int THPVariable_set_data(THPVariable *self, PyObject *data)
|
||||
if (&self->cdata.data().type() != &tensor.type()) {
|
||||
// we change the type of var.data so we must change the type of var
|
||||
auto newType = VariableType::getType(tensor);
|
||||
self->cdata.get()->*get(TensorImpl_Type()) = newType;
|
||||
self->cdata.temporary_hack_set_type(newType);
|
||||
}
|
||||
self->cdata.data() = tensor;
|
||||
return 0;
|
||||
@ -301,7 +288,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
|
||||
HANDLE_TH_ERRORS
|
||||
auto& var = self->cdata;
|
||||
if (py_grad == Py_None) {
|
||||
var.grad().reset();
|
||||
var.reset_grad();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -342,7 +329,8 @@ int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
|
||||
PyObject *THPVariable_get_output_nr(THPVariable *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyInt_FromLong(self->cdata.output_nr());
|
||||
const auto output_nr = static_cast<long>(self->cdata.output_nr());
|
||||
return PyInt_FromLong(output_nr);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -368,7 +356,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
|
||||
return -1;
|
||||
}
|
||||
var.get()->_requires_grad = (obj == Py_True);
|
||||
var.set_requires_grad(obj == Py_True);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
@ -400,9 +388,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
|
||||
Py_XINCREF(obj);
|
||||
Py_XDECREF(self->backward_hooks);
|
||||
self->backward_hooks = obj;
|
||||
self->cdata.hooks().clear();
|
||||
self->cdata.clear_hooks();
|
||||
if (obj) {
|
||||
self->cdata.hooks().emplace_back(new PyFunctionPreHook(obj, 0));
|
||||
self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
|
||||
}
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
|
@ -3,8 +3,10 @@
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
#include "torch/csrc/THP_export.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/utils/python_compat.h"
|
||||
#include "torch/csrc/utils/python_numbers.h"
|
||||
#include "torch/csrc/utils/tensor_new.h"
|
||||
@ -115,7 +117,7 @@ static Variable valueToTensor(const Type & type, PyObject* value) {
|
||||
return type.scalarTensor(Scalar(THPUtils_unpackDouble(value)));
|
||||
}
|
||||
if (THPModule_isTensor(value)) {
|
||||
return make_variable(createTensor(value));
|
||||
return make_variable(createTensor(value), /*requires_grad=*/false);
|
||||
}
|
||||
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
|
||||
}
|
||||
@ -152,7 +154,7 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
handle_var(reinterpret_cast<THPVariable*>(obj)->cdata);
|
||||
} else if (THPModule_isTensor(obj)) {
|
||||
handle_var(make_variable(createTensor(obj)));
|
||||
handle_var(make_variable(createTensor(obj), /*requires_grad=*/false));
|
||||
} else if (PySequence_Check(obj)) {
|
||||
handle_var(sequenceToVariable(self.type(), obj));
|
||||
} else {
|
||||
@ -270,7 +272,7 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
|
||||
variable_list variableIndices;
|
||||
Variable sliced = applySlicing(self_, holder.get(), variableIndices);
|
||||
if (variableIndices.empty()) {
|
||||
if (sliced.get() == self_.get()) {
|
||||
if (sliced.is_same(self_)) {
|
||||
// ensure we return a shallow copy for things like x[...]
|
||||
sliced = at::alias(sliced);
|
||||
}
|
||||
|
@ -1,49 +1,57 @@
|
||||
#include "torch/csrc/autograd/saved_variable.h"
|
||||
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
|
||||
using namespace at;
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
SavedVariable::SavedVariable(const Variable& variable, bool is_output)
|
||||
: SavedVariable() {
|
||||
if (!variable.defined()) {
|
||||
return;
|
||||
}
|
||||
data = variable.data();
|
||||
requires_grad = variable.requires_grad();
|
||||
expected_version = variable.current_version();
|
||||
version = variable.get()->version_counter.save();
|
||||
has_grad_fn = !variable.is_leaf();
|
||||
output_nr = variable.output_nr();
|
||||
if (!has_grad_fn) {
|
||||
grad_accumulator = variable.grad_accumulator();
|
||||
}
|
||||
if (!is_output) {
|
||||
_grad_fn = variable.grad_fn();
|
||||
}
|
||||
if (variable.tracing_state()) {
|
||||
tracing_state.reset(new jit::tracer::ValueTracingState(*variable.tracing_state()));
|
||||
: output_nr_(variable.output_nr()),
|
||||
requires_grad_(variable.requires_grad()),
|
||||
has_grad_fn_(!variable.is_leaf()) {
|
||||
if (variable.defined()) {
|
||||
was_default_constructed_ = false;
|
||||
// These copies are all shared_ptr copies, so slightly more expensive.
|
||||
// Do them here instead of in the init list in case data is undefined.
|
||||
data_ = variable.data();
|
||||
if (variable.is_leaf()) {
|
||||
grad_accumulator_ = variable.grad_accumulator();
|
||||
} else if (!is_output) {
|
||||
grad_fn_ = variable.grad_fn();
|
||||
}
|
||||
version_counter_ = variable.version_counter();
|
||||
saved_version_ = version_counter_.current_version();
|
||||
if (variable.has_tracing_state()) {
|
||||
tracing_state_.reset(
|
||||
new jit::tracer::ValueTracingState(variable.tracing_state()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variable {
|
||||
if (!data.defined()) {
|
||||
if (version.defined()) {
|
||||
Variable SavedVariable::unpack(std::shared_ptr<Function> saved_for) const {
|
||||
if (!data_.defined()) {
|
||||
if (!was_default_constructed_) {
|
||||
throw std::runtime_error(ERR_BACKWARD_TWICE);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
|
||||
if (version.is_modified()) {
|
||||
if (saved_version_ != version_counter_.current_version()) {
|
||||
throw std::runtime_error(
|
||||
"one of the variables needed for gradient computation has been "
|
||||
"modified by an inplace operation");
|
||||
}
|
||||
|
||||
auto grad_fn = _grad_fn;
|
||||
if (has_grad_fn && !grad_fn) {
|
||||
auto grad_fn = grad_fn_;
|
||||
if (has_grad_fn_ && !grad_fn) {
|
||||
if (!saved_for) {
|
||||
// If saving the grad_fn would create a circular reference, then it must
|
||||
// be passed in to the unpack function.
|
||||
@ -57,20 +65,22 @@ auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variabl
|
||||
// in-place functions on unpacked variables.
|
||||
Variable var;
|
||||
if (grad_fn) {
|
||||
var = make_variable(data, output_nr, std::move(grad_fn));
|
||||
var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
|
||||
} else {
|
||||
var = make_variable(data, requires_grad);
|
||||
var = make_variable(data_, requires_grad_);
|
||||
}
|
||||
var.version_counter() = version;
|
||||
var.set_version_counter(saved_version_);
|
||||
|
||||
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
|
||||
// should have saved the grad accumulator. Even if the Variable no longer
|
||||
// alive, the accumulator should be kept alive by the references in the graph).
|
||||
if (requires_grad && !var.grad_fn() && grad_accumulator.expired())
|
||||
// alive, the accumulator should be kept alive by the references in the
|
||||
// graph).
|
||||
if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
|
||||
throw std::logic_error("No grad accumulator for a saved leaf!");
|
||||
var.get()->grad_accumulator = grad_accumulator;
|
||||
if (tracing_state)
|
||||
var.tracing_state().reset(new jit::tracer::ValueTracingState(*tracing_state));
|
||||
var.set_grad_accumulator(grad_accumulator_);
|
||||
if (tracing_state_) {
|
||||
var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_));
|
||||
}
|
||||
|
||||
return var;
|
||||
}
|
||||
|
@ -1,46 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "torch/csrc/autograd/variable_version.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/variable_version.h"
|
||||
#include "torch/csrc/Types.h"
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Variable;
|
||||
struct Function;
|
||||
|
||||
extern const char* ERR_BACKWARD_TWICE;
|
||||
|
||||
struct SavedVariable {
|
||||
SavedVariable()
|
||||
: data()
|
||||
, has_grad_fn(false)
|
||||
, version()
|
||||
, requires_grad(false)
|
||||
, expected_version(-1) {}
|
||||
|
||||
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
|
||||
/// enough information to reconstruct a variable from a certain point in time.
|
||||
class SavedVariable {
|
||||
public:
|
||||
SavedVariable() = default;
|
||||
SavedVariable(const Variable& variable, bool is_output);
|
||||
SavedVariable(SavedVariable&&) = default;
|
||||
SavedVariable& operator=(SavedVariable&&) = default;
|
||||
|
||||
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
|
||||
/// function if constructing the `SavedVariable` with it would have caused a
|
||||
/// circular reference.
|
||||
Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const;
|
||||
|
||||
void reset_data() {
|
||||
return data_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor data_;
|
||||
|
||||
at::Tensor data;
|
||||
// The gradient function associated with this node. If has_grad_fn
|
||||
// is false, then this is a leaf node. Note that the grad_fn is not saved if
|
||||
// it would create a circular reference. In that case, the grad_fn must be
|
||||
// passed in to the unpack function when reconstructing the Variable.
|
||||
bool has_grad_fn;
|
||||
std::shared_ptr<Function> _grad_fn;
|
||||
std::weak_ptr<Function> grad_accumulator;
|
||||
SavedVersion version;
|
||||
bool requires_grad;
|
||||
int expected_version;
|
||||
int output_nr;
|
||||
std::unique_ptr<jit::tracer::ValueTracingState> tracing_state;
|
||||
std::shared_ptr<Function> grad_fn_;
|
||||
std::weak_ptr<Function> grad_accumulator_;
|
||||
std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
|
||||
VariableVersion version_counter_;
|
||||
|
||||
Variable unpack(std::shared_ptr<Function> saved_for=nullptr) const;
|
||||
uint32_t saved_version_;
|
||||
uint32_t output_nr_;
|
||||
bool was_default_constructed_ = true;
|
||||
bool requires_grad_;
|
||||
bool has_grad_fn_;
|
||||
};
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -1,84 +1,82 @@
|
||||
#include "Python.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/autograd/generated/VariableType.h"
|
||||
#include "torch/csrc/autograd/generated/Functions.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/functions/accumulate_grad.h"
|
||||
#include "torch/csrc/autograd/functions/tensor.h"
|
||||
#include "torch/csrc/autograd/generated/Functions.h"
|
||||
#include "torch/csrc/autograd/generated/VariableType.h"
|
||||
#include "torch/csrc/autograd/variable_version.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/utils/auto_unique_ptr.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
using namespace at;
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) {
|
||||
// TODO: If you ever want to support returning an undefined tensor from
|
||||
// a function, you'll have to uncomment the line below. Not sure if
|
||||
// we actually want to support this.
|
||||
// if (!data.defined()) return Variable();
|
||||
TORCH_ASSERT(grad_fn);
|
||||
int output_nr = grad_fn->num_inputs++;
|
||||
return make_variable(std::move(data), output_nr, std::move(grad_fn));
|
||||
}
|
||||
|
||||
VariableImpl::VariableImpl(Tensor data_, bool requires_grad, int output_nr, std::shared_ptr<Function> grad_fn)
|
||||
: TensorImpl(VariableType::getType(data_))
|
||||
, data(std::move(data_))
|
||||
, grad()
|
||||
, _grad_fn(std::move(grad_fn))
|
||||
, version_counter()
|
||||
, _requires_grad(requires_grad)
|
||||
, is_view(false)
|
||||
, output_nr(output_nr)
|
||||
, pyobj(nullptr) {
|
||||
TORCH_ASSERTM(!_grad_fn || !_requires_grad, "_requires_grad should be false if grad_fn is set");
|
||||
Variable::Impl::Impl(at::Tensor data_, bool requires_grad_, Edge gradient_edge_)
|
||||
: TensorImpl(VariableType::getType(data_)),
|
||||
data(std::move(data_)),
|
||||
grad_fn(std::move(gradient_edge_.function)),
|
||||
requires_grad(requires_grad_),
|
||||
is_view(false),
|
||||
output_nr(gradient_edge_.input_nr),
|
||||
pyobj(nullptr) {
|
||||
TORCH_ASSERTM(
|
||||
!grad_fn || !requires_grad,
|
||||
"_requires_grad should be false if grad_fn is set");
|
||||
if (!data.defined()) {
|
||||
throw std::runtime_error("data is undefined");
|
||||
}
|
||||
}
|
||||
|
||||
VariableImpl::~VariableImpl() {
|
||||
}
|
||||
Variable::Impl::~Impl() = default;
|
||||
|
||||
const char * VariableImpl::toString() const {
|
||||
const char* Variable::Impl::toString() const {
|
||||
return "Variable";
|
||||
}
|
||||
|
||||
IntList VariableImpl::sizes() const {
|
||||
IntList Variable::Impl::sizes() const {
|
||||
return data.sizes();
|
||||
}
|
||||
|
||||
IntList VariableImpl::strides() const {
|
||||
IntList Variable::Impl::strides() const {
|
||||
return data.strides();
|
||||
}
|
||||
|
||||
int64_t VariableImpl::dim() const {
|
||||
int64_t Variable::Impl::dim() const {
|
||||
return data.dim();
|
||||
}
|
||||
|
||||
const char * VariableImpl::typeString() {
|
||||
const char* Variable::Impl::typeString() {
|
||||
return "VariableType";
|
||||
}
|
||||
|
||||
void * VariableImpl::unsafeGetTH(bool retain) {
|
||||
void* Variable::Impl::unsafeGetTH(bool retain) {
|
||||
return data.unsafeGetTH(retain);
|
||||
}
|
||||
|
||||
std::unique_ptr<at::Storage> VariableImpl::storage() {
|
||||
std::unique_ptr<at::Storage> Variable::Impl::storage() {
|
||||
return data.storage();
|
||||
}
|
||||
|
||||
Scalar VariableImpl::localScalar() {
|
||||
Scalar Variable::Impl::localScalar() {
|
||||
return data.pImpl->localScalar();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
|
||||
if (_grad_fn) {
|
||||
throw std::logic_error("get_grad_accumulator() should be only called on leaf Variables");
|
||||
std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
|
||||
if (grad_fn) {
|
||||
throw std::logic_error(
|
||||
"get_grad_accumulator() should be only called on leaf Variables");
|
||||
}
|
||||
if (!_requires_grad) {
|
||||
if (!requires_grad) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -92,11 +90,12 @@ std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
|
||||
return result;
|
||||
}
|
||||
|
||||
VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, int output_nr,
|
||||
std::shared_ptr<Function> grad_fn)
|
||||
: VariableImpl(std::move(data_), false, output_nr, std::move(grad_fn))
|
||||
, base(std::move(base_))
|
||||
, attr_version(0) {
|
||||
Variable::ViewImpl::ViewImpl(
|
||||
Variable base_,
|
||||
at::Tensor data_,
|
||||
Edge gradient_edge_)
|
||||
: Variable::Impl(std::move(data_), false, std::move(gradient_edge_)),
|
||||
base(std::move(base_)) {
|
||||
TORCH_ASSERTM(base.defined(), "base is undefined");
|
||||
if (base.is_view()) {
|
||||
base = base.base();
|
||||
@ -106,52 +105,71 @@ VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, int output_
|
||||
attr_version = version_counter.current_version();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
|
||||
std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!_grad_fn && !base.requires_grad()) {
|
||||
return _grad_fn;
|
||||
if (!grad_fn && !base.requires_grad()) {
|
||||
return grad_fn;
|
||||
}
|
||||
auto current_version = version_counter.current_version();
|
||||
if (attr_version != current_version) {
|
||||
TORCH_ASSERT(output_nr == 0);
|
||||
auto fn = std::make_shared<generated::AsStridedBackward>();
|
||||
fn->self_geometry = TensorGeometry(base);
|
||||
fn->self_geometry = at::TensorGeometry(base);
|
||||
fn->size = sizes();
|
||||
fn->stride = strides();
|
||||
fn->storage_offset = data.storage_offset();
|
||||
fn->set_next_functions(get_next_functions(base));
|
||||
fn->num_inputs = 1;
|
||||
_grad_fn = std::move(fn);
|
||||
grad_fn = std::move(fn);
|
||||
attr_version = current_version;
|
||||
}
|
||||
return _grad_fn;
|
||||
return grad_fn;
|
||||
}
|
||||
|
||||
void VariableViewImpl::rebase_history(int output_nr, std::shared_ptr<Function> grad_fn) {
|
||||
TORCH_ASSERT(output_nr == 0);
|
||||
TORCH_ASSERT(grad_fn);
|
||||
TORCH_ASSERTM(grad_fn->num_inputs == 1, "Functions which modify views in-place must return a single Variable");
|
||||
this->output_nr = output_nr;
|
||||
base.output_nr() = 0;
|
||||
base.get()->_grad_fn = std::make_shared<CopySlices>(
|
||||
base, TensorGeometry(data), std::move(grad_fn));
|
||||
get_grad_fn(); // trigger an update to the view's grad_fn
|
||||
void Variable::ViewImpl::rebase_history(Edge gradient_edge) {
|
||||
TORCH_ASSERT(gradient_edge.input_nr == 0);
|
||||
TORCH_ASSERT(gradient_edge.function);
|
||||
TORCH_ASSERTM(
|
||||
gradient_edge.function->num_inputs == 1,
|
||||
"Functions which modify views in-place must return a single Variable");
|
||||
this->output_nr = gradient_edge.input_nr;
|
||||
auto copy_slices = std::make_shared<CopySlices>(
|
||||
base, at::TensorGeometry(data), std::move(gradient_edge.function));
|
||||
base.set_gradient_edge({std::move(copy_slices), 0});
|
||||
get_grad_fn(); // trigger an update to the view's grad_fn
|
||||
}
|
||||
|
||||
void Variable::rebase_history(Edge gradient_edge) {
|
||||
TORCH_ASSERT(gradient_edge.function != nullptr);
|
||||
if (is_view()) {
|
||||
auto& impl = static_cast<Variable::ViewImpl&>(*get());
|
||||
impl.rebase_history(std::move(gradient_edge));
|
||||
} else {
|
||||
set_gradient_edge(std::move(gradient_edge));
|
||||
}
|
||||
}
|
||||
|
||||
Variable Variable::detach() const {
|
||||
Variable detached = make_variable(data());
|
||||
detached.version_counter() = version_counter();
|
||||
auto detached = make_variable(data(), /*requires_grad=*/false);
|
||||
detached.set_version_counter(version_counter());
|
||||
return detached;
|
||||
}
|
||||
|
||||
void Variable::detach_() {
|
||||
if (is_view()) {
|
||||
throw std::runtime_error("Can't detach views in-place. Use detach() instead");
|
||||
throw std::runtime_error(
|
||||
"Can't detach views in-place. Use detach() instead");
|
||||
}
|
||||
get()->_requires_grad = false;
|
||||
output_nr() = 0;
|
||||
get()->_grad_fn = nullptr;
|
||||
set_requires_grad(false);
|
||||
set_gradient_edge(Edge());
|
||||
}
|
||||
|
||||
void Variable::set_tracing_state(
|
||||
jit::tracer::ValueTracingState* new_tracing_state) {
|
||||
get()->tracing_state.reset(new_tracing_state);
|
||||
}
|
||||
|
||||
jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept {
|
||||
return *get()->tracing_state;
|
||||
}
|
||||
}} // namespace torch::autograd
|
||||
|
@ -1,46 +1,187 @@
|
||||
#pragma once
|
||||
|
||||
// A wrapper around at::Tensor to represent autograd Variables. Variables
|
||||
// can be implicitly converted to an at::Tensor.
|
||||
#include <Python.h>
|
||||
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function_hook.h"
|
||||
#include "torch/csrc/autograd/variable_version.h"
|
||||
#include "torch/csrc/utils/auto_unique_ptr.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/autograd/function_hook.h"
|
||||
#include "torch/csrc/utils/auto_unique_ptr.h"
|
||||
#include "torch/csrc/autograd/variable_version.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/Types.h"
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
struct Function;
|
||||
} // namespace autograd
|
||||
namespace jit { namespace tracer {
|
||||
// Has to be forward declared because tracer_state.h has a dependency on
|
||||
// variable.h.
|
||||
struct ValueTracingStateElem;
|
||||
using ValueTracingState = std::list<ValueTracingStateElem>;
|
||||
}} // namespace jit::tracer
|
||||
} // namespace torch
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
using at::Tensor;
|
||||
struct VariableImpl;
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// Variable
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// A `Variable` augments a `Tensor` with the ability to interact in our
|
||||
/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
|
||||
/// `Function`s in the autograd graph. A `Variable` can either be a leaf, like a
|
||||
/// weight in a neural network, or an interior variable, when it is the result
|
||||
/// of an operation between variables. Every `Variable` also stores another
|
||||
/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
|
||||
/// gradient will be accumulated into this variable.
|
||||
///
|
||||
/// Gradient Edges
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
|
||||
/// edge in the autograd graph that connects the variable to a particular input
|
||||
/// of the gradient function that will be invoked with the variable during the
|
||||
/// backward pass. More precisely, this gradient function can be one of two
|
||||
/// things:
|
||||
/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
|
||||
/// gradient of the function that produced the variable.
|
||||
/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
|
||||
/// scalar gradient value into its `grad` variable.
|
||||
///
|
||||
/// Versioning
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// Another major feature of `Variable`s are *versions*. Versions are
|
||||
/// incremented when an in-place mutation of a variable occurs. Versions are
|
||||
/// useful when constructing `SavedVariable`s, which take a snapshot of a
|
||||
/// `Variable` at a certain version. You can retrieve a `Variable`'s version
|
||||
/// through its `current_version()` method.
|
||||
///
|
||||
/// Views
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// It is possible for a `Variable` to be a *view* of another `Variable`, in
|
||||
/// which case it tracks that `Variable`'s data and autograd history. Beyond
|
||||
/// construction, the interface of a view is identical to that of a regular
|
||||
/// `Variable`. You can determine whether `Variable` is in fact a view by
|
||||
/// probing its `is_view()` method.
|
||||
///
|
||||
/// Interface
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// `Variable` inherits from `Tensor` and thus its API is a superset of that of
|
||||
/// `Tensor`. This means you can perform all the usual mathematical and other
|
||||
/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore,
|
||||
/// `Variable` and `Tensor` actually convert implicitly between each other. You
|
||||
/// can thus call functions defined on `Tensor`s also with `Variable`s. For
|
||||
/// this, the `Variable` class allows implicit construction from `Tensor`. It is
|
||||
/// the responsibility of calling code to ensure that this constructor is
|
||||
/// invoked only when the `Tensor`'s dynamic type is actually `Variable`. Most
|
||||
/// notably, it is *not* correct to construct a brand new `Variable` from a
|
||||
/// `Tensor` using this constructor. To do so, you must use the `make_variable`
|
||||
/// free function instead. To create a view variable, use `make_variable_view`.
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct Variable : public at::Tensor {
|
||||
inline Variable(VariableImpl * self, bool retain);
|
||||
Variable() : Tensor() {}
|
||||
Variable(const Variable & rhs) : Tensor(rhs) {}
|
||||
Variable(Variable && rhs) noexcept : Tensor(std::move(rhs)) {}
|
||||
/// Default constructor.
|
||||
Variable() = default;
|
||||
|
||||
// Implicitly casts a Tensor to a Variable. This should only be called on
|
||||
// Tensors which you know are actually Variables.
|
||||
/*implicit*/ Variable(Tensor const & rhs) : Tensor(rhs) {}
|
||||
/*implicit*/ Variable(Tensor && rhs) noexcept : Tensor(std::move(rhs)) {}
|
||||
// Factory Functions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// NOTE: These factory functions have to be friends to access the
|
||||
// `Variable::Impl`. As a side effect, it allows us to keep them in the class.
|
||||
|
||||
/// Creates a `Variable` that is a *view* of another (*base*) variable.
|
||||
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
|
||||
friend Variable
|
||||
make_variable_view(Variable base, at::Tensor data, Edge gradient_edge);
|
||||
|
||||
/// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
|
||||
/// set only for leaves, and determines whether the `Variable` will accumulate
|
||||
/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
|
||||
/// type *must* be `Tensor`.
|
||||
friend Variable make_variable(at::Tensor data, bool requires_grad);
|
||||
|
||||
/// Creates a `Variable` from the given `Tensor` and specify a `gradient_edge`,
|
||||
/// i.e. a (function, input_nr) pair specifying the function in the autograd
|
||||
/// graph, and what particular input of that function, this variable is
|
||||
/// connected to.
|
||||
friend Variable make_variable(at::Tensor data, Edge gradient_edge);
|
||||
|
||||
// Tensor Conversions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you
|
||||
// know are Variables.
|
||||
/*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {}
|
||||
/*implicit*/ Variable(at::Tensor&& rhs) noexcept
|
||||
: at::Tensor(std::move(rhs)) {}
|
||||
|
||||
// NOTE: Assignment operators to Tensor come for free from the constructors.
|
||||
|
||||
/// Downcasts the `Tensor` reference to a `Variable` reference. If compiling
|
||||
/// in DEBUG mode and the tensor's dynamic type is not in fact `Variable`,
|
||||
/// throws a `std::runtime_error` exception.
|
||||
/// NOTE: Has to be a friend function because runtime type information is
|
||||
/// available only for `TensorImpl`/`Impl` and not the `Tensor`/`Variable`
|
||||
/// classes, as the latter are not polymorphic classes (`Tensor` has no
|
||||
/// virtual methods).
|
||||
friend Variable& as_variable_ref(at::Tensor& tensor);
|
||||
|
||||
const at::Tensor& data() const noexcept;
|
||||
at::Tensor& data() noexcept;
|
||||
|
||||
// Gradient Function and Edges
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Gets the gradient function of the `Variable`. If this is a leaf variable,
|
||||
/// the pointer returned will be null.
|
||||
const std::shared_ptr<Function>& grad_fn() const;
|
||||
|
||||
/// Gets the raw gradient function pointer, whatever it currently is.
|
||||
Function* grad_fn_unsafe() const;
|
||||
|
||||
/// Sets the gradient accumulator of the `Variable`. This is only applicable
|
||||
/// to leaf variables. Interior variables should call `set_gradient_edge()`.
|
||||
void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator);
|
||||
|
||||
/// Attempts to get a pointer to the gradient accumulator of the `Variable`,
|
||||
/// if it still exists. If the gradient accumulator function has been
|
||||
/// destroyed, returns a `nullptr`.
|
||||
std::shared_ptr<Function> try_get_grad_accumulator() const;
|
||||
|
||||
/// Gets the gradient accumulator of the `Variable` if it has one, or else
|
||||
/// create one on the fly and return it.
|
||||
std::shared_ptr<Function> grad_accumulator() const;
|
||||
|
||||
/// Sets the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
|
||||
/// `Variable`.
|
||||
/// NOTE: This will always set the `grad_fn`, even if this is a leaf
|
||||
/// variable, and never the `grad_accumulator`. For the latter, use
|
||||
/// `set_grad_accumulator`. This allows late construction of an interior
|
||||
/// `Variable`.
|
||||
/// You will likely want to call `rebase_history()` if this call is involved
|
||||
/// in an in-place modification of a `Variable`.
|
||||
void set_gradient_edge(Edge&& edge) noexcept;
|
||||
|
||||
/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
|
||||
/// gradient function if this is an interior `Variable`, or the gradient
|
||||
/// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
|
||||
/// will store the input index of the `Function` to which this variable is
|
||||
/// connected in its `input_nr` field. For leaves, the `input_nr` is always
|
||||
/// zero. Note that `set_gradient_edge` and `gradient_edge` are not
|
||||
/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
|
||||
/// `set_grad_accumulator` to set the accumulator.
|
||||
Edge gradient_edge() const {
|
||||
// If grad_fn is null (as is the case for a leaf node), we instead
|
||||
// interpret the gradient function to be a grad accumulator,
|
||||
// which will accumulate its inputs into the grad property of the
|
||||
// variable. These nodes get suppressed in some situations,
|
||||
// see "suppress grad accumulation" below. Note that only variables which
|
||||
// have `requires_grad = True` can have grad accumulators.
|
||||
// interpret the gradient function to be a gradient accumulator, which will
|
||||
// accumulate its inputs into the grad property of the variable. These
|
||||
// nodes get suppressed in some situations, see "suppress gradient
|
||||
// accumulation" below. Note that only variables which have `requires_grad =
|
||||
// True` can have gradient accumulators.
|
||||
if (const auto& gradient = grad_fn()) {
|
||||
return Edge(gradient, output_nr());
|
||||
} else {
|
||||
@ -48,282 +189,434 @@ struct Variable : public at::Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
inline VariableImpl* get() const;
|
||||
/// Returns the input index of the gradient `Function` to which this `Variable`
|
||||
/// is connected.
|
||||
uint32_t output_nr() const noexcept;
|
||||
|
||||
inline const Tensor & data() const;
|
||||
inline Tensor & data();
|
||||
/// True if this `Variable` is a leaf and thus does not have a `grad_fn`.
|
||||
bool is_leaf() const noexcept;
|
||||
|
||||
inline Tensor opt_data() const;
|
||||
// The Grad Variable
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline const Variable & grad() const;
|
||||
inline Variable & grad();
|
||||
/// Accesses the gradient `Variable` of this `Variable`.
|
||||
const Variable& grad() const noexcept;
|
||||
Variable& grad() noexcept;
|
||||
void reset_grad() noexcept;
|
||||
|
||||
inline bool is_leaf() const;
|
||||
/// Sets the `requires_grad` property of `Variable`. This should be true for
|
||||
/// leaf variables that want to accumulate gradients, and false for all other
|
||||
/// variables.
|
||||
void set_requires_grad(bool requires_grad) noexcept;
|
||||
bool requires_grad() const noexcept;
|
||||
|
||||
inline const std::shared_ptr<Function>& grad_fn() const;
|
||||
// Versions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// Updates the grad_fn of an existing Variable. Called after in-place modifications.
|
||||
// XXX: this should be called only _after_ the version counter is implemented.
|
||||
inline void rebase_history(int output_nr, std::shared_ptr<Function> grad_fn);
|
||||
/// Increments the version count of this `Variable`.
|
||||
void bump_version() noexcept;
|
||||
void set_version_counter(const VariableVersion& version_counter) noexcept;
|
||||
|
||||
std::shared_ptr<Function> grad_accumulator() const;
|
||||
/// Retrieves this `Variable`s version counter.
|
||||
const VariableVersion& version_counter() const noexcept;
|
||||
|
||||
/// Retrieves the current value of the `Variable`'s version counter.
|
||||
/// Equivalent to calling `version_counter().current_version()`.
|
||||
uint32_t current_version() const noexcept;
|
||||
|
||||
// Autograd Graph Interaction
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Update the `grad_fn` of an existing Variable. Called after in-place
|
||||
/// modifications.
|
||||
void rebase_history(Edge gradient_edge);
|
||||
|
||||
/// Returns a copy of this `Variable` that is detached from its autograd graph
|
||||
/// and has a blank version. This method is OK to call if the `Variable` is a
|
||||
/// view.
|
||||
Variable detach() const;
|
||||
|
||||
/// Like `detach()`, but removes this `Variable` in-place. This method may
|
||||
/// only be called on non-view `Variable`s. You can use `is_view()` to check
|
||||
/// this. If this `Variable` is a view, throws an `std::runtime_error()`.
|
||||
void detach_();
|
||||
|
||||
inline const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const;
|
||||
inline std::vector<std::shared_ptr<FunctionPreHook>>& hooks();
|
||||
// Hooks
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline auto_unique_ptr<jit::tracer::ValueTracingState>& tracing_state() const;
|
||||
void add_hook(std::shared_ptr<FunctionPreHook> hook);
|
||||
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
|
||||
void clear_hooks();
|
||||
|
||||
inline int current_version() const;
|
||||
// JIT Tracing
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline VariableVersion& version_counter() const;
|
||||
void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state);
|
||||
jit::tracer::ValueTracingState& tracing_state() const noexcept;
|
||||
|
||||
inline const int& output_nr() const;
|
||||
inline int& output_nr();
|
||||
/// Returns true if the `Variable`'s tracing state is not null.
|
||||
bool has_tracing_state() const noexcept;
|
||||
|
||||
inline bool requires_grad() const;
|
||||
// View Variables
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline bool is_view() const;
|
||||
inline Variable& base() const;
|
||||
/// Returns true if this `Variable` is a view of another `Variable`.
|
||||
bool is_view() const noexcept;
|
||||
|
||||
inline const std::string& name() const;
|
||||
inline std::string& name();
|
||||
/// Returns the `Variable` that this `Variable` is a view of. If this
|
||||
/// `Variable` is not a view, throw a `std::runtime_error`.
|
||||
const Variable& base() const;
|
||||
|
||||
inline Variable & operator=(Variable && rhs) &;
|
||||
inline Variable & operator=(const Variable & rhs) &;
|
||||
inline Variable & operator=(Tensor && rhs) &;
|
||||
inline Variable & operator=(const Tensor & rhs) &;
|
||||
// Miscellaneous
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Compares this `Variable` to another `Variable` (or `Tensor`) via
|
||||
/// pointer-equality.
|
||||
bool is_same(const Variable& other) const noexcept {
|
||||
return this->pImpl == other.pImpl;
|
||||
}
|
||||
|
||||
void set_name(const std::string& name);
|
||||
const std::string& name() const noexcept;
|
||||
|
||||
PyObject* pyobj() const noexcept;
|
||||
void set_pyobj(PyObject* pyobj) noexcept;
|
||||
|
||||
// Hacks!
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Sets the type of the underlying `Tensor`. Used for a bad (hopefully)
|
||||
/// temporary hack in python_variable.h. If removed, also remove the `using
|
||||
/// at::TensorImpl::type_;` in `Variable::Impl`.
|
||||
void temporary_hack_set_type(at::Type*) noexcept;
|
||||
|
||||
private:
|
||||
/// Private implementation struct of the `Variable`. This struct declaration
|
||||
/// and the `get()` method which exposes it shall forever remain private and
|
||||
/// never be exposed to the public interface of this class.
|
||||
struct Impl;
|
||||
struct ViewImpl;
|
||||
|
||||
// Private Methods
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Variable(Variable::Impl* self, bool retain);
|
||||
Impl* get() const noexcept;
|
||||
};
|
||||
|
||||
struct VariableImpl : public at::TensorImpl {
|
||||
public:
|
||||
VariableImpl(at::Tensor data, bool requires_grad=false, int output_nr=0,
|
||||
std::shared_ptr<Function> grad_fn=nullptr);
|
||||
virtual ~VariableImpl();
|
||||
virtual const char * toString() const override;
|
||||
virtual at::IntList sizes() const override;
|
||||
virtual at::IntList strides() const override;
|
||||
virtual int64_t dim() const override;
|
||||
virtual at::Scalar localScalar() override;
|
||||
virtual void * unsafeGetTH(bool retain) override;
|
||||
virtual std::unique_ptr<at::Storage> storage() override;
|
||||
static const char * typeString();
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable::Impl
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct Variable::Impl : public at::TensorImpl {
|
||||
explicit Impl(
|
||||
at::Tensor data_,
|
||||
bool requires_grad_ = false,
|
||||
Edge edge = Edge());
|
||||
|
||||
virtual ~Impl();
|
||||
|
||||
const char* toString() const override;
|
||||
at::IntList sizes() const override;
|
||||
at::IntList strides() const override;
|
||||
int64_t dim() const override;
|
||||
at::Scalar localScalar() override;
|
||||
void* unsafeGetTH(bool retain) override;
|
||||
std::unique_ptr<at::Storage> storage() override;
|
||||
static const char* typeString();
|
||||
|
||||
public:
|
||||
std::shared_ptr<Function> get_grad_accumulator();
|
||||
virtual std::shared_ptr<Function>& get_grad_fn() { return _grad_fn; }
|
||||
virtual std::shared_ptr<Function>& get_grad_fn() {
|
||||
return grad_fn;
|
||||
}
|
||||
|
||||
// Make this field public so we can access it from `Variable`. Part of
|
||||
// temporary_hack_set_type.
|
||||
using at::TensorImpl::type_;
|
||||
|
||||
std::string name;
|
||||
at::Tensor data;
|
||||
|
||||
Variable grad;
|
||||
std::shared_ptr<Function> _grad_fn;
|
||||
std::shared_ptr<Function> grad_fn;
|
||||
std::weak_ptr<Function> grad_accumulator;
|
||||
|
||||
VariableVersion version_counter;
|
||||
std::vector<std::shared_ptr<FunctionPreHook>> hooks;
|
||||
std::weak_ptr<Function> grad_accumulator;
|
||||
// Mutex to ensure that concurrent read operations that modify internal state
|
||||
// are still thread-safe. Used by get_grad_fn and get_grad_accumulator.
|
||||
std::mutex mutex;
|
||||
bool _requires_grad; // only meaningful on leaf variables (must be false otherwise)
|
||||
|
||||
bool requires_grad; // only meaningful on leaf variables (must be false
|
||||
// otherwise)
|
||||
bool is_view;
|
||||
// The "output number" of this variable; e.g., if this variable
|
||||
// was the second output of a function, then output_nr == 1.
|
||||
// We use this to make sure we can setup the backwards trace
|
||||
// correctly when this variable is passed to another function.
|
||||
int output_nr;
|
||||
PyObject *pyobj; // weak reference
|
||||
uint32_t output_nr;
|
||||
PyObject* pyobj; // weak reference
|
||||
|
||||
std::string name;
|
||||
// Mutex to ensure that concurrent read operations that modify internal
|
||||
// state are still thread-safe. Used by get_grad_fn and
|
||||
// get_grad_accumulator.
|
||||
std::mutex mutex;
|
||||
|
||||
// For use in torch::jit::tracer
|
||||
auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state;
|
||||
friend struct VariableType;
|
||||
};
|
||||
|
||||
// A Variable that is a view on another Variable. The base and view share the
|
||||
// same version_counter. The _grad_fn field of the Variable may become stale
|
||||
// due to in-place modifications of the shared data. Accesses should go through
|
||||
// get_grad_fn(). All other fields are always valid.
|
||||
struct VariableViewImpl : public VariableImpl {
|
||||
VariableViewImpl(Variable base, at::Tensor data, int output_nr, std::shared_ptr<Function> grad_fn);
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable::ViewImpl
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// Gets the up-to-date grad_fn. If the shared data or base was modified, we
|
||||
// re-create the grad_fn to express the up-to-date view relationship between
|
||||
// this and the base Variable.
|
||||
/// A Variable that is a view on another Variable. The base and view share the
|
||||
/// same version_counter. The grad_fn field of the Variable may become stale
|
||||
/// due to in-place modifications of the shared data. Accesses should go
|
||||
/// through get_grad_fn(). All other fields are always valid.
|
||||
struct Variable::ViewImpl : public Variable::Impl {
|
||||
ViewImpl(Variable base_, at::Tensor data_, Edge gradient_edge);
|
||||
|
||||
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
|
||||
/// re-create the grad_fn to express the up-to-date view relationship between
|
||||
/// this and the base Variable.
|
||||
virtual std::shared_ptr<Function>& get_grad_fn() override;
|
||||
|
||||
// Called after in-place modifications. Modifies the grad_fn of the base
|
||||
// Variable.
|
||||
void rebase_history(int output_nr, std::shared_ptr<Function> grad_fn);
|
||||
/// Called after in-place modifications. Modifies the grad_fn of the base
|
||||
/// Variable.
|
||||
void rebase_history(Edge gradient_edge);
|
||||
|
||||
// The base Variable (never a view)
|
||||
/// The base `Variable` (never a view).
|
||||
Variable base;
|
||||
|
||||
// The value of the version_counter at the time grad_fn was created. The
|
||||
// _grad_fn field is stale if attr_version != version_counter.current_version()
|
||||
int attr_version;
|
||||
/// The value of the version_counter at the time grad_fn was created. The
|
||||
/// grad_fn field is stale if attr_version !=
|
||||
/// version_counter.current_version().
|
||||
uint32_t attr_version;
|
||||
};
|
||||
|
||||
inline Variable make_variable(at::Tensor data, bool requires_grad=false) {
|
||||
if (!data.defined()) {
|
||||
return Variable();
|
||||
}
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable Implementation
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
namespace detail {
|
||||
inline at::Tensor handle_scalars(at::Tensor& data) {
|
||||
#ifndef WITH_SCALARS
|
||||
if (data.dim() == 0) {
|
||||
// don't expose 0-dim tensors to Variable API.
|
||||
data = data.as_strided_({1}, {1});
|
||||
// Don't expose 0-dim tensors to Variable API.
|
||||
return data.as_strided_({1}, {1});
|
||||
}
|
||||
#endif
|
||||
return data;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
return Variable(new VariableImpl(std::move(data), requires_grad), false);
|
||||
// Factory Functions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline Variable make_variable_view(
|
||||
Variable base,
|
||||
at::Tensor data,
|
||||
Edge gradient_edge = Edge()) {
|
||||
if (data.defined()) {
|
||||
data = detail::handle_scalars(data);
|
||||
auto impl = new Variable::ViewImpl(
|
||||
std::move(base), std::move(data), std::move(gradient_edge));
|
||||
return Variable(impl, /*retain=*/false);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
|
||||
inline Variable make_variable(at::Tensor data, int output_nr, std::shared_ptr<Function> grad_fn) {
|
||||
if (!data.defined()) {
|
||||
return Variable();
|
||||
inline Variable make_variable(at::Tensor data, bool requires_grad) {
|
||||
if (data.defined()) {
|
||||
auto impl = new Variable::Impl(detail::handle_scalars(data), requires_grad);
|
||||
return Variable(impl, /*retain=*/false);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
|
||||
#ifndef WITH_SCALARS
|
||||
if (data.dim() == 0) {
|
||||
// don't expose 0-dim tensors to Variable API.
|
||||
data = data.as_strided_({1}, {1});
|
||||
inline Variable make_variable(at::Tensor data, Edge gradient_edge) {
|
||||
if (data.defined()) {
|
||||
auto impl = new Variable::Impl(
|
||||
detail::handle_scalars(data), false, std::move(gradient_edge));
|
||||
return Variable(impl, /*retain=*/false);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
|
||||
// Tensor Conversion
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline Variable& as_variable_ref(at::Tensor& tensor) {
|
||||
#ifdef DEBUG
|
||||
// dynamic_cast will return a nullptr if the `TensorImpl`'s dynamic type is
|
||||
// not `Variable::Impl`.
|
||||
if (dynamic_cast<Variable::Impl*>(tensor.get()) == nullptr) {
|
||||
throw std::runtime_error(
|
||||
"Attempted to cast a Tensor to a Variable, but "
|
||||
"the dynamic type of the value is not Variable.");
|
||||
}
|
||||
#endif
|
||||
|
||||
return Variable(new VariableImpl(std::move(data), false, output_nr, std::move(grad_fn)), false);
|
||||
return static_cast<Variable&>(tensor);
|
||||
}
|
||||
|
||||
Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn);
|
||||
|
||||
inline Variable make_variable_view(Variable base, at::Tensor data, int output_nr=0,
|
||||
std::shared_ptr<Function> grad_fn=nullptr) {
|
||||
if (!data.defined()) {
|
||||
return Variable();
|
||||
}
|
||||
|
||||
#ifndef WITH_SCALARS
|
||||
if (data.dim() == 0) {
|
||||
// don't expose 0-dim tensors to Variable API.
|
||||
data = data.as_strided_({1}, {1});
|
||||
}
|
||||
#endif
|
||||
|
||||
return Variable(new VariableViewImpl(std::move(base), std::move(data), output_nr, std::move(grad_fn)), false);
|
||||
}
|
||||
|
||||
|
||||
inline Variable::Variable(VariableImpl * self, bool retain) : Tensor(self, retain) {
|
||||
}
|
||||
|
||||
inline VariableImpl* Variable::get() const {
|
||||
return static_cast<VariableImpl*>(pImpl);
|
||||
}
|
||||
|
||||
inline const Tensor & Variable::data() const {
|
||||
return get()->data;
|
||||
}
|
||||
inline Tensor & Variable::data() {
|
||||
inline const at::Tensor& Variable::data() const noexcept {
|
||||
return get()->data;
|
||||
}
|
||||
|
||||
inline Tensor Variable::opt_data() const {
|
||||
if (!defined()) {
|
||||
return Tensor();
|
||||
}
|
||||
return data();
|
||||
inline at::Tensor& Variable::data() noexcept {
|
||||
return get()->data;
|
||||
}
|
||||
|
||||
inline const Variable & Variable::grad() const {
|
||||
return get()->grad;
|
||||
}
|
||||
inline Variable & Variable::grad() {
|
||||
return get()->grad;
|
||||
}
|
||||
|
||||
inline bool Variable::is_leaf() const {
|
||||
return get()->_grad_fn == nullptr;
|
||||
}
|
||||
// Gradient Function and Edges
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline const std::shared_ptr<Function>& Variable::grad_fn() const {
|
||||
return get()->get_grad_fn();
|
||||
};
|
||||
inline void Variable::rebase_history(int output_nr, std::shared_ptr<Function> grad_fn) {
|
||||
TORCH_ASSERT(grad_fn);
|
||||
if (is_view()) {
|
||||
auto& impl = static_cast<VariableViewImpl&>(*get());
|
||||
impl.rebase_history(output_nr, std::move(grad_fn));
|
||||
} else {
|
||||
get()->output_nr = output_nr;
|
||||
get()->_grad_fn = std::move(grad_fn);
|
||||
}
|
||||
}
|
||||
|
||||
inline Function* Variable::grad_fn_unsafe() const {
|
||||
return get()->grad_fn.get();
|
||||
}
|
||||
|
||||
inline void Variable::set_grad_accumulator(
|
||||
std::weak_ptr<Function> grad_accumulator) {
|
||||
get()->grad_accumulator = std::move(grad_accumulator);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
|
||||
return get()->grad_accumulator.lock();
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Function> Variable::grad_accumulator() const {
|
||||
return get()->get_grad_accumulator();
|
||||
};
|
||||
}
|
||||
|
||||
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() const {
|
||||
return get()->hooks;
|
||||
};
|
||||
inline std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() {
|
||||
return get()->hooks;
|
||||
};
|
||||
inline void Variable::set_gradient_edge(Edge&& edge) noexcept {
|
||||
get()->grad_fn = std::move(edge.function);
|
||||
get()->output_nr = edge.input_nr;
|
||||
}
|
||||
|
||||
inline auto_unique_ptr<jit::tracer::ValueTracingState>& Variable::tracing_state() const {
|
||||
return get()->tracing_state;
|
||||
};
|
||||
inline uint32_t Variable::output_nr() const noexcept {
|
||||
return get()->output_nr;
|
||||
}
|
||||
|
||||
inline int Variable::current_version() const {
|
||||
inline bool Variable::is_leaf() const noexcept {
|
||||
return get()->grad_fn == nullptr;
|
||||
}
|
||||
|
||||
// The Grad Variable
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline const Variable& Variable::grad() const noexcept {
|
||||
return get()->grad;
|
||||
}
|
||||
|
||||
inline Variable& Variable::grad() noexcept {
|
||||
return get()->grad;
|
||||
}
|
||||
|
||||
inline void Variable::reset_grad() noexcept {
|
||||
get()->grad.reset();
|
||||
}
|
||||
|
||||
inline void Variable::set_requires_grad(bool requires_grad) noexcept {
|
||||
get()->requires_grad = requires_grad;
|
||||
}
|
||||
|
||||
inline bool Variable::requires_grad() const noexcept {
|
||||
return get()->requires_grad || get()->grad_fn ||
|
||||
(is_view() && base().requires_grad());
|
||||
}
|
||||
|
||||
// Versions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline void Variable::set_version_counter(
|
||||
const VariableVersion& version_counter) noexcept {
|
||||
get()->version_counter = version_counter;
|
||||
}
|
||||
|
||||
inline void Variable::bump_version() noexcept {
|
||||
get()->version_counter.bump();
|
||||
}
|
||||
|
||||
inline uint32_t Variable::current_version() const noexcept {
|
||||
return get()->version_counter.current_version();
|
||||
}
|
||||
|
||||
inline VariableVersion& Variable::version_counter() const {
|
||||
inline const VariableVersion& Variable::version_counter() const noexcept {
|
||||
return get()->version_counter;
|
||||
}
|
||||
|
||||
inline const int& Variable::output_nr() const {
|
||||
return get()->output_nr;
|
||||
// Hooks
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
|
||||
get()->hooks.push_back(std::move(hook));
|
||||
}
|
||||
|
||||
inline int& Variable::output_nr() {
|
||||
return get()->output_nr;
|
||||
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
|
||||
const noexcept {
|
||||
return get()->hooks;
|
||||
}
|
||||
|
||||
inline bool Variable::requires_grad() const {
|
||||
return get()->_requires_grad || get()->_grad_fn || (is_view() && base().requires_grad());
|
||||
inline void Variable::clear_hooks() {
|
||||
get()->hooks.clear();
|
||||
}
|
||||
|
||||
inline const std::string& Variable::name() const {
|
||||
return get()->name;
|
||||
}
|
||||
inline std::string& Variable::name() {
|
||||
return get()->name;
|
||||
// JIT Tracing
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline bool Variable::has_tracing_state() const noexcept {
|
||||
return get()->tracing_state != nullptr;
|
||||
}
|
||||
|
||||
inline bool Variable::is_view()const {
|
||||
// View Variables
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline bool Variable::is_view() const noexcept {
|
||||
return get()->is_view;
|
||||
}
|
||||
inline Variable& Variable::base() const {
|
||||
|
||||
inline const Variable& Variable::base() const {
|
||||
if (is_view()) {
|
||||
return static_cast<VariableViewImpl&>(*get()).base;
|
||||
return static_cast<Variable::ViewImpl*>(get())->base;
|
||||
}
|
||||
throw std::runtime_error("Can't get base of non-view");
|
||||
}
|
||||
|
||||
inline Variable & Variable::operator=(Variable && rhs) & {
|
||||
rhs.swap(*this);
|
||||
return *this;
|
||||
// Miscellaneous
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline void Variable::set_name(const std::string& name) {
|
||||
get()->name = name;
|
||||
}
|
||||
inline Variable & Variable::operator=(const Variable & rhs) & {
|
||||
Variable(rhs).swap(*this);
|
||||
return *this;
|
||||
|
||||
inline const std::string& Variable::name() const noexcept {
|
||||
return get()->name;
|
||||
}
|
||||
inline Variable & Variable::operator=(Tensor && rhs) & {
|
||||
rhs.swap(*this);
|
||||
return *this;
|
||||
|
||||
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
|
||||
get()->pyobj = pyobj;
|
||||
}
|
||||
inline Variable & Variable::operator=(const Tensor & rhs) & {
|
||||
Variable(rhs).swap(*this);
|
||||
return *this;
|
||||
|
||||
inline PyObject* Variable::pyobj() const noexcept {
|
||||
return get()->pyobj;
|
||||
}
|
||||
|
||||
// Hacks!
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline void Variable::temporary_hack_set_type(at::Type* new_type) noexcept {
|
||||
get()->type_ = new_type;
|
||||
}
|
||||
|
||||
// Private Methods
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline Variable::Variable(Variable::Impl* self, bool retain)
|
||||
: at::Tensor(self, retain) {}
|
||||
|
||||
inline Variable::Impl* Variable::get() const noexcept {
|
||||
return static_cast<Variable::Impl*>(pImpl);
|
||||
}
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -1,96 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
// Every Variable has a version counter. Version counters are incremented
|
||||
// whenever the data or shape of a tensor changes through Variable operations.
|
||||
// whenever the data or shape of a tensor changes through Variable operations.
|
||||
// These are typicallly in-place operations. Version counters are used to
|
||||
// detect modifications to saved varaibles which would result in incorrect
|
||||
// detect modifications to saved variables which would result in incorrect
|
||||
// gradient calculations. Version counters may be shared between Variables:
|
||||
//
|
||||
// 1. A view shares the version counter of the base Variable
|
||||
// 2. Detached variables share the version counter of the source
|
||||
// 3. Unpacked saved variables share the version counter of the source
|
||||
// 1. A view shares the version counter of the base Variable,
|
||||
// 2. Detached variables share the version counter of the source,
|
||||
// 3. Unpacked saved variables share the version counter of the source.
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct VersionBlock {
|
||||
VersionBlock() : version() {}
|
||||
|
||||
// monotonically increasing version
|
||||
std::atomic<int> version;
|
||||
};
|
||||
|
||||
struct SavedVersion;
|
||||
|
||||
struct VariableVersion {
|
||||
VariableVersion() : version_block(std::make_shared<VersionBlock>()) {}
|
||||
VariableVersion(const VariableVersion&) = delete;
|
||||
VariableVersion(VariableVersion&&) = delete;
|
||||
public:
|
||||
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
|
||||
// leaves it in a persistently undefined state. See
|
||||
// https://cplusplus.github.io/LWG/issue2334.
|
||||
VariableVersion(uint32_t version = 0)
|
||||
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
|
||||
|
||||
// increment the version counter
|
||||
void increment() { version_block->version++; }
|
||||
|
||||
// current version
|
||||
int current_version() const { return version_block->version.load(); }
|
||||
|
||||
// creates a saved reference with the current version and the counter
|
||||
inline SavedVersion save() const;
|
||||
|
||||
// Uses another variable's version counter. Used for variables which share storages
|
||||
// NOTE: not thread-safe to call this from multiple threads without synchronization
|
||||
// because shared_ptr assignment isn't thread-safe.
|
||||
VariableVersion& operator=(const VariableVersion& other) {
|
||||
version_block = other.version_block;
|
||||
return *this;
|
||||
void bump() noexcept {
|
||||
version_block_->fetch_add(1);
|
||||
}
|
||||
|
||||
// Uses the version counter from a SavedVariable
|
||||
// NOTE: not thread-safe to call this from multiple threads without synchronization
|
||||
inline VariableVersion& operator=(const SavedVersion& other);
|
||||
uint32_t current_version() const noexcept {
|
||||
return version_block_->load();
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct SavedVersion;
|
||||
std::shared_ptr<VersionBlock> version_block; // always non-null
|
||||
private:
|
||||
std::shared_ptr<std::atomic<uint32_t>> version_block_;
|
||||
};
|
||||
|
||||
// The version counter used in SavedVariables. Saves the expected_version (the
|
||||
// version at the time of save) and a reference to the version counter's
|
||||
// version_block.
|
||||
struct SavedVersion {
|
||||
SavedVersion() {}
|
||||
SavedVersion(const VariableVersion& version)
|
||||
: expected_version(version.current_version())
|
||||
, version_block(version.version_block) {}
|
||||
|
||||
// if the version_block has been modified since when it was saved
|
||||
bool is_modified() const {
|
||||
return expected_version != version_block->version.load();
|
||||
}
|
||||
|
||||
// true if the version_block is defined
|
||||
bool defined() const {
|
||||
return static_cast<bool>(version_block);
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct VariableVersion;
|
||||
int expected_version;
|
||||
std::shared_ptr<VersionBlock> version_block; // may be null
|
||||
};
|
||||
|
||||
SavedVersion VariableVersion::save() const {
|
||||
return SavedVersion(*this);
|
||||
}
|
||||
|
||||
VariableVersion& VariableVersion::operator=(const SavedVersion& other) {
|
||||
if (!other.version_block) {
|
||||
throw std::runtime_error(
|
||||
"Can't take version counter from empty SavedVersion. File a bug report.");
|
||||
}
|
||||
version_block = other.version_block;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "torch/csrc/jit/passes/batch_mm.h"
|
||||
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
@ -82,26 +83,26 @@ private:
|
||||
// inplace to avoid allocations
|
||||
tensor_list unwrapVariables(variable_tensor_list && list) const {
|
||||
for(auto & v : list) {
|
||||
v = v.defined() ? static_cast<Variable&>(v).data() : at::Tensor();
|
||||
v = v.defined() ? autograd::as_variable_ref(v).data() : at::Tensor();
|
||||
}
|
||||
return std::move(list);
|
||||
}
|
||||
// inplace to avoid allocations
|
||||
variable_tensor_list wrapTensors(tensor_list && list) const {
|
||||
for(auto & v : list) {
|
||||
v = autograd::make_variable(v);
|
||||
v = autograd::make_variable(v, /*requires_grad=*/false);
|
||||
}
|
||||
return variable_tensor_list(std::move(list));
|
||||
}
|
||||
// Capture (save) inputs that would be required to subsequently run backwards
|
||||
void captureInputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & inputs) const {
|
||||
for(auto offset : grad.df_input_captured_inputs) {
|
||||
grad_fn.captures.emplace_back(static_cast<Variable&>(inputs[offset]), false);
|
||||
grad_fn.captures.emplace_back(autograd::as_variable_ref(inputs[offset]), false);
|
||||
}
|
||||
}
|
||||
void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & outputs) const {
|
||||
for(auto offset : grad.df_input_captured_outputs) {
|
||||
grad_fn.captures.emplace_back(static_cast<Variable&>(outputs[offset]), true);
|
||||
grad_fn.captures.emplace_back(autograd::as_variable_ref(outputs[offset]), true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +112,7 @@ private:
|
||||
// hook up the outputs of df to the gradient functions of the inputs that require
|
||||
// gradients
|
||||
for(auto idx : grad.df_output_vjps) {
|
||||
auto & v = static_cast<Variable&>(inputs[idx]);
|
||||
auto & v = autograd::as_variable_ref(inputs[idx]);
|
||||
// TODO: this kinda stuff is _way_ to low level to the public API of variable.
|
||||
// Why do I have to care here whether v has a grad_fn or grad accumulator?
|
||||
// Why do I have to care here about output_nr? I just want to say
|
||||
@ -133,15 +134,12 @@ private:
|
||||
// this is currently intentionally not done here so we can get an idea of our
|
||||
// perf before introducing overhead for correctness
|
||||
for(auto idx : grad.df_input_vjps) {
|
||||
auto & o = static_cast<Variable&>(outputs[idx]);
|
||||
auto impl = o.get();
|
||||
// Note: we have to set this up in place, or we have to
|
||||
// throw away and reallocate variables that were already created in
|
||||
// wrapTensors. We should add an API for this, and more generally
|
||||
// we need to clean up the fields of Variable.
|
||||
impl->_grad_fn = grad_fn;
|
||||
impl->output_nr = grad_fn->num_inputs++;
|
||||
impl->_requires_grad = true;
|
||||
// Note: we have to set this up in place, or we have to throw away and
|
||||
// reallocate variables that were already created in wrapTensors. We
|
||||
// should add an API for this.
|
||||
auto& output = autograd::as_variable_ref(outputs[idx]);
|
||||
output.set_gradient_edge(autograd::Edge(grad_fn, grad_fn->num_inputs++));
|
||||
output.set_requires_grad(true);
|
||||
}
|
||||
captureOutputs(*grad_fn, outputs);
|
||||
// drop the temporary outputs so that we return the same number of
|
||||
|
@ -1,11 +1,13 @@
|
||||
#include "Python.h"
|
||||
#include "interpreter.h"
|
||||
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/autograd/profiler.h"
|
||||
#include "torch/csrc/jit/generated/aten_dispatch.h"
|
||||
#include "torch/csrc/jit/pybind.h"
|
||||
#include "torch/csrc/utils/auto_gil.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/python_engine.h"
|
||||
#include "torch/csrc/autograd/functions/special.h"
|
||||
@ -62,12 +64,12 @@ struct HandleBuilder {
|
||||
}
|
||||
autograd::Variable addInput(at::Retainable* input, const VariableFlags & flags_) {
|
||||
if(handle && flags_.requires_grad) {
|
||||
auto gradient_edge = autograd::Edge(
|
||||
handle->forward_inputs, handle->forward_inputs->num_inputs++);
|
||||
return autograd::make_variable(
|
||||
unsafeToTensorShare(input),
|
||||
handle->forward_inputs->num_inputs++,
|
||||
handle->forward_inputs);
|
||||
unsafeToTensorShare(input), std::move(gradient_edge));
|
||||
} else {
|
||||
return autograd::make_variable(unsafeToTensorShare(input));
|
||||
return autograd::make_variable(unsafeToTensorShare(input), /*requires_grad=*/false);
|
||||
}
|
||||
}
|
||||
at::Retainable* addOutput(const autograd::Variable & output) {
|
||||
|
@ -1,7 +1,13 @@
|
||||
#include "Python.h"
|
||||
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/jit/interpreter_autograd_function.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
@ -135,9 +141,10 @@ autograd::variable_list InterpreterAutogradFunction::apply(
|
||||
auto & flags = details.output_flags[i];
|
||||
if (flags.requires_grad) { // See Note [Null-edge pruning]
|
||||
if (!grad_fn) make_grad_fn();
|
||||
result.push_back(autograd::make_variable(toutputs[i], grad_fn));
|
||||
autograd::Edge edge(grad_fn, grad_fn->num_inputs++);
|
||||
result.push_back(autograd::make_variable(toutputs[i], std::move(edge)));
|
||||
} else {
|
||||
result.push_back(autograd::make_variable(toutputs[i], false));
|
||||
result.push_back(autograd::make_variable(toutputs[i], /*requires_grad=*/false));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -437,7 +437,7 @@ public:
|
||||
return outputs_.back();
|
||||
}
|
||||
void eraseOutput(size_t i);
|
||||
|
||||
|
||||
Block * addBlock();
|
||||
void eraseBlock(size_t i);
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/THP.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
#include "torch/csrc/jit/tracer.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -113,7 +113,7 @@ struct CompiledFunction {
|
||||
std::size_t num_captured = fn_.captured_vars_.size();
|
||||
// Check that no captured Variables were replaced by enter. It's hard to handle that.
|
||||
for (std::size_t i = num_all_inputs - num_captured; i < num_all_inputs; ++i) {
|
||||
TORCH_EXPECTM(input_info.vars[i].get() == new_vars[i].get(),
|
||||
TORCH_EXPECTM(input_info.vars[i].is_same(new_vars[i]),
|
||||
"Some of the Variables captured by the JIT are repeated");
|
||||
}
|
||||
// Now only arguments to this function could have changed. Slice their vars out, and
|
||||
|
@ -463,7 +463,7 @@ struct ADTestSpec {
|
||||
std::vector<Variable> make_vars() const {
|
||||
std::vector<Variable> out;
|
||||
for (const auto & m : input_meta) {
|
||||
out.emplace_back(make_variable(at::CPU(at::kFloat).tensor(m).normal_(), true));
|
||||
out.emplace_back(autograd::make_variable(at::CPU(at::kFloat).tensor(m).normal_(), /*requires_grad=*/true));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@ -8,6 +8,9 @@
|
||||
#include "torch/csrc/utils/auto_gil.h"
|
||||
#include "torch/csrc/utils/python_strings.h"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <frameobject.h>
|
||||
#include <patchlevel.h>
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "torch/csrc/utils/variadic.h"
|
||||
#include "torch/csrc/autograd/function_hook.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/utils/auto_unique_ptr.h"
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@ -26,12 +27,13 @@ std::string getPythonInterpreterStackTrace();
|
||||
namespace detail {
|
||||
|
||||
inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>& state, const Variable& var, bool alloc = true) {
|
||||
for (auto it = var.tracing_state()->begin(); it != var.tracing_state()->end();) {
|
||||
auto& tracing_state = var.tracing_state();
|
||||
for (auto it = tracing_state.begin(); it != tracing_state.end();) {
|
||||
auto ts = it->state.lock();
|
||||
// GC of invalidated tracing states
|
||||
if (!ts) {
|
||||
auto current_it = it++;
|
||||
var.tracing_state()->erase(current_it);
|
||||
tracing_state.erase(current_it);
|
||||
continue;
|
||||
} else if (ts == state) {
|
||||
return &(*it);
|
||||
@ -39,8 +41,8 @@ inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>&
|
||||
++it;
|
||||
}
|
||||
if (alloc) {
|
||||
var.tracing_state()->emplace_front();
|
||||
auto & vts = var.tracing_state()->front();
|
||||
tracing_state.emplace_front();
|
||||
auto & vts = tracing_state.front();
|
||||
vts.state = state;
|
||||
return &vts;
|
||||
} else {
|
||||
@ -70,8 +72,8 @@ inline std::vector<VariableFlags> getVarFlags(const variable_list& vars) {
|
||||
// need it (in most cases if we have a variable_list it is already
|
||||
// flattened).
|
||||
inline bool isTracingVar(const Variable& var) {
|
||||
if (!var.defined() || !var.tracing_state()) return false;
|
||||
return std::any_of(var.tracing_state()->begin(), var.tracing_state()->end(), detail::isElemActive);
|
||||
if (!var.defined() || !var.has_tracing_state()) return false;
|
||||
return std::any_of(var.tracing_state().begin(), var.tracing_state().end(), detail::isElemActive);
|
||||
}
|
||||
|
||||
inline bool isTracingVar(at::ArrayRef<Variable> vars) {
|
||||
@ -104,8 +106,8 @@ inline bool isTracing(Args&&... args) {
|
||||
inline std::shared_ptr<TracingState> getTracingState(const variable_list& vars) {
|
||||
std::shared_ptr<TracingState> state;
|
||||
for (auto& var : vars) {
|
||||
if (!var.defined() || !var.tracing_state()) continue;
|
||||
for (auto & vts : *var.tracing_state()) {
|
||||
if (!var.defined() || !var.has_tracing_state()) continue;
|
||||
for (auto & vts : var.tracing_state()) {
|
||||
auto var_state = vts.state.lock();
|
||||
if (!var_state || !var_state->active) continue;
|
||||
if (!state) state = var_state;
|
||||
|
38
torch/csrc/jit/tracer_state.cpp
Normal file
38
torch/csrc/jit/tracer_state.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace jit { namespace tracer {
|
||||
TracingState::TracingState(size_t num_stages)
|
||||
: graph(new Graph()),
|
||||
active(false),
|
||||
num_stages(num_stages),
|
||||
eval_count(0),
|
||||
var_flags(num_stages),
|
||||
output_edges(num_stages) {}
|
||||
|
||||
TracingState::~TracingState() = default;
|
||||
|
||||
bool TracingState::is_complete() const {
|
||||
return !is_expired() && graph->stage() == num_stages - 1;
|
||||
}
|
||||
|
||||
void TracingState::push_scope(const std::string& scope_name) {
|
||||
graph->push_scope(scope_name);
|
||||
}
|
||||
|
||||
void TracingState::pop_scope() {
|
||||
graph->pop_scope();
|
||||
}
|
||||
}}} // namespace torch::jit::tracer
|
@ -1,28 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Variable;
|
||||
struct Function;
|
||||
|
||||
}}
|
||||
namespace torch { namespace jit {
|
||||
struct Graph;
|
||||
struct Value;
|
||||
struct VariableFlags;
|
||||
}} // namespace torch::jit
|
||||
|
||||
namespace torch { namespace jit { namespace tracer {
|
||||
|
||||
using torch::autograd::Variable;
|
||||
using torch::autograd::Function;
|
||||
using edge_list = std::vector<autograd::Edge>;
|
||||
using variable_list = std::vector<Variable>;
|
||||
using variable_list = std::vector<autograd::Variable>;
|
||||
|
||||
// TracingState tracks the necessary state when we are tracing the execution of
|
||||
// autograd code; most importantly, it holds a reference to the actual IR
|
||||
@ -34,26 +34,19 @@ using variable_list = std::vector<Variable>;
|
||||
// from arising when a variable that participated in a trace outlives the
|
||||
// actual trace itself.
|
||||
|
||||
using io_variable_flags_list =
|
||||
std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>>;
|
||||
using io_variable_flags_list = std::vector<
|
||||
std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>>;
|
||||
|
||||
struct TracingState : public std::enable_shared_from_this<TracingState> {
|
||||
TracingState(std::size_t num_stages)
|
||||
: graph(new Graph())
|
||||
, active(false)
|
||||
, num_stages(num_stages)
|
||||
, eval_count(0)
|
||||
, var_flags(num_stages)
|
||||
, output_edges(num_stages) {}
|
||||
explicit TracingState(size_t num_stages);
|
||||
~TracingState();
|
||||
|
||||
// XXX: graph can be NULL if it's a failed trace (failed = didn't capture all
|
||||
// the stages we care about)
|
||||
std::shared_ptr<Graph> graph;
|
||||
bool active;
|
||||
|
||||
// Used to free the Graph as soon as we know this trace will fail
|
||||
std::size_t num_stages;
|
||||
std::atomic<std::size_t> eval_count;
|
||||
size_t num_stages;
|
||||
std::atomic<size_t> eval_count;
|
||||
|
||||
// void* is an unsafe TH. NON-OWNING, so it might get invalidated.
|
||||
// TODO: Perhaps, turn this into an owning reference. The buffers
|
||||
@ -66,23 +59,17 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
|
||||
std::mutex mutex;
|
||||
variable_list inputs; // Used only for the duration of first stage
|
||||
|
||||
std::unique_lock<std::mutex> lock() { return std::unique_lock<std::mutex>(mutex); };
|
||||
std::unique_lock<std::mutex> lock() {
|
||||
return std::unique_lock<std::mutex>(mutex);
|
||||
}
|
||||
|
||||
bool is_expired() const {
|
||||
bool is_expired() const noexcept {
|
||||
return !graph;
|
||||
}
|
||||
|
||||
bool is_complete() const {
|
||||
return !is_expired() && graph->stage() == num_stages - 1;
|
||||
}
|
||||
|
||||
void push_scope(const std::string& scope_name) {
|
||||
graph->push_scope(scope_name);
|
||||
}
|
||||
|
||||
void pop_scope() {
|
||||
graph->pop_scope();
|
||||
}
|
||||
bool is_complete() const;
|
||||
void push_scope(const std::string& scope_name);
|
||||
void pop_scope();
|
||||
};
|
||||
|
||||
struct ValueTracingStateElem {
|
||||
@ -102,4 +89,4 @@ struct FunctionTracingState {
|
||||
bool in_eval_subgraph = false;
|
||||
};
|
||||
|
||||
}}}
|
||||
}}} // namespace torch::jit::tracer
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "torch/csrc/jit/variable_flags.h"
|
||||
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
|
||||
using torch::autograd::Variable;
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
|
@ -93,14 +93,14 @@ static Tensor new_from_data(ScalarType scalarType, PyObject* data) {
|
||||
}
|
||||
#ifdef WITH_NUMPY
|
||||
if (PyArray_Check(data)) {
|
||||
return autograd::make_variable(tensor_from_numpy(data), false);
|
||||
return autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto sizes = compute_sizes(data);
|
||||
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), false);
|
||||
// TODO: we should pass tensor.sizes() rather than sizes, but this doesn't works
|
||||
// if scalars are disabled because the size changes without WITH_SCALARS.
|
||||
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), /*requires_grad=*/false);
|
||||
recursive_store(
|
||||
(char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
|
||||
scalarType, tensor.type().elementSizeInBytes(), data);
|
||||
@ -198,7 +198,7 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
|
||||
}
|
||||
|
||||
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
||||
static_cast<torch::autograd::Variable&>(self).get()->_requires_grad = requires_grad;
|
||||
static_cast<torch::autograd::Variable&>(self).set_requires_grad(requires_grad);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,15 @@
|
||||
#include "tuple_parser.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "python_strings.h"
|
||||
#include "python_numbers.h"
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
||||
TupleParser::TupleParser(PyObject* args, int num_args) : args(args), idx(0) {
|
||||
|
@ -3,6 +3,11 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace torch {
|
||||
|
||||
// This class allows you to write variadic functions which
|
||||
@ -84,4 +89,23 @@ inline size_t count_variables(Args&&... args) {
|
||||
return CountVariables().apply(std::forward<Args>(args)...).out;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std::index_sequence shim for C++11
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A container of type-template parameter indices.
|
||||
template<size_t... Is>
|
||||
struct Indices {};
|
||||
|
||||
// Decrements the index N, adds N-1 to the list of indices and forwards
|
||||
// whatever we arleady have.
|
||||
template<size_t N, size_t... Is>
|
||||
struct MakeIndices : MakeIndices<N-1, N-1, Is...> {};
|
||||
|
||||
// Partial specialization that forms our base case. When N is zero, we stop
|
||||
// and define a typedef that will be visible to earlier classes due to
|
||||
// inheritance. The typedef we define is an index list containing the numbers
|
||||
// 0 through N-1.
|
||||
template<size_t... Is>
|
||||
struct MakeIndices<0, Is...> { using indices = Indices<Is...>; };
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user