mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Replace Variable.volatile with torch.no_grad() (#3970)
This removes volatile from Variable. The functionality is mostly replaced by a global (thread-local) flag, which is controlled by torch.set_grad_enabled() and the context manager torch.no_grad(). In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled() Fixes #3627
This commit is contained in:
@ -22,6 +22,10 @@ using namespace torch::autograd;
|
||||
|
||||
PyObject *THPVariableClass = NULL;
|
||||
|
||||
static const char* VOLATILE_WARNING =
|
||||
"volatile was removed and now has no effect. Use "
|
||||
"`with torch.no_grad():` instead.";
|
||||
|
||||
// Creates a new Python object for a Variable. The Variable must not already
|
||||
// have a PyObject* associated with it.
|
||||
static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
|
||||
@ -31,7 +35,7 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
|
||||
auto v = (THPVariable*) obj;
|
||||
new (&v->cdata) Variable(std::move(var));
|
||||
v->cdata.get()->pyobj = obj;
|
||||
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn().get())) {
|
||||
if (auto fn = dynamic_cast<PyFunction*>(v->cdata.get()->_grad_fn.get())) {
|
||||
// Create a new reference to the THPFunction. This ensures that ref count
|
||||
// of the THPFunction is at least the number of referring THPVariables.
|
||||
v->cdata.get()->_grad_fn = THPFunction_asFunction((THPFunction*)fn->obj);
|
||||
@ -58,30 +62,6 @@ PyObject * THPVariable_Wrap(Variable var)
|
||||
return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var));
|
||||
}
|
||||
|
||||
// This function DOES NOT steal a reference to data
|
||||
PyObject * THPVariable_NewVolatile(PyObject *data)
|
||||
{
|
||||
Variable v = make_variable(torch::createTensor(data), false, true);
|
||||
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v));
|
||||
if (obj) {
|
||||
((THPVariable*)obj)->data = data;
|
||||
Py_INCREF(data);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
// This function DOES NOT steal a reference to data
|
||||
PyObject * THPVariable_NewLeaf(PyObject *data)
|
||||
{
|
||||
Variable v = make_variable(torch::createTensor(data));
|
||||
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v));
|
||||
if (obj) {
|
||||
((THPVariable*)obj)->data = data;
|
||||
Py_INCREF(data);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
{
|
||||
Py_VISIT(self->data);
|
||||
@ -161,6 +141,10 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
data = _data.get();
|
||||
}
|
||||
|
||||
if (is_volatile) {
|
||||
PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
|
||||
}
|
||||
|
||||
THPUtils_assert(!(is_volatile && requires_grad),
|
||||
"Variable can't be volatile and require_grad at the same time!");
|
||||
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn),
|
||||
@ -174,7 +158,7 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
|
||||
var = make_variable(torch::createTensor(data), grad_fn_);
|
||||
} else {
|
||||
var = make_variable(torch::createTensor(data), requires_grad, is_volatile);
|
||||
var = make_variable(torch::createTensor(data), requires_grad);
|
||||
}
|
||||
|
||||
if (name)
|
||||
@ -314,58 +298,47 @@ PyObject *THPVariable_get_grad(THPVariable *self)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
int THPVariable_set_grad(THPVariable *self, PyObject *other)
|
||||
int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& var = self->cdata;
|
||||
if (other == Py_None) {
|
||||
if (py_grad == Py_None) {
|
||||
var.grad().reset();
|
||||
return 0;
|
||||
}
|
||||
|
||||
THPUtils_assertRet(-1, THPVariable_Check(other),
|
||||
"expected Variable or None (got %s)", THPUtils_typename(other));
|
||||
THPUtils_assertRet(-1, self != (THPVariable*)other,
|
||||
THPUtils_assertRet(-1, THPVariable_Check(py_grad),
|
||||
"expected Variable or None (got %s)", THPUtils_typename(py_grad));
|
||||
THPUtils_assertRet(-1, self != (THPVariable*)py_grad,
|
||||
"can't assign Variable as its own grad");
|
||||
|
||||
auto& data = var.data();
|
||||
auto& other_var = ((THPVariable*)other)->cdata;
|
||||
auto& other_data = other_var.data();
|
||||
auto& grad = ((THPVariable*)py_grad)->cdata;
|
||||
auto& sparseType = var.type().toBackend(var.is_cuda() ? kSparseCUDA : kSparseCPU);
|
||||
|
||||
// Make sure the data is ok
|
||||
THPUtils_assertRet(-1, other_data.type().ID() == data.type().ID(),
|
||||
THPUtils_assertRet(-1, grad.type() == var.type() || grad.type() == sparseType,
|
||||
"assigned grad has data of a different type");
|
||||
THPUtils_assertRet(-1, other_data.type().is_cuda() == data.type().is_cuda(),
|
||||
"assigned grad has data located on a different device");
|
||||
if (data.type().is_cuda()) {
|
||||
THPUtils_assertRet(-1, other_data.get_device() == data.get_device(),
|
||||
if (var.type().is_cuda()) {
|
||||
THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
|
||||
"assigned grad has data located on a different device");
|
||||
}
|
||||
THPUtils_assertRet(-1, other_data.sizes().vec() == data.sizes().vec(),
|
||||
THPUtils_assertRet(-1, grad.sizes().equals(var.sizes()),
|
||||
"assigned grad has data of a different size");
|
||||
|
||||
var.grad() = other_var;
|
||||
var.grad() = grad;
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject *THPVariable_get_volatile(THPVariable *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& var = self->cdata;
|
||||
return PyBool_FromLong(var.is_volatile());
|
||||
END_HANDLE_TH_ERRORS
|
||||
const char* msg = "volatile was removed (Variable.volatile is always False)";
|
||||
PyErr_WarnEx(PyExc_UserWarning, msg, 1);
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool");
|
||||
THPUtils_assertRet(-1, !self->cdata.grad_fn(),
|
||||
"volatile can only be set on leaf variables");
|
||||
self->cdata.is_volatile() = (obj == Py_True);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
|
||||
}
|
||||
|
||||
PyObject *THPVariable_get_output_nr(THPVariable *self)
|
||||
@ -387,7 +360,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
|
||||
auto& var = self->cdata;
|
||||
if (var.grad_fn()) {
|
||||
if (!var.is_leaf()) {
|
||||
const char *hint = "";
|
||||
if (obj == Py_False) {
|
||||
hint = " If you want to use a computed variable in a subgraph "
|
||||
@ -397,7 +370,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
|
||||
return -1;
|
||||
}
|
||||
var.requires_grad() = (obj == Py_True);
|
||||
var.get()->_requires_grad = (obj == Py_True);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user