mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Add ability to register prehooks to grad_fn (#83226)
This simply replicates the implementation of PyFunctionPostHooks Fixes https://github.com/pytorch/pytorch/issues/83120 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83226 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
02cfefb48c
commit
b567742038
@ -161,6 +161,11 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
|
||||
return registerFunctionHook(fn, hook);
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
return registerFunctionPreHook(fn, hook);
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
return THPUtils_packString(fn.name());
|
||||
@ -282,5 +287,35 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
// This is almost a copy of the function above except post -> pre
|
||||
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
||||
PyObject* dict = Py_None;
|
||||
for (const auto& hook : fn.pre_hooks()) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
||||
dict = pyhook->dict;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
THPObjectPtr register_fn(
|
||||
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
|
||||
if (!register_fn)
|
||||
return nullptr;
|
||||
THPObjectPtr res(
|
||||
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
|
||||
if (!res)
|
||||
return nullptr;
|
||||
|
||||
if (dict == Py_None) {
|
||||
dict = PyTuple_GET_ITEM(res.get(), 0);
|
||||
std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict));
|
||||
fn.add_pre_hook(std::move(hook));
|
||||
}
|
||||
|
||||
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
|
||||
Py_INCREF(handle);
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user