mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Refactor attribute names in autograd
This commit is contained in:
committed by
Soumith Chintala
parent
2197e4c766
commit
2ca787fcf4
@ -39,17 +39,17 @@ PyObject * THPVariable_Wrap(const std::shared_ptr<Variable>& var)
|
||||
return var->pyobj;
|
||||
}
|
||||
|
||||
// This function DOES NOT steal a reference to data and creator
|
||||
// To create a leaf Variable pass NULL as creator.
|
||||
PyObject * THPVariable_New(PyObject *data, PyObject *creator, bool requires_grad, bool is_volatile)
|
||||
// This function DOES NOT steal a reference to data and grad_fn
|
||||
PyObject * THPVariable_New(PyObject *data, PyObject *_grad_fn)
|
||||
{
|
||||
THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
|
||||
THPUtils_assert(!creator || THPFunction_Check(creator), "creator must be a Function");
|
||||
auto v = std::make_shared<Variable>(torch::createTensor(data), requires_grad, is_volatile);
|
||||
THPUtils_assert(THPFunction_Check(_grad_fn), "grad_fn must be a Function");
|
||||
THPFunction *grad_fn = (THPFunction*)_grad_fn;
|
||||
auto v = std::make_shared<Variable>(torch::createTensor(data), grad_fn->cdata.is_executable, false);
|
||||
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
|
||||
if (obj) {
|
||||
v->pyobj = obj;
|
||||
v->creator = THPFunction_asFunction((THPFunction*)creator);
|
||||
v->grad_fn = THPFunction_asFunction((THPFunction*)grad_fn);
|
||||
((THPVariable*)obj)->data = data;
|
||||
Py_INCREF(data);
|
||||
}
|
||||
@ -57,9 +57,16 @@ PyObject * THPVariable_New(PyObject *data, PyObject *creator, bool requires_grad
|
||||
}
|
||||
|
||||
// This function DOES NOT steal a reference to data
|
||||
PyObject * THPVariable_NewVolatile(PyObject *data)
|
||||
PyObject * THPVariable_NewVolatile(PyObject *data, bool is_leaf)
|
||||
{
|
||||
return THPVariable_New(data, nullptr, false, true);
|
||||
auto v = std::make_shared<Variable>(torch::createTensor(data), false, true, is_leaf);
|
||||
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
|
||||
if (obj) {
|
||||
v->pyobj = obj;
|
||||
((THPVariable*)obj)->data = data;
|
||||
Py_INCREF(data);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
@ -67,7 +74,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
Py_VISIT(self->data);
|
||||
Py_VISIT(self->backward_hooks);
|
||||
if (self->cdata) {
|
||||
if (auto fn = dynamic_cast<PyFunction*>(self->cdata->creator.get())) {
|
||||
if (auto fn = dynamic_cast<PyFunction*>(self->cdata->grad_fn.get())) {
|
||||
Py_VISIT(fn->obj);
|
||||
}
|
||||
for (auto& hook : self->cdata->pre_hooks) {
|
||||
@ -102,17 +109,17 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
THPObjectPtr _data;
|
||||
PyObject *data = NULL;
|
||||
PyObject *creator = NULL;
|
||||
PyObject *grad_fn = NULL;
|
||||
char is_volatile = 0;
|
||||
char requires_grad = 0;
|
||||
|
||||
const char *accepted_args[] = {"data", "creator", "volatile", "requires_grad", NULL};
|
||||
const char *accepted_args[] = {"data", "grad_fn", "volatile", "requires_grad", NULL};
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OObb", (char**)accepted_args,
|
||||
&data, &creator, &is_volatile, &requires_grad))
|
||||
&data, &grad_fn, &is_volatile, &requires_grad))
|
||||
return NULL;
|
||||
|
||||
if (creator == Py_None)
|
||||
creator = NULL;
|
||||
if (grad_fn == Py_None)
|
||||
grad_fn = NULL;
|
||||
|
||||
if (data == NULL || data == Py_None) {
|
||||
// For legacy serialization code, create an empty tensor temporarily.
|
||||
@ -123,9 +130,9 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
|
||||
THPUtils_assert(!(is_volatile && requires_grad),
|
||||
"Variable can't be volatile and require_grad at the same time!");
|
||||
THPUtils_assert(!creator || THPFunction_Check(creator),
|
||||
"Variable creator has to be a Function object or None, but got %s",
|
||||
THPUtils_typename(creator));
|
||||
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn),
|
||||
"Variable grad_fn has to be a Function object or None, but got %s",
|
||||
THPUtils_typename(grad_fn));
|
||||
THPUtils_assert(THPModule_isTensor(data), "Variable data has to "
|
||||
"be a tensor, but got %s", THPUtils_typename(data));
|
||||
|
||||
@ -133,7 +140,7 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
PyObject* self = THPVariable_NewWithVar(type, var);
|
||||
if (self) {
|
||||
var->pyobj = self;
|
||||
var->creator = THPFunction_asFunction((THPFunction*)creator);
|
||||
var->grad_fn = THPFunction_asFunction((THPFunction*)grad_fn);
|
||||
((THPVariable*)self)->cdata = var;
|
||||
((THPVariable*)self)->data = data;
|
||||
Py_INCREF(data);
|
||||
@ -148,13 +155,13 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
// The 'data' argument is optional in __new__ to handle legacy serialized
|
||||
// Variables.
|
||||
PyObject *data;
|
||||
PyObject *creator = NULL;
|
||||
PyObject *grad_fn = NULL;
|
||||
char is_volatile = 0;
|
||||
char requires_grad = 0;
|
||||
|
||||
const char *accepted_args[] = {"data", "creator", "volatile", "requires_grad", NULL};
|
||||
const char *accepted_args[] = {"data", "grad_fn", "volatile", "requires_grad", NULL};
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|Obb", (char**)accepted_args,
|
||||
&data, &creator, &is_volatile, &requires_grad))
|
||||
&data, &grad_fn, &is_volatile, &requires_grad))
|
||||
return -1;
|
||||
|
||||
return 0;
|
||||
@ -169,22 +176,27 @@ PyObject *THPVariable_get_version(THPVariable *self)
|
||||
return PyInt_FromLong(**var.version_counter);
|
||||
}
|
||||
|
||||
PyObject *THPVariable_get_creator(THPVariable *self)
|
||||
PyObject *THPVariable_get_grad_fn(THPVariable *self)
|
||||
{
|
||||
auto& var = *self->cdata;
|
||||
if (!var.creator) {
|
||||
if (!var.grad_fn) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
return functionToPyObject(var.creator);
|
||||
return functionToPyObject(var.grad_fn);
|
||||
}
|
||||
|
||||
int THPVariable_set_creator(THPVariable *self, PyObject *obj)
|
||||
int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
|
||||
{
|
||||
THPUtils_assertRet(-1, obj == Py_None, "_creator can be only set to None");
|
||||
self->cdata->creator = nullptr;
|
||||
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
|
||||
self->cdata->grad_fn = nullptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
PyObject *THPVariable_is_leaf(THPVariable *self)
|
||||
{
|
||||
return PyBool_FromLong(self->cdata->is_leaf);
|
||||
}
|
||||
|
||||
PyObject * THPVariable_get_data(THPVariable *self)
|
||||
{
|
||||
if (!self->data) {
|
||||
@ -247,7 +259,7 @@ PyObject *THPVariable_get_volatile(THPVariable *self)
|
||||
int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
|
||||
{
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool");
|
||||
THPUtils_assertRet(-1, !self->cdata->creator,
|
||||
THPUtils_assertRet(-1, self->cdata->is_leaf,
|
||||
"volatile can only be set on leaf variables");
|
||||
auto& var = *self->cdata;
|
||||
var.is_volatile = (obj == Py_True);
|
||||
@ -270,7 +282,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
{
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
|
||||
auto& var = *self->cdata;
|
||||
if (var.creator) {
|
||||
if (!var.is_leaf) {
|
||||
const char *hint = "";
|
||||
if (obj == Py_False) {
|
||||
hint = " If you want to use a computed variable in a subgraph "
|
||||
@ -280,7 +292,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
|
||||
return -1;
|
||||
}
|
||||
var.requires_grad = (obj == Py_True);
|
||||
var.is_executable = var.requires_grad = (obj == Py_True);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -310,8 +322,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
|
||||
|
||||
static struct PyGetSetDef THPVariable_properties[] = {
|
||||
{"_version", (getter)THPVariable_get_version, NULL, NULL, NULL},
|
||||
{"creator", (getter)THPVariable_get_creator, NULL, NULL, NULL},
|
||||
{"_creator", (getter)THPVariable_get_creator, (setter)THPVariable_set_creator, NULL, NULL},
|
||||
{"grad_fn", (getter)THPVariable_get_grad_fn, NULL, NULL, NULL},
|
||||
{"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, NULL, NULL},
|
||||
{"is_leaf", (getter)THPVariable_is_leaf, NULL, NULL, NULL},
|
||||
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL},
|
||||
{"_grad", (getter)THPVariable_get_raw_grad, (setter)THPVariable_set_raw_grad, NULL, NULL},
|
||||
{"grad", (getter)THPVariable_get_grad, NULL, NULL, NULL},
|
||||
|
Reference in New Issue
Block a user