mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30915 Since we now have C++14, we don't need these c10::guts helpers anymore ghstack-source-id: 95777609 Test Plan: waitforsandcastle Differential Revision: D18869639 fbshipit-source-id: 97716f932297c64c6e814410ac47b444c33d4e2e
371 lines
13 KiB
C++
371 lines
13 KiB
C++
#include <torch/csrc/autograd/variable.h>
|
|
|
|
#include <torch/csrc/autograd/autograd.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
|
#include <torch/csrc/autograd/functions/tensor.h>
|
|
#include <torch/csrc/autograd/generated/Functions.h>
|
|
|
|
#include <ATen/core/VariableHooksInterface.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <list>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <typeinfo>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
|
|
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base)
|
|
: AutogradMeta(self_impl, false) {
|
|
base_ = std::move(base);
|
|
TORCH_CHECK(base_.defined(), "base is undefined");
|
|
if (base_.is_view()) {
|
|
base_ = base_.base();
|
|
}
|
|
is_view_ = true;
|
|
self_impl->set_version_counter(impl::version_counter(base_));
|
|
attr_version = self_impl->version_counter().current_version();
|
|
}
|
|
|
|
DifferentiableViewMeta::~DifferentiableViewMeta() {
|
|
base_.reset();
|
|
}
|
|
|
|
namespace {
|
|
|
|
at::Tensor singleton_undefined_tensor;
|
|
|
|
struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
|
|
std::unique_ptr<c10::AutogradMetaInterface> make() const override {
|
|
return std::make_unique<AutogradMeta>();
|
|
}
|
|
const at::Tensor& undefined_tensor() const override {
|
|
return singleton_undefined_tensor;
|
|
}
|
|
};
|
|
|
|
ConcreteAutogradMetaFactory meta_factory;
|
|
|
|
static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_factory);
|
|
|
|
}
|
|
|
|
namespace impl {
|
|
|
|
AutogradMeta* materialize_autograd_meta(const Variable& self) {
|
|
TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor");
|
|
auto p = self.unsafeGetTensorImpl();
|
|
if (!p->autograd_meta()) {
|
|
p->set_autograd_meta(std::make_unique<AutogradMeta>());
|
|
}
|
|
return get_autograd_meta(self);
|
|
}
|
|
|
|
void rebase_history(const Variable& self, Edge gradient_edge) {
|
|
AT_ASSERT(gradient_edge.function != nullptr);
|
|
if (self.is_view()) {
|
|
// NB: is_view() ==> get_autograd_meta()
|
|
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(get_autograd_meta(self));
|
|
AT_ASSERT(gradient_edge.input_nr == 0);
|
|
AT_ASSERT(gradient_edge.function);
|
|
TORCH_CHECK(
|
|
gradient_edge.function->num_inputs() == 1,
|
|
"Functions which modify views in-place must return a single Variable");
|
|
diff_view_meta->output_nr_ = gradient_edge.input_nr;
|
|
auto copy_slices = std::make_shared<CopySlices>(
|
|
diff_view_meta->base_, at::TensorGeometry(self), std::move(gradient_edge.function));
|
|
set_gradient_edge(diff_view_meta->base_, {std::move(copy_slices), 0});
|
|
self.grad_fn(); // trigger an update to the view's grad_fn
|
|
} else {
|
|
set_gradient_edge(self, std::move(gradient_edge));
|
|
}
|
|
}
|
|
|
|
void create_cpp_hook(const Variable& self) {
|
|
auto &list = materialize_autograd_meta(self)->cpp_hooks_list;
|
|
list.reset(new hooks_list());
|
|
std::unique_ptr<FunctionPreHook> hook_ptr(new CppFunctionPreHook(list, self.output_nr()));
|
|
clear_hooks(self);
|
|
add_hook(self, std::make_shared<CppFunctionPreHook>(list, 0));
|
|
auto fn = self.grad_fn();
|
|
if (fn) {
|
|
fn->add_pre_hook(std::move(hook_ptr));
|
|
}
|
|
}
|
|
|
|
void set_grad_accumulator(const Variable& self,
|
|
std::weak_ptr<Node> grad_accumulator) {
|
|
materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator);
|
|
}
|
|
|
|
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
|
|
if (get_autograd_meta(self)) {
|
|
return get_autograd_meta(self)->grad_accumulator_.lock();
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<Node> grad_accumulator(const Variable& self) {
|
|
auto autograd_meta = get_autograd_meta(self);
|
|
if (!autograd_meta) {
|
|
return nullptr;
|
|
}
|
|
if (autograd_meta->grad_fn_) {
|
|
throw std::logic_error(
|
|
"grad_accumulator() should be only called on leaf Variables");
|
|
}
|
|
if (!autograd_meta->requires_grad_) {
|
|
return nullptr;
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
|
|
|
|
auto result = autograd_meta->grad_accumulator_.lock();
|
|
if (result)
|
|
return result;
|
|
|
|
c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
|
|
auto intrusive_from_this = c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
|
|
result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
|
|
autograd_meta->grad_accumulator_ = result;
|
|
return result;
|
|
}
|
|
|
|
Edge gradient_edge(const Variable& self) {
|
|
// If grad_fn is null (as is the case for a leaf node), we instead
|
|
// interpret the gradient function to be a gradient accumulator, which will
|
|
// accumulate its inputs into the grad property of the variable. These
|
|
// nodes get suppressed in some situations, see "suppress gradient
|
|
// accumulation" below. Note that only variables which have `requires_grad =
|
|
// True` can have gradient accumulators.
|
|
if (const auto& gradient = self.grad_fn()) {
|
|
return Edge(gradient, self.output_nr());
|
|
} else {
|
|
return Edge(grad_accumulator(self), 0);
|
|
}
|
|
}
|
|
|
|
void set_gradient_edge(const Variable& self, Edge edge) {
|
|
auto* meta = materialize_autograd_meta(self);
|
|
meta->grad_fn_ = std::move(edge.function);
|
|
meta->output_nr_ = edge.input_nr;
|
|
}
|
|
|
|
Node* grad_fn_unsafe(const Variable& self) {
|
|
if (get_autograd_meta(self)) {
|
|
return get_autograd_meta(self)->grad_fn_.get();
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Versions
|
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
void set_version_counter(
|
|
const Variable& self,
|
|
const c10::VariableVersion& version_counter) {
|
|
TORCH_CHECK(self.defined(), "cannot call set_version_counter() on undefined tensor");
|
|
self.unsafeGetTensorImpl()->set_version_counter(version_counter);
|
|
}
|
|
|
|
void bump_version(const Variable& self) {
|
|
TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
|
|
self.unsafeGetTensorImpl()->bump_version();
|
|
}
|
|
|
|
const c10::VariableVersion& version_counter(const Variable& self) {
|
|
TORCH_CHECK(self.defined(), "cannot call version_counter() on undefined tensor");
|
|
return self.unsafeGetTensorImpl()->version_counter();
|
|
}
|
|
|
|
// Hooks
|
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
void add_hook(const Variable& self, std::shared_ptr<FunctionPreHook> hook) {
|
|
materialize_autograd_meta(self)->hooks_.push_back(std::move(hook));
|
|
}
|
|
|
|
namespace {
|
|
std::vector<std::shared_ptr<FunctionPreHook>> empty_singleton;
|
|
}
|
|
|
|
// TODO: Return an ArrayRef instead (and delete the singleton while you're at
|
|
// it
|
|
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks(const Variable& self)
|
|
{
|
|
if (get_autograd_meta(self)) {
|
|
return get_autograd_meta(self)->hooks_;
|
|
} else {
|
|
return empty_singleton;
|
|
}
|
|
}
|
|
|
|
void clear_hooks(const Variable& self) {
|
|
// This is a little goofy, but usually this should be a no oop
|
|
materialize_autograd_meta(self)->hooks_.clear();
|
|
}
|
|
|
|
void set_name(const Variable& self, const std::string& name) {
|
|
materialize_autograd_meta(self)->name_ = name;
|
|
}
|
|
|
|
// Miscellaneous
|
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
void set_pyobj(const Variable& self, PyObject* pyobj) {
|
|
TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor");
|
|
self.unsafeGetTensorImpl()->set_pyobj(pyobj);
|
|
}
|
|
|
|
PyObject* pyobj(const Variable& self) {
|
|
TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor");
|
|
return self.unsafeGetTensorImpl()->pyobj();
|
|
}
|
|
|
|
AutogradMeta* get_autograd_meta(const Variable& self) {
|
|
// NB: could return null
|
|
TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
|
|
return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
|
|
}
|
|
|
|
} // namespace impl
|
|
|
|
using at::Tensor;
|
|
|
|
struct VariableHooks final : at::impl::VariableHooksInterface {
|
|
Tensor tensor_data(const Tensor&) const override;
|
|
Tensor variable_data(const Tensor&) const override;
|
|
const std::shared_ptr<torch::autograd::Node>& grad_fn(const Tensor&) const override;
|
|
unsigned _register_hook(const Tensor&, std::function<Tensor(const Tensor&)> hook) const override;
|
|
void remove_hook(const Tensor&, unsigned pos) const override;
|
|
bool is_view(const Tensor&) const override;
|
|
const Tensor& base(const Tensor&) const override;
|
|
const std::string& name(const Tensor&) const override;
|
|
};
|
|
|
|
VariableHooks variableHooks;
|
|
at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
|
|
|
|
Tensor VariableHooks::variable_data(const Tensor& self) const {
|
|
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
|
|
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
|
/*version_counter=*/0,
|
|
/*allow_tensor_metadata_change=*/false);
|
|
self_impl_copy->set_autograd_meta(nullptr);
|
|
return at::Tensor(self_impl_copy);
|
|
}
|
|
|
|
Tensor VariableHooks::tensor_data(const Tensor& self) const {
|
|
TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
|
|
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
|
/*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
|
|
/*allow_tensor_metadata_change=*/self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
|
|
return at::Tensor(self_impl_copy);
|
|
}
|
|
|
|
// View Variables
|
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
bool VariableHooks::is_view(const Tensor& self) const {
|
|
if (torch::autograd::impl::get_autograd_meta(self)) {
|
|
return torch::autograd::impl::get_autograd_meta(self)->is_view_;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
const Tensor& VariableHooks::base(const Tensor& self) const {
|
|
if (self.is_view()) {
|
|
// is_view() implies get_autograd_meta()
|
|
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
|
|
return diff_view_meta->base_;
|
|
} else {
|
|
throw std::runtime_error("Can't get base of non-view Variable");
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
std::string singleton_string;
|
|
}
|
|
|
|
const std::string& VariableHooks::name(const Tensor& self) const {
|
|
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
|
|
if (torch::autograd::impl::get_autograd_meta(self)) {
|
|
return torch::autograd::impl::get_autograd_meta(self)->name_;
|
|
} else {
|
|
return singleton_string;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
|
|
}
|
|
|
|
const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tensor& self) const {
|
|
if (self.is_view()) {
|
|
// NB: is_view() ==> get_autograd_meta()
|
|
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
|
|
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
|
|
if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
|
|
return diff_view_meta->grad_fn_;
|
|
}
|
|
auto current_version = self._version();
|
|
if (diff_view_meta->attr_version != current_version) {
|
|
AT_ASSERT(diff_view_meta->output_nr_ == 0);
|
|
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
|
|
fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
|
|
fn->size = self.sizes().vec();
|
|
fn->stride = self.strides().vec();
|
|
fn->storage_offset = self.storage_offset();
|
|
fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_));
|
|
fn->add_input_metadata(
|
|
diff_view_meta->base_.options()
|
|
, self.sizes() // Note: sizes(), not base_.sizes(), is intentional
|
|
, diff_view_meta->base_.device());
|
|
diff_view_meta->grad_fn_ = std::move(fn);
|
|
diff_view_meta->attr_version = current_version;
|
|
}
|
|
return diff_view_meta->grad_fn_;
|
|
} else {
|
|
if (torch::autograd::impl::get_autograd_meta(self)) {
|
|
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
|
|
} else {
|
|
return singleton_shared_ptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
void VariableHooks::remove_hook(const Tensor& self, unsigned pos) const {
|
|
auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list;
|
|
TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos);
|
|
// Hook will be ignored
|
|
(*list)[pos] = nullptr;
|
|
}
|
|
|
|
unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(const Tensor&)> hook) const {
|
|
TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that "
|
|
"doesn't require gradient");
|
|
// NB: materialize_autograd_meta unnecessary due to requires grad check
|
|
auto &list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list;
|
|
if(!list) {
|
|
torch::autograd::impl::create_cpp_hook(self);
|
|
}
|
|
unsigned idx = list->size();
|
|
list->push_back(hook);
|
|
return idx;
|
|
}
|
|
|
|
}} // namespace torch::autograd
|