Add tensor post accumulate grad hook API (#107063)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063
Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
Jane Xu
2023-08-21 15:17:14 -07:00
committed by PyTorch MergeBot
parent bcede143bd
commit 3f655277d4
14 changed files with 379 additions and 8 deletions

View File

@ -36,6 +36,7 @@
#include <ATen/ATen.h>
#include <autograd/function_hook.h>
#include <c10/core/SymIntArrayRef.h>
#include <structmember.h>
#include <cstdint>
@ -427,6 +428,7 @@ static int THPVariable_clear(THPVariable* self) {
return 0;
}
Py_CLEAR(self->backward_hooks);
Py_CLEAR(self->post_accumulate_grad_hooks);
const auto& tensor = THPVariable_Unpack(self);
if (tensor.defined()) {
// Two situations to consider:
@ -1162,6 +1164,47 @@ int THPVariable_set_backwards_hooks(
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPVariable_get_post_accumulate_grad_hooks(
THPVariable* self,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
}
if (self->post_accumulate_grad_hooks) {
Py_INCREF(self->post_accumulate_grad_hooks);
return self->post_accumulate_grad_hooks;
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
int THPVariable_set_post_accumulate_grad_hooks(
THPVariable* self,
PyObject* obj,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_setter(
self, "_post_accumulate_grad_hooks", obj);
}
THPUtils_assertRet(
-1, obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
if (obj == Py_None) {
obj = nullptr;
}
Py_XINCREF(obj);
Py_CLEAR(self->post_accumulate_grad_hooks);
self->post_accumulate_grad_hooks = obj;
const auto& tensor = THPVariable_Unpack(self);
if (obj) {
torch::autograd::impl::set_post_acc_grad_hooks(
tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
}
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -1476,6 +1519,11 @@ static struct PyGetSetDef THPVariable_properties[] = {
(setter)THPVariable_set_backwards_hooks,
nullptr,
nullptr},
{"_post_accumulate_grad_hooks",
(getter)THPVariable_get_post_accumulate_grad_hooks,
(setter)THPVariable_set_post_accumulate_grad_hooks,
nullptr,
nullptr},
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
@ -2033,6 +2081,7 @@ static int THPVariable_subclass_traverse(
// Finally traverse THPVariable special stuff
Py_VISIT(var->backward_hooks);
Py_VISIT(var->post_accumulate_grad_hooks);
if (!var->cdata.unsafeIsBorrowed()) {
const auto& tensor = THPVariable_Unpack(var);
if (tensor.defined()) {