Fix cpp node instance check (#125875)

Mostly visible when calling multi_grad_hook and thus using this to test it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125875
Approved by: https://github.com/jackiexu1992, https://github.com/ezyang
This commit is contained in:
albanD
2024-05-11 21:31:10 +00:00
committed by PyTorch MergeBot
parent 07d6ab5aa2
commit 6ffc94fa62
2 changed files with 65 additions and 7 deletions

View File

@ -20,8 +20,7 @@
using namespace torch::autograd;
namespace torch {
namespace autograd {
namespace torch::autograd {
namespace {
@ -227,6 +226,7 @@ PyTypeObject* _initFunctionPyTypeObject(
const char* name,
PyGetSetDef* function_properties,
PyMethodDef* function_methods) {
type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
// NOLINTNEXTLINE(misc-redundant-expression)
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
type.tp_name = name;
@ -251,15 +251,17 @@ static std::unordered_set<PyTypeObject*> cpp_function_types_set;
struct DefaultFunctionType {
DefaultFunctionType() : type() {
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
Py_INCREF(&type);
}
PyTypeObject type;
};
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
PyTypeObject* get_default_type() {
static DefaultFunctionType default_type;
return &(default_type.type);
}
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
if (!cdata) {
Py_RETURN_NONE;
}
@ -278,7 +280,7 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
PyTypeObject* type;
if (it == cpp_function_types_map.end()) {
type = &default_type.type;
type = get_default_type();
} else {
type = (PyTypeObject*)it->second.get();
}
@ -305,6 +307,9 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
bool THPCppFunction_Check(PyObject* obj) {
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
if ((PyTypeObject*)type.get() == get_default_type()) {
return true;
}
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
cpp_function_types_set.end()) {
return false;
@ -374,5 +379,4 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
return handle;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd