mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix cpp node instance check (#125875)
Mostly visible when calling multi_grad_hook and thus using this to test it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125875 Approved by: https://github.com/jackiexu1992, https://github.com/ezyang
This commit is contained in:
@ -20,8 +20,7 @@
|
||||
|
||||
using namespace torch::autograd;
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -227,6 +226,7 @@ PyTypeObject* _initFunctionPyTypeObject(
|
||||
const char* name,
|
||||
PyGetSetDef* function_properties,
|
||||
PyMethodDef* function_methods) {
|
||||
type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
|
||||
// NOLINTNEXTLINE(misc-redundant-expression)
|
||||
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
|
||||
type.tp_name = name;
|
||||
@ -251,15 +251,17 @@ static std::unordered_set<PyTypeObject*> cpp_function_types_set;
|
||||
struct DefaultFunctionType {
|
||||
DefaultFunctionType() : type() {
|
||||
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
|
||||
Py_INCREF(&type);
|
||||
}
|
||||
|
||||
PyTypeObject type;
|
||||
};
|
||||
|
||||
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
||||
PyTypeObject* get_default_type() {
|
||||
static DefaultFunctionType default_type;
|
||||
return &(default_type.type);
|
||||
}
|
||||
|
||||
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
||||
if (!cdata) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
@ -278,7 +280,7 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
PyTypeObject* type;
|
||||
if (it == cpp_function_types_map.end()) {
|
||||
type = &default_type.type;
|
||||
type = get_default_type();
|
||||
} else {
|
||||
type = (PyTypeObject*)it->second.get();
|
||||
}
|
||||
@ -305,6 +307,9 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
|
||||
|
||||
bool THPCppFunction_Check(PyObject* obj) {
|
||||
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
|
||||
if ((PyTypeObject*)type.get() == get_default_type()) {
|
||||
return true;
|
||||
}
|
||||
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
|
||||
cpp_function_types_set.end()) {
|
||||
return false;
|
||||
@ -374,5 +379,4 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
Reference in New Issue
Block a user