mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Previously, our AST was a DAG, where shared Nodes indicated a computation should be reused. This commit rewrites the IR into a new functional representation which represents sharing explicitly using variable bindings. We offer a few justifications for this new style: 1. The new representation is not all that different from the old one; it is about as easy to construct, and the lack of an explicit graph doesn't negatively impact our ability to interpret the graph, since we've chosen, as a matter of design, to NOT have the IR participate in the actual execution of a graph. 2. The new let-binding representation has an implicit ordering, which we can use to conveniently keep track of the original order the trace showed up as. This automatically gives us a topsort, and gives us an easier to read textual representation of our IR: %14 = Embedding %11, %0, -1, None, 2, False, False %15 = Dropout %14, 0.2, True, False %16 = Index %12, 0 %17 = Index %12, 1 %18 = Index %13, 0 %19 = Index %13, 1 %20 = Index %15, 0 %21 = Linear %20, %1, %3 %22 = Linear %16, %2, %4 3. It moves us closer to a Futhark style language (http://futhark-lang.org/publications/pldi17.pdf). Major aspects of the diff - Node is replaced with Expr and Arg, a pair of mutually recursive structures which represent our new language. In BNF, the language looks like this: a ::= c | %i e ::= %i, ... = e | PyOp e, ... | Ret %i, ... Technically, Ret is not actually a return (no control flow is involved), it just tuples up a series of tensors (identified by variables). One important invariant is that locals are always tensors; they are never constants (this is asymmetric with Args.) - Arguments support Python constants. This is an important piece because many operators take extra Python literals like integers and tuples in order to specify extra parameters about how an operator operates. Adding this was essential to getting word_language_model to work. - As both Expr and Arg have multiple variants, there is new infrastructure for doing case on the variants using ExprVisitor and ArgVisitor. The strategy here is adapted from WebAssembly's visitors, although we have generalized to permit arbitrary argument forwarding, which is necessary to support tail-recursive visitor calls. TCO is important because our interpreter may recurse arbitrarily deep into a stack of nested lets. If users wish, they can also manually case on the type tag. - Tracing is now turned on and off using _tracer_enter/_tracer_exit in torch._C. _tracer_enter accepts a list of variables which are to be treated as arguments; _tracer_exit accepts the list of traced variables which should be returned when you reexecute the trace, and returns the trace expression which can be reexecuted. GlobalTracingState is a global variable which tracks whether or not we are tracing or not. - You use run_forward to execute a trace on some set of parameters. - When under tracing, variables keep track, via trace_local, what the name of their variables in the IR are. Here is a simple runner which leaks memory but can be used to JIT models: import torch.autograd.function as F import torch._C def jit(model): import types real_forward = model.forward def forward(self, *args): def flatten(x): return tuple(F._iter_variables(x)) if not hasattr(self, "saved_trace"): torch._C._tracer_enter(tuple(self.parameters()) + flatten(args)) out = real_forward(*args) self.saved_trace = torch._C._tracer_exit(flatten(out)) self.saved_outs = out return out else: flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args)) return F._unflatten(flat_out, self.saved_outs) Major problems: - Sanity checking is spotty at best, especially when users pass in variables. - The interpreter leaks tensor memory from the store. When we add back def-use we should be able to deallocate tensors as soon as we know they are no longer necessary. - The interpreter needs to reach feature parity with the old execution engine. From there, we need to see if backwards can be subsumed as well. - I still have no confidence in having memory managed everything correctly. This requires a close look. - Rather than return an *open* expression as a trace, we should return a *lambda* instead, which knows about how many formal parameters it requires. - The IR is not introspectable from Python at the moment, but this is simply a matter of implementing all the binding code. - The tracer is NOT reentrant (you can't trace while you're inside a trace.) Furthermore, no sanity checking is done if you try to incorrectly reuse things from one trace in another. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
428 lines
14 KiB
C++
428 lines
14 KiB
C++
#include "torch/csrc/autograd/python_variable.h"
|
|
#include "torch/csrc/autograd/python_ir.h"
|
|
|
|
#include <structmember.h>
|
|
|
|
#include "THP.h"
|
|
#include "torch/csrc/DynamicTypes.h"
|
|
#include "torch/csrc/Types.h"
|
|
#include "torch/csrc/autograd/python_cpp_function.h"
|
|
#include "torch/csrc/autograd/python_hook.h"
|
|
#include "torch/csrc/autograd/functions/accumulate_grad.h"
|
|
#include "torch/csrc/cuda/AutoGPU.h"
|
|
#include "torch/csrc/utils/auto_gil.h"
|
|
#include "torch/csrc/Exceptions.h"
|
|
|
|
|
|
using namespace torch::autograd;
|
|
|
|
PyObject *THPVariableClass = NULL;
|
|
|
|
static PyObject* THPVariable_NewWithVar(PyTypeObject* type, std::shared_ptr<Variable> var)
|
|
{
|
|
PyObject* obj = type->tp_alloc(type, 0);
|
|
if (obj) {
|
|
auto v = (THPVariable*) obj;
|
|
new (&v->cdata) std::shared_ptr<Variable>(std::move(var));
|
|
if (auto fn = dynamic_cast<PyFunction*>(v->cdata->grad_fn.get())) {
|
|
// Create a new reference to the THPFunction. This ensures that ref count
|
|
// of the THPFunction is at least the number of referring THPVariables.
|
|
v->cdata->grad_fn = THPFunction_asFunction((THPFunction*)fn->obj);
|
|
}
|
|
}
|
|
return obj;
|
|
}
|
|
|
|
PyObject * THPVariable_Wrap(const std::shared_ptr<Variable>& var)
|
|
{
|
|
if (!var) {
|
|
Py_RETURN_NONE;
|
|
} else if (var->pyobj) {
|
|
Py_INCREF(var->pyobj);
|
|
} else {
|
|
var->pyobj = THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, var);
|
|
THPVariable* py_var = (THPVariable*)var->pyobj;
|
|
py_var->data = torch::createPyObject(var->data);
|
|
}
|
|
return var->pyobj;
|
|
}
|
|
|
|
// This function DOES NOT steal a reference to data
|
|
PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& grad_fn)
|
|
{
|
|
THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
|
|
auto v = std::make_shared<Variable>(torch::createTensor(data), grad_fn->is_executable, false);
|
|
v->grad_fn = grad_fn;
|
|
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
|
|
if (obj) {
|
|
v->pyobj = obj;
|
|
Py_INCREF(data);
|
|
((THPVariable*)obj)->data = data;
|
|
}
|
|
return obj;
|
|
}
|
|
|
|
// This function DOES NOT steal a reference to data
|
|
PyObject * THPVariable_NewVolatile(PyObject *data)
|
|
{
|
|
auto v = std::make_shared<Variable>(torch::createTensor(data), false, true);
|
|
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
|
|
if (obj) {
|
|
v->pyobj = obj;
|
|
((THPVariable*)obj)->data = data;
|
|
Py_INCREF(data);
|
|
}
|
|
return obj;
|
|
}
|
|
|
|
// This function DOES NOT steal a reference to data
|
|
PyObject * THPVariable_NewLeaf(PyObject *data)
|
|
{
|
|
auto v = std::make_shared<Variable>(torch::createTensor(data), false, false);
|
|
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
|
|
if (obj) {
|
|
v->pyobj = obj;
|
|
((THPVariable*)obj)->data = data;
|
|
Py_INCREF(data);
|
|
}
|
|
return obj;
|
|
}
|
|
|
|
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
|
{
|
|
Py_VISIT(self->data);
|
|
Py_VISIT(self->backward_hooks);
|
|
if (self->cdata) {
|
|
if (auto fn = dynamic_cast<PyFunction*>(self->cdata->grad_fn.get())) {
|
|
Py_VISIT(fn->obj);
|
|
}
|
|
for (auto& hook : self->cdata->hooks) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
|
Py_VISIT(pyhook->dict);
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
static int THPVariable_clear(THPVariable *self)
|
|
{
|
|
Py_CLEAR(self->data);
|
|
Py_CLEAR(self->backward_hooks);
|
|
if (self->cdata) {
|
|
if (auto grad_acc = self->cdata->grad_accumulator.lock()) {
|
|
grad_acc->pre_hooks.clear();
|
|
}
|
|
self->cdata->pyobj = nullptr;
|
|
}
|
|
self->cdata.reset();
|
|
return 0;
|
|
}
|
|
|
|
static void THPVariable_dealloc(THPVariable* self)
|
|
{
|
|
PyObject_GC_UnTrack(self);
|
|
THPVariable_clear(self);
|
|
self->cdata.~shared_ptr<Variable>();
|
|
Py_TYPE(self)->tp_free((PyObject*)self);
|
|
}
|
|
|
|
PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
|
{
|
|
THPObjectPtr _data;
|
|
PyObject *data = NULL;
|
|
PyObject *grad_fn = NULL;
|
|
char is_volatile = 0;
|
|
char requires_grad = 0;
|
|
|
|
const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL};
|
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args,
|
|
&data, &requires_grad, &is_volatile, &grad_fn))
|
|
return NULL;
|
|
|
|
if (grad_fn == Py_None)
|
|
grad_fn = NULL;
|
|
|
|
if (data == NULL || data == Py_None) {
|
|
// For legacy serialization code, create an empty tensor temporarily.
|
|
at::Tensor tensor = at::CPU(at::kFloat).tensor();
|
|
_data = torch::createPyObject(tensor);
|
|
data = _data.get();
|
|
}
|
|
|
|
THPUtils_assert(!(is_volatile && requires_grad),
|
|
"Variable can't be volatile and require_grad at the same time!");
|
|
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn),
|
|
"Variable _grad_fn has to be a Function object or None, but got %s",
|
|
THPUtils_typename(grad_fn));
|
|
THPUtils_assert(THPModule_isTensor(data), "Variable data has to "
|
|
"be a tensor, but got %s", THPUtils_typename(data));
|
|
|
|
std::shared_ptr<Variable> var;
|
|
if (grad_fn) {
|
|
var = std::make_shared<Variable>(torch::createTensor(data), THPFunction_asFunction((THPFunction*)grad_fn));
|
|
} else {
|
|
var = std::make_shared<Variable>(torch::createTensor(data), requires_grad, is_volatile);
|
|
}
|
|
PyObject* self = THPVariable_NewWithVar(type, var);
|
|
if (self) {
|
|
var->pyobj = self;
|
|
((THPVariable*)self)->cdata = var;
|
|
((THPVariable*)self)->data = data;
|
|
Py_INCREF(data);
|
|
}
|
|
|
|
return self;
|
|
}
|
|
|
|
int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
|
|
{
|
|
// Ensures that calls to Variable() and subclasses contain data argument.
|
|
// The 'data' argument is optional in __new__ to handle legacy serialized
|
|
// Variables.
|
|
PyObject *data;
|
|
PyObject *grad_fn = NULL;
|
|
char is_volatile = 0;
|
|
char requires_grad = 0;
|
|
|
|
const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL};
|
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args,
|
|
&data, &requires_grad, &is_volatile, &grad_fn))
|
|
return -1;
|
|
|
|
return 0;
|
|
}
|
|
|
|
typedef PyObject *(*getter)(PyObject *, void *);
|
|
typedef int (*setter)(PyObject *, PyObject *, void *);
|
|
|
|
PyObject *THPVariable_get_version(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
return PyInt_FromLong(**var.version_counter);
|
|
}
|
|
|
|
PyObject *THPVariable_get_grad_fn(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
if (!var.grad_fn) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
return functionToPyObject(var.grad_fn);
|
|
}
|
|
|
|
int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
|
|
{
|
|
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
|
|
self->cdata->grad_fn = nullptr;
|
|
return 0;
|
|
}
|
|
|
|
PyObject *THPVariable_is_leaf(THPVariable *self)
|
|
{
|
|
return PyBool_FromLong(!self->cdata->grad_fn);
|
|
}
|
|
|
|
PyObject * THPVariable_get_data(THPVariable *self)
|
|
{
|
|
if (!self->data) {
|
|
self->data = torch::createPyObject(self->cdata->data);
|
|
}
|
|
Py_INCREF(self->data);
|
|
return self->data;
|
|
}
|
|
|
|
int THPVariable_set_data(THPVariable *self, PyObject *data)
|
|
{
|
|
THPUtils_assertRet(-1, THPModule_isTensor(data), "Variable data has to "
|
|
"be a tensor, but got %s", THPUtils_typename(data));
|
|
Py_INCREF(data);
|
|
Py_XDECREF(self->data);
|
|
self->data = data;
|
|
auto& var = *self->cdata;
|
|
auto tensor = torch::createTensor(data);
|
|
var.data.swap(tensor);
|
|
return 0;
|
|
}
|
|
|
|
PyObject *THPVariable_get_grad(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
if (!var.grad) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
return THPVariable_Wrap(var.grad);
|
|
}
|
|
|
|
int THPVariable_set_grad(THPVariable *self, PyObject *other)
|
|
{
|
|
auto& var = *self->cdata;
|
|
if (other == Py_None) {
|
|
var.grad.reset();
|
|
return 0;
|
|
}
|
|
|
|
THPUtils_assertRet(-1, THPVariable_Check(other),
|
|
"expected Variable or None (got %s)", THPUtils_typename(other));
|
|
THPUtils_assertRet(-1, self != (THPVariable*)other,
|
|
"can't assign Variable as its own grad");
|
|
|
|
auto& other_var = ((THPVariable*)other)->cdata;
|
|
|
|
// Make sure the data is ok
|
|
THPUtils_assertRet(-1, other_var->data.type().ID() == var.data.type().ID(),
|
|
"assigned grad has data of a different type");
|
|
THPUtils_assertRet(-1, other_var->data.type().isCuda() == var.data.type().isCuda(),
|
|
"assigned grad has data located on a different device");
|
|
if (var.data.type().isCuda()) {
|
|
THPUtils_assertRet(-1, other_var->data.get_device() == var.data.get_device(),
|
|
"assigned grad has data located on a different device");
|
|
}
|
|
THPUtils_assertRet(-1, other_var->data.sizes().vec() == var.data.sizes().vec(),
|
|
"assigned grad has data of a different size");
|
|
|
|
var.grad = other_var;
|
|
if (auto grad_acc = var.grad_accumulator.lock()) {
|
|
((AccumulateGrad*)grad_acc.get())->variable_grad = other_var;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
PyObject *THPVariable_get_volatile(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
return PyBool_FromLong(var.is_volatile);
|
|
}
|
|
|
|
int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
|
|
{
|
|
THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool");
|
|
THPUtils_assertRet(-1, !self->cdata->grad_fn,
|
|
"volatile can only be set on leaf variables");
|
|
auto& var = *self->cdata;
|
|
var.is_volatile = (obj == Py_True);
|
|
return 0;
|
|
}
|
|
|
|
PyObject *THPVariable_get_output_nr(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
return PyInt_FromLong(var.output_nr);
|
|
}
|
|
|
|
PyObject *THPVariable_get_requires_grad(THPVariable *self)
|
|
{
|
|
auto& var = *self->cdata;
|
|
return PyBool_FromLong(var.requires_grad);
|
|
}
|
|
|
|
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
|
{
|
|
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
|
|
auto& var = *self->cdata;
|
|
if (var.grad_fn) {
|
|
const char *hint = "";
|
|
if (obj == Py_False) {
|
|
hint = " If you want to use a computed variable in a subgraph "
|
|
"that doesn't require differentiation use "
|
|
"var_no_grad = var.detach().";
|
|
}
|
|
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
|
|
return -1;
|
|
}
|
|
var.requires_grad = obj == Py_True;
|
|
if (auto grad_accumulator = var.grad_accumulator.lock()) {
|
|
grad_accumulator->is_executable = var.requires_grad;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
PyObject *THPVariable_get_backwards_hooks(THPVariable *self)
|
|
{
|
|
if (self->backward_hooks) {
|
|
Py_INCREF(self->backward_hooks);
|
|
return self->backward_hooks;
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
|
|
{
|
|
if (obj == Py_None) {
|
|
obj = nullptr;
|
|
}
|
|
Py_XINCREF(obj);
|
|
Py_XDECREF(self->backward_hooks);
|
|
self->backward_hooks = obj;
|
|
self->cdata->hooks.clear();
|
|
if (obj) {
|
|
self->cdata->hooks.emplace_back(new PyFunctionPreHook(obj, 0));
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
static struct PyGetSetDef THPVariable_properties[] = {
|
|
{"_version", (getter)THPVariable_get_version, NULL, NULL, NULL},
|
|
{"grad_fn", (getter)THPVariable_get_grad_fn, NULL, NULL, NULL},
|
|
{"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, NULL, NULL},
|
|
{"is_leaf", (getter)THPVariable_is_leaf, NULL, NULL, NULL},
|
|
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL},
|
|
{"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, // only for legacy reasons
|
|
{"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL},
|
|
{"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, NULL, NULL},
|
|
{"output_nr", (getter)THPVariable_get_output_nr, NULL, NULL, NULL},
|
|
{"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, NULL, NULL},
|
|
{"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, NULL, NULL},
|
|
{NULL}
|
|
};
|
|
|
|
PyTypeObject THPVariableType = {
|
|
PyVarObject_HEAD_INIT(NULL, 0)
|
|
"torch._C._VariableBase", /* tp_name */
|
|
sizeof(THPVariable), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
(destructor)THPVariable_dealloc, /* tp_dealloc */
|
|
0, /* tp_print */
|
|
0, /* tp_getattr */
|
|
0, /* tp_setattr */
|
|
0, /* tp_reserved */
|
|
0, /* tp_repr */
|
|
0, /* tp_as_number */
|
|
0, /* tp_as_sequence */
|
|
0, /* tp_as_mapping */
|
|
0, /* tp_hash */
|
|
0, /* tp_call */
|
|
0, /* tp_str */
|
|
0, /* tp_getattro */
|
|
0, /* tp_setattro */
|
|
0, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
|
NULL, /* tp_doc */
|
|
(traverseproc)THPVariable_traverse, /* tp_traverse */
|
|
(inquiry)THPVariable_clear, /* tp_clear */
|
|
0, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
0, /* tp_iter */
|
|
0, /* tp_iternext */
|
|
0, /* tp_methods */
|
|
0, /* tp_members */
|
|
THPVariable_properties, /* tp_getset */
|
|
0, /* tp_base */
|
|
0, /* tp_dict */
|
|
0, /* tp_descr_get */
|
|
0, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
THPVariable_pyinit, /* tp_init */
|
|
0, /* tp_alloc */
|
|
THPVariable_pynew /* tp_new */
|
|
};
|
|
|
|
bool THPVariable_initModule(PyObject *module)
|
|
{
|
|
if (PyType_Ready(&THPVariableType) < 0)
|
|
return false;
|
|
Py_INCREF(&THPVariableType);
|
|
PyModule_AddObject(module, "_VariableBase", (PyObject *)&THPVariableType);
|
|
return true;
|
|
}
|