mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add tensor post accumulate grad hook API (#107063)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063 Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
bcede143bd
commit
3f655277d4
@ -36,6 +36,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <autograd/function_hook.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <structmember.h>
|
||||
#include <cstdint>
|
||||
@ -427,6 +428,7 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
return 0;
|
||||
}
|
||||
Py_CLEAR(self->backward_hooks);
|
||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
if (tensor.defined()) {
|
||||
// Two situations to consider:
|
||||
@ -1162,6 +1164,47 @@ int THPVariable_set_backwards_hooks(
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPVariable_get_post_accumulate_grad_hooks(
|
||||
THPVariable* self,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
|
||||
}
|
||||
if (self->post_accumulate_grad_hooks) {
|
||||
Py_INCREF(self->post_accumulate_grad_hooks);
|
||||
return self->post_accumulate_grad_hooks;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
int THPVariable_set_post_accumulate_grad_hooks(
|
||||
THPVariable* self,
|
||||
PyObject* obj,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_setter(
|
||||
self, "_post_accumulate_grad_hooks", obj);
|
||||
}
|
||||
THPUtils_assertRet(
|
||||
-1, obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
|
||||
if (obj == Py_None) {
|
||||
obj = nullptr;
|
||||
}
|
||||
Py_XINCREF(obj);
|
||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||
self->post_accumulate_grad_hooks = obj;
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
if (obj) {
|
||||
torch::autograd::impl::set_post_acc_grad_hooks(
|
||||
tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
|
||||
}
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -1476,6 +1519,11 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
(setter)THPVariable_set_backwards_hooks,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_post_accumulate_grad_hooks",
|
||||
(getter)THPVariable_get_post_accumulate_grad_hooks,
|
||||
(setter)THPVariable_set_post_accumulate_grad_hooks,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
|
||||
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
|
||||
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
|
||||
@ -2033,6 +2081,7 @@ static int THPVariable_subclass_traverse(
|
||||
|
||||
// Finally traverse THPVariable special stuff
|
||||
Py_VISIT(var->backward_hooks);
|
||||
Py_VISIT(var->post_accumulate_grad_hooks);
|
||||
if (!var->cdata.unsafeIsBorrowed()) {
|
||||
const auto& tensor = THPVariable_Unpack(var);
|
||||
if (tensor.defined()) {
|
||||
|
Reference in New Issue
Block a user