mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add a requires_grad_() function to tensors. (#6771)
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
#include "torch/csrc/autograd/functions/accumulate_grad.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/generated/VariableType.h"
|
||||
#include "torch/csrc/autograd/utils/python_error_messages.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/tensor/python_tensor.h"
|
||||
@ -290,13 +291,7 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
|
||||
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
|
||||
auto& var = self->cdata;
|
||||
if (!var.is_leaf()) {
|
||||
const char *hint = "";
|
||||
if (obj == Py_False) {
|
||||
hint = " If you want to use a computed variable in a subgraph "
|
||||
"that doesn't require differentiation use "
|
||||
"var_no_grad = var.detach().";
|
||||
}
|
||||
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
|
||||
THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
|
||||
return -1;
|
||||
}
|
||||
var.set_requires_grad(obj == Py_True);
|
||||
|
Reference in New Issue
Block a user