mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Move most methods off Variable into torch::autograd::impl functions. (#29665)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29665 Our intention is to merge the static distinction between Tensor and Variable. Ordinarily, this would entail merging the methods of Tensor and Variable. But there are a lot of "private"-ish methods on Variable that we don't actually want to dump onto the Tensor class. So, as prep work, we move all of those methods off of Variable and into the torch::autograd::impl namespace (impl as in, please don't use this end users). This ends up being a fairly large patch because all of the call sites have to play ball too. While I was on the topic, I also moved any of the touched functions into the C++ file, so that modifying them would not trigger a recompilation of all of torch. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D18496169 Pulled By: ezyang fbshipit-source-id: afb203252620ec274be596b3e7b1d84d321bad3a
This commit is contained in:
committed by
Facebook Github Bot
parent
38340f59fd
commit
1ab2f043ba
@ -57,7 +57,7 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
|
||||
if (obj) {
|
||||
auto v = (THPVariable*) obj;
|
||||
new (&v->cdata) Variable(std::move(var));
|
||||
v->cdata.set_pyobj(obj);
|
||||
torch::autograd::impl::set_pyobj(v->cdata, obj);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
@ -68,7 +68,7 @@ PyObject * THPVariable_Wrap(Variable var)
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
if (auto obj = var.pyobj()) {
|
||||
if (auto obj = torch::autograd::impl::pyobj(var)) {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
@ -94,7 +94,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
// for more details about the race condition involving traversing the grad_fn
|
||||
// and the python GC.
|
||||
if (self->cdata.defined()) {
|
||||
for (const auto& hook : self->cdata.hooks()) {
|
||||
for (const auto& hook : torch::autograd::impl::hooks(self->cdata)) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
}
|
||||
@ -107,7 +107,7 @@ static int THPVariable_clear(THPVariable *self)
|
||||
{
|
||||
Py_CLEAR(self->backward_hooks);
|
||||
if (self->cdata.defined()) {
|
||||
if (auto grad_acc = self->cdata.try_get_grad_accumulator()) {
|
||||
if (auto grad_acc = torch::autograd::impl::try_get_grad_accumulator(self->cdata)) {
|
||||
grad_acc->pre_hooks().clear();
|
||||
}
|
||||
// We must clear the pyobj field in the base C++ Variable, to ensure
|
||||
@ -123,7 +123,7 @@ static int THPVariable_clear(THPVariable *self)
|
||||
// objects stay live, buster! See
|
||||
// https://github.com/pytorch/pytorch/issues/22884 for an example of
|
||||
// this actually showing up.
|
||||
self->cdata.set_pyobj(nullptr);
|
||||
torch::autograd::impl::set_pyobj(self->cdata, nullptr);
|
||||
}
|
||||
self->cdata.reset();
|
||||
return 0;
|
||||
@ -419,9 +419,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unus
|
||||
Py_XINCREF(obj);
|
||||
Py_XDECREF(self->backward_hooks);
|
||||
self->backward_hooks = obj;
|
||||
self->cdata.clear_hooks();
|
||||
torch::autograd::impl::clear_hooks(self->cdata);
|
||||
if (obj) {
|
||||
self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
|
||||
torch::autograd::impl::add_hook(self->cdata, std::make_shared<PyFunctionPreHook>(obj, 0));
|
||||
}
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
|
||||
Reference in New Issue
Block a user