mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Reduce hook registration code duplication (#91418)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91418 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
8191c49f82
commit
ae52750d91
@ -261,6 +261,30 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
|
||||
cpp_function_types_set.insert(pytype);
|
||||
}
|
||||
|
||||
bool THPCppFunction_Check(PyObject* obj) {
|
||||
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
|
||||
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
|
||||
cpp_function_types_set.end()) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
|
||||
THPObjectPtr register_fn(
|
||||
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
|
||||
if (!register_fn) {
|
||||
return nullptr;
|
||||
}
|
||||
THPObjectPtr res(
|
||||
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
|
||||
if (!res) {
|
||||
return nullptr;
|
||||
}
|
||||
return res.release();
|
||||
}
|
||||
|
||||
PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
|
||||
PyObject* dict = Py_None;
|
||||
for (const auto& hook : fn.post_hooks()) {
|
||||
@ -269,16 +293,10 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
THPObjectPtr register_fn(
|
||||
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
|
||||
if (!register_fn)
|
||||
THPObjectPtr res{callRegisterFn(dict, hook)};
|
||||
if (!res) {
|
||||
return nullptr;
|
||||
THPObjectPtr res(
|
||||
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
|
||||
if (!res)
|
||||
return nullptr;
|
||||
|
||||
}
|
||||
if (dict == Py_None) {
|
||||
dict = PyTuple_GET_ITEM(res.get(), 0);
|
||||
std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict));
|
||||
@ -299,16 +317,10 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
THPObjectPtr register_fn(
|
||||
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
|
||||
if (!register_fn)
|
||||
THPObjectPtr res{callRegisterFn(dict, hook)};
|
||||
if (!res) {
|
||||
return nullptr;
|
||||
THPObjectPtr res(
|
||||
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
|
||||
if (!res)
|
||||
return nullptr;
|
||||
|
||||
}
|
||||
if (dict == Py_None) {
|
||||
dict = PyTuple_GET_ITEM(res.get(), 0);
|
||||
std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict));
|
||||
@ -320,15 +332,5 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
bool THPCppFunction_Check(PyObject* obj) {
|
||||
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
|
||||
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
|
||||
cpp_function_types_set.end()) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user