mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Expose arbitrary cpp autograd functions to Python (#11082)
Summary: This is needed because the JIT declares some custom autograd functions. colesbury Pull Request resolved: https://github.com/pytorch/pytorch/pull/11082 Differential Revision: D9580456 Pulled By: apaszke fbshipit-source-id: 6bf00c1188a20b2ee6ecf60e5a0099f8263ad55a
This commit is contained in:
committed by
Facebook Github Bot
parent
93bd291e55
commit
f0142faab0
@ -11,6 +11,7 @@
|
||||
#include "torch/csrc/autograd/python_hook.h"
|
||||
#include "torch/csrc/autograd/python_anomaly_mode.h"
|
||||
#include "torch/csrc/utils/auto_gil.h"
|
||||
#include "torch/csrc/utils/python_strings.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
|
||||
@ -152,6 +153,10 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook)
|
||||
return registerFunctionHook(fn, hook);
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_name(PyObject* self) {
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
return THPUtils_packString(fn.name());
|
||||
}
|
||||
|
||||
static struct PyMethodDef default_methods[] = {
|
||||
THP_FUNCTION_DEFAULT_METHODS,
|
||||
@ -184,8 +189,19 @@ PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
|
||||
|
||||
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
|
||||
|
||||
struct DefaultFunctionType {
|
||||
DefaultFunctionType() {
|
||||
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
|
||||
Py_INCREF(&type);
|
||||
}
|
||||
|
||||
PyTypeObject type;
|
||||
};
|
||||
|
||||
PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
|
||||
{
|
||||
static DefaultFunctionType default_type;
|
||||
|
||||
if (!cdata) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
@ -201,12 +217,13 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
|
||||
} else {
|
||||
auto& fn = *cdata;
|
||||
auto it = cpp_function_types.find(std::type_index(typeid(fn)));
|
||||
PyTypeObject* type;
|
||||
if (it == cpp_function_types.end()) {
|
||||
return PyErr_Format(PyExc_TypeError,
|
||||
"Don't know how to create Python object for %s", typeid(fn).name());
|
||||
type = &default_type.type;
|
||||
} else {
|
||||
type = (PyTypeObject*)it->second.get();
|
||||
}
|
||||
|
||||
PyTypeObject* type = (PyTypeObject*)it->second.get();
|
||||
THPObjectPtr obj(type->tp_alloc(type, 0));
|
||||
if (!obj) return nullptr;
|
||||
THPCppFunction* f = (THPCppFunction*)obj.get();
|
||||
|
Reference in New Issue
Block a user