Add ability to register prehooks to grad_fn (#83226)

This simply replicates the implementation of PyFunctionPostHooks

Fixes https://github.com/pytorch/pytorch/issues/83120
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83226
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2022-08-12 17:07:19 -04:00
committed by PyTorch MergeBot
parent 02cfefb48c
commit b567742038
6 changed files with 267 additions and 0 deletions

View File

@ -161,6 +161,11 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
return registerFunctionHook(fn, hook);
}
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
auto& fn = *((THPCppFunction*)self)->cdata;
return registerFunctionPreHook(fn, hook);
}
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
auto& fn = *((THPCppFunction*)self)->cdata;
return THPUtils_packString(fn.name());
@ -282,5 +287,35 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
return handle;
}
// This is almost a copy of the function above except post -> pre
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
PyObject* dict = Py_None;
for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
dict = pyhook->dict;
break;
}
}
THPObjectPtr register_fn(
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
if (!register_fn)
return nullptr;
THPObjectPtr res(
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
if (!res)
return nullptr;
if (dict == Py_None) {
dict = PyTuple_GET_ITEM(res.get(), 0);
std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict));
fn.add_pre_hook(std::move(hook));
}
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
Py_INCREF(handle);
return handle;
}
} // namespace autograd
} // namespace torch