only Tensors of floating point dtype can require gradients (see #7021) (#7034)

This commit is contained in:
Thomas Viehmann
2018-04-30 10:20:00 +02:00
committed by Adam Paszke
parent 6a55d86234
commit 8fbab83c2a
5 changed files with 55 additions and 6 deletions

View File

@ -284,11 +284,16 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
HANDLE_TH_ERRORS
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
auto& var = self->cdata;
auto requires_grad = (obj == Py_True);
if (!var.is_leaf()) {
THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
return -1;
}
var.set_requires_grad(obj == Py_True);
if (requires_grad && !var.is_floating_point()) {
THPUtils_setError("only Tensors of floating point dtype can require gradients");
return -1;
}
var.set_requires_grad(requires_grad);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}