Expose arbitrary cpp autograd functions to Python (#11082)

Summary:
This is needed because the JIT declares some custom autograd functions.

colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11082

Differential Revision: D9580456

Pulled By: apaszke

fbshipit-source-id: 6bf00c1188a20b2ee6ecf60e5a0099f8263ad55a
This commit is contained in:
Adam Paszke
2018-08-30 14:15:59 -07:00
committed by Facebook Github Bot
parent 93bd291e55
commit f0142faab0
3 changed files with 27 additions and 5 deletions

View File

@ -11,6 +11,7 @@
#include "torch/csrc/autograd/python_hook.h"
#include "torch/csrc/autograd/python_anomaly_mode.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
@ -152,6 +153,10 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook)
return registerFunctionHook(fn, hook);
}
PyObject* THPCppFunction_name(PyObject* self) {
auto& fn = *((THPCppFunction*)self)->cdata;
return THPUtils_packString(fn.name());
}
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,
@ -184,8 +189,19 @@ PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
struct DefaultFunctionType {
DefaultFunctionType() {
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
Py_INCREF(&type);
}
PyTypeObject type;
};
PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
{
static DefaultFunctionType default_type;
if (!cdata) {
Py_RETURN_NONE;
}
@ -201,12 +217,13 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
} else {
auto& fn = *cdata;
auto it = cpp_function_types.find(std::type_index(typeid(fn)));
PyTypeObject* type;
if (it == cpp_function_types.end()) {
return PyErr_Format(PyExc_TypeError,
"Don't know how to create Python object for %s", typeid(fn).name());
type = &default_type.type;
} else {
type = (PyTypeObject*)it->second.get();
}
PyTypeObject* type = (PyTypeObject*)it->second.get();
THPObjectPtr obj(type->tp_alloc(type, 0));
if (!obj) return nullptr;
THPCppFunction* f = (THPCppFunction*)obj.get();