mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This commit is contained in:
committed by
Adam Paszke
parent
6a55d86234
commit
8fbab83c2a
@ -284,11 +284,16 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
|
||||
auto& var = self->cdata;
|
||||
auto requires_grad = (obj == Py_True);
|
||||
if (!var.is_leaf()) {
|
||||
THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
|
||||
return -1;
|
||||
}
|
||||
var.set_requires_grad(obj == Py_True);
|
||||
if (requires_grad && !var.is_floating_point()) {
|
||||
THPUtils_setError("only Tensors of floating point dtype can require gradients");
|
||||
return -1;
|
||||
}
|
||||
var.set_requires_grad(requires_grad);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user