Check input flags in Traceable

This commit is contained in:
Adam Paszke
2017-09-01 11:13:37 -07:00
parent 230721e198
commit ea888c1905
3 changed files with 123 additions and 51 deletions

View File

@ -188,6 +188,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple");
int num_inputs = PyTuple_GET_SIZE(inputs);
ctx.outputs = PyTuple_New(num_inputs);
if (!ctx.outputs) return NULL;
// First, find all relevant functions and fill ctx.output_map
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
@ -202,6 +203,8 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
THPUtils_assert(grad_fn, "One of the differentiated Variables appears to not have "
"been used in the graph");
THPUtils_assert(grad_fn->is_executable, "One of the differentiated Variables does "
"not require grad");
auto& fn_info = ctx.output_map[grad_fn];
fn_info.first.emplace_back(output_nr, i);
fn_info.second = is_leaf;