Reduce hook registration code duplication (#91418)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91418
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2022-12-27 17:35:51 -05:00
committed by PyTorch MergeBot
parent 8191c49f82
commit ae52750d91

View File

@ -261,6 +261,30 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
cpp_function_types_set.insert(pytype);
}
bool THPCppFunction_Check(PyObject* obj) {
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
cpp_function_types_set.end()) {
return false;
} else {
return true;
}
}
PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
THPObjectPtr register_fn(
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
if (!register_fn) {
return nullptr;
}
THPObjectPtr res(
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
if (!res) {
return nullptr;
}
return res.release();
}
PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
PyObject* dict = Py_None;
for (const auto& hook : fn.post_hooks()) {
@ -269,16 +293,10 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
break;
}
}
THPObjectPtr register_fn(
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
if (!register_fn)
THPObjectPtr res{callRegisterFn(dict, hook)};
if (!res) {
return nullptr;
THPObjectPtr res(
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
if (!res)
return nullptr;
}
if (dict == Py_None) {
dict = PyTuple_GET_ITEM(res.get(), 0);
std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict));
@ -299,16 +317,10 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
break;
}
}
THPObjectPtr register_fn(
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
if (!register_fn)
THPObjectPtr res{callRegisterFn(dict, hook)};
if (!res) {
return nullptr;
THPObjectPtr res(
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
if (!res)
return nullptr;
}
if (dict == Py_None) {
dict = PyTuple_GET_ITEM(res.get(), 0);
std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict));
@ -320,15 +332,5 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
return handle;
}
bool THPCppFunction_Check(PyObject* obj) {
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
cpp_function_types_set.end()) {
return false;
} else {
return true;
}
}
} // namespace autograd
} // namespace torch