mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
This is a "re-implementation" of compiled autograd. The idea is that: - we leverage the existing autograd graph to construct a Python function that is able to run the autograd graph - then, we run torch.compile over this function This resolves some of the issues we have with the existing compiled autograd. - We're able to graph break in unsupported C++ autograd nodes - The existing compiled autograd uses make_fx to construct the autograd graph before applying torch.compile over that autograd graph. This requires unsound assumptions about input strides and Tensor subclasses. By replicated what PyTorch autograd does in Python, this POC does not have this problem. More on the motivation over at https://docs.google.com/document/d/11KZw4MGoZOLDWQbv6NWxscNUC7lu97M4IVMqfcbkdqA/edit
114 lines
5.0 KiB
C++
114 lines
5.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <memory>
|
|
#include <typeinfo>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct THPCppFunction {
|
|
PyObject_HEAD
|
|
std::shared_ptr<Node> cdata;
|
|
};
|
|
|
|
template <typename Ctor>
|
|
PyObject* CppFunction_pynew(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwds) {
|
|
THPObjectPtr obj(type->tp_alloc(type, 0));
|
|
if (!obj)
|
|
return nullptr;
|
|
THPCppFunction* f = (THPCppFunction*)obj.get();
|
|
HANDLE_TH_ERRORS
|
|
new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
|
|
END_HANDLE_TH_ERRORS
|
|
if (!f->cdata) {
|
|
return nullptr;
|
|
}
|
|
return obj.release();
|
|
}
|
|
|
|
#define THP_FUNCTION_DEFAULT_METHODS \
|
|
{(char*)"_register_hook_dict", \
|
|
THPCppFunction_register_hook_dict, \
|
|
METH_O, \
|
|
nullptr}, \
|
|
{(char*)"pre_hooks", THPCppFunction_pre_hooks, METH_NOARGS, nullptr}, \
|
|
{(char*)"post_hooks", THPCppFunction_post_hooks, METH_NOARGS, nullptr}, \
|
|
{(char*)"is_traceable", THPCppFunction_is_traceable, METH_NOARGS, nullptr}, \
|
|
{(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
|
|
{(char*)"register_prehook", \
|
|
THPCppFunction_register_prehook, \
|
|
METH_O, \
|
|
nullptr}, \
|
|
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, \
|
|
{(char*)"_sequence_nr", \
|
|
THPCppFunction_sequence_nr, \
|
|
METH_NOARGS, \
|
|
nullptr}, \
|
|
{ \
|
|
(char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
|
|
}
|
|
|
|
#define THP_FUNCTION_DEFAULT_PROPERTIES \
|
|
{(char*)"next_functions", \
|
|
THPCppFunction_next_functions, \
|
|
nullptr, \
|
|
nullptr, \
|
|
nullptr}, \
|
|
{(char*)"requires_grad", \
|
|
THPCppFunction_requires_grad, \
|
|
nullptr, \
|
|
nullptr, \
|
|
nullptr}, \
|
|
{(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \
|
|
{ \
|
|
(char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \
|
|
nullptr \
|
|
}
|
|
|
|
PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_metadata(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_requires_grad(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
|
|
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
|
|
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
|
|
PyObject* THPCppFunction_get_prehooks(PyObject* self, PyObject* noargs);
|
|
|
|
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
|
|
PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
|
|
PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused);
|
|
|
|
PyTypeObject* _initFunctionPyTypeObject(
|
|
PyTypeObject& type,
|
|
const char* name,
|
|
PyGetSetDef* function_properties,
|
|
PyMethodDef* function_methods);
|
|
|
|
PyObject* registerFunctionHook(Node& fn, PyObject* hook);
|
|
|
|
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
|
|
|
|
template <typename Ctor>
|
|
PyTypeObject* createForwardFunctionPyTypeObject(
|
|
PyTypeObject& type,
|
|
const char* name,
|
|
PyGetSetDef* function_properties = nullptr,
|
|
PyMethodDef* function_methods = nullptr) {
|
|
type.tp_new = &CppFunction_pynew<Ctor>;
|
|
return _initFunctionPyTypeObject(
|
|
type, name, function_properties, function_methods);
|
|
}
|
|
|
|
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
|
|
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
|
|
|
|
bool THPCppFunction_Check(PyObject* obj);
|
|
|
|
} // namespace torch::autograd
|