Add a requires_grad_() function to tensors. (#6771)

This commit is contained in:
gchanan
2018-04-19 13:47:24 -04:00
committed by GitHub
parent f6da2fd944
commit d0b0edf27a
4 changed files with 65 additions and 7 deletions

View File

@ -14,6 +14,7 @@
#include "torch/csrc/autograd/functions/accumulate_grad.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/utils/python_error_messages.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/tensor/python_tensor.h"
@ -290,13 +291,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
auto& var = self->cdata;
if (!var.is_leaf()) {
const char *hint = "";
if (obj == Py_False) {
hint = " If you want to use a computed variable in a subgraph "
"that doesn't require differentiation use "
"var_no_grad = var.detach().";
}
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
return -1;
}
var.set_requires_grad(obj == Py_True);