mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Check input flags in Traceable
This commit is contained in:
@ -188,6 +188,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
|
||||
THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple");
|
||||
int num_inputs = PyTuple_GET_SIZE(inputs);
|
||||
ctx.outputs = PyTuple_New(num_inputs);
|
||||
if (!ctx.outputs) return NULL;
|
||||
// First, find all relevant functions and fill ctx.output_map
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
PyObject *input = PyTuple_GET_ITEM(inputs, i);
|
||||
@ -202,6 +203,8 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
|
||||
}
|
||||
THPUtils_assert(grad_fn, "One of the differentiated Variables appears to not have "
|
||||
"been used in the graph");
|
||||
THPUtils_assert(grad_fn->is_executable, "One of the differentiated Variables does "
|
||||
"not require grad");
|
||||
auto& fn_info = ctx.output_map[grad_fn];
|
||||
fn_info.first.emplace_back(output_nr, i);
|
||||
fn_info.second = is_leaf;
|
||||
|
Reference in New Issue
Block a user