Refactor attribute names in autograd

This commit is contained in:
Adam Paszke
2017-03-15 13:26:02 -07:00
committed by Soumith Chintala
parent 2197e4c766
commit 2ca787fcf4
33 changed files with 645 additions and 593 deletions

View File

@ -39,17 +39,17 @@ PyObject * THPVariable_Wrap(const std::shared_ptr<Variable>& var)
return var->pyobj;
}
// This function DOES NOT steal a reference to data and creator
// To create a leaf Variable pass NULL as creator.
PyObject * THPVariable_New(PyObject *data, PyObject *creator, bool requires_grad, bool is_volatile)
// This function DOES NOT steal a reference to data and grad_fn
PyObject * THPVariable_New(PyObject *data, PyObject *_grad_fn)
{
THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
THPUtils_assert(!creator || THPFunction_Check(creator), "creator must be a Function");
auto v = std::make_shared<Variable>(torch::createTensor(data), requires_grad, is_volatile);
THPUtils_assert(THPFunction_Check(_grad_fn), "grad_fn must be a Function");
THPFunction *grad_fn = (THPFunction*)_grad_fn;
auto v = std::make_shared<Variable>(torch::createTensor(data), grad_fn->cdata.is_executable, false);
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
if (obj) {
v->pyobj = obj;
v->creator = THPFunction_asFunction((THPFunction*)creator);
v->grad_fn = THPFunction_asFunction((THPFunction*)grad_fn);
((THPVariable*)obj)->data = data;
Py_INCREF(data);
}
@ -57,9 +57,16 @@ PyObject * THPVariable_New(PyObject *data, PyObject *creator, bool requires_grad
}
// This function DOES NOT steal a reference to data
PyObject * THPVariable_NewVolatile(PyObject *data)
PyObject * THPVariable_NewVolatile(PyObject *data, bool is_leaf)
{
return THPVariable_New(data, nullptr, false, true);
auto v = std::make_shared<Variable>(torch::createTensor(data), false, true, is_leaf);
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
if (obj) {
v->pyobj = obj;
((THPVariable*)obj)->data = data;
Py_INCREF(data);
}
return obj;
}
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
@ -67,7 +74,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
Py_VISIT(self->data);
Py_VISIT(self->backward_hooks);
if (self->cdata) {
if (auto fn = dynamic_cast<PyFunction*>(self->cdata->creator.get())) {
if (auto fn = dynamic_cast<PyFunction*>(self->cdata->grad_fn.get())) {
Py_VISIT(fn->obj);
}
for (auto& hook : self->cdata->pre_hooks) {
@ -102,17 +109,17 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
THPObjectPtr _data;
PyObject *data = NULL;
PyObject *creator = NULL;
PyObject *grad_fn = NULL;
char is_volatile = 0;
char requires_grad = 0;
const char *accepted_args[] = {"data", "creator", "volatile", "requires_grad", NULL};
const char *accepted_args[] = {"data", "grad_fn", "volatile", "requires_grad", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OObb", (char**)accepted_args,
&data, &creator, &is_volatile, &requires_grad))
&data, &grad_fn, &is_volatile, &requires_grad))
return NULL;
if (creator == Py_None)
creator = NULL;
if (grad_fn == Py_None)
grad_fn = NULL;
if (data == NULL || data == Py_None) {
// For legacy serialization code, create an empty tensor temporarily.
@ -123,9 +130,9 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
THPUtils_assert(!(is_volatile && requires_grad),
"Variable can't be volatile and require_grad at the same time!");
THPUtils_assert(!creator || THPFunction_Check(creator),
"Variable creator has to be a Function object or None, but got %s",
THPUtils_typename(creator));
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn),
"Variable grad_fn has to be a Function object or None, but got %s",
THPUtils_typename(grad_fn));
THPUtils_assert(THPModule_isTensor(data), "Variable data has to "
"be a tensor, but got %s", THPUtils_typename(data));
@ -133,7 +140,7 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject* self = THPVariable_NewWithVar(type, var);
if (self) {
var->pyobj = self;
var->creator = THPFunction_asFunction((THPFunction*)creator);
var->grad_fn = THPFunction_asFunction((THPFunction*)grad_fn);
((THPVariable*)self)->cdata = var;
((THPVariable*)self)->data = data;
Py_INCREF(data);
@ -148,13 +155,13 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
// The 'data' argument is optional in __new__ to handle legacy serialized
// Variables.
PyObject *data;
PyObject *creator = NULL;
PyObject *grad_fn = NULL;
char is_volatile = 0;
char requires_grad = 0;
const char *accepted_args[] = {"data", "creator", "volatile", "requires_grad", NULL};
const char *accepted_args[] = {"data", "grad_fn", "volatile", "requires_grad", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|Obb", (char**)accepted_args,
&data, &creator, &is_volatile, &requires_grad))
&data, &grad_fn, &is_volatile, &requires_grad))
return -1;
return 0;
@ -169,22 +176,27 @@ PyObject *THPVariable_get_version(THPVariable *self)
return PyInt_FromLong(**var.version_counter);
}
PyObject *THPVariable_get_creator(THPVariable *self)
PyObject *THPVariable_get_grad_fn(THPVariable *self)
{
auto& var = *self->cdata;
if (!var.creator) {
if (!var.grad_fn) {
Py_RETURN_NONE;
}
return functionToPyObject(var.creator);
return functionToPyObject(var.grad_fn);
}
int THPVariable_set_creator(THPVariable *self, PyObject *obj)
int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
{
THPUtils_assertRet(-1, obj == Py_None, "_creator can be only set to None");
self->cdata->creator = nullptr;
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
self->cdata->grad_fn = nullptr;
return 0;
}
PyObject *THPVariable_is_leaf(THPVariable *self)
{
return PyBool_FromLong(self->cdata->is_leaf);
}
PyObject * THPVariable_get_data(THPVariable *self)
{
if (!self->data) {
@ -247,7 +259,7 @@ PyObject *THPVariable_get_volatile(THPVariable *self)
int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
{
THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool");
THPUtils_assertRet(-1, !self->cdata->creator,
THPUtils_assertRet(-1, self->cdata->is_leaf,
"volatile can only be set on leaf variables");
auto& var = *self->cdata;
var.is_volatile = (obj == Py_True);
@ -270,7 +282,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
{
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
auto& var = *self->cdata;
if (var.creator) {
if (!var.is_leaf) {
const char *hint = "";
if (obj == Py_False) {
hint = " If you want to use a computed variable in a subgraph "
@ -280,7 +292,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
return -1;
}
var.requires_grad = (obj == Py_True);
var.is_executable = var.requires_grad = (obj == Py_True);
return 0;
}
@ -310,8 +322,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
static struct PyGetSetDef THPVariable_properties[] = {
{"_version", (getter)THPVariable_get_version, NULL, NULL, NULL},
{"creator", (getter)THPVariable_get_creator, NULL, NULL, NULL},
{"_creator", (getter)THPVariable_get_creator, (setter)THPVariable_set_creator, NULL, NULL},
{"grad_fn", (getter)THPVariable_get_grad_fn, NULL, NULL, NULL},
{"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, NULL, NULL},
{"is_leaf", (getter)THPVariable_is_leaf, NULL, NULL, NULL},
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL},
{"_grad", (getter)THPVariable_get_raw_grad, (setter)THPVariable_set_raw_grad, NULL, NULL},
{"grad", (getter)THPVariable_get_grad, NULL, NULL, NULL},