Pass variable_list of inputs to _wrap_outputs

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23037

Test Plan: Imported from OSS

Differential Revision: D16380071

fbshipit-source-id: ae3333c02ef8a3c09b95bec7b8e92ce649553615
This commit is contained in:
mal
2019-07-19 12:18:50 -07:00
committed by Facebook Github Bot
parent 2ee0f0bc3a
commit 44493a623e
3 changed files with 13 additions and 16 deletions

View File

@ -355,7 +355,7 @@ static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(THPFunction
// do in this case. After this method is run, t2var is extended with
// mappings for output tensors as well.
static void _wrap_outputs(THPFunction *self,
PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs, bool is_executable)
const variable_list &input_vars, PyObject *raw_output, PyObject *outputs, bool is_executable)
{
auto cdata = is_executable ? THPFunction_asFunction(self) : nullptr;
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
@ -372,15 +372,6 @@ static void _wrap_outputs(THPFunction *self,
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
};
std::unordered_set<at::TensorImpl*> inputs;
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
for (int i = 0; i < num_inputs; i++) {
PyObject* obj = PyTuple_GET_ITEM(inputs_tuple, i);
if (THPVariable_Check(obj)) {
inputs.emplace(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
}
}
auto non_differentiable = _parse_non_differentiable(self);
auto dirty_inputs = _mark_dirty(self);
@ -391,7 +382,7 @@ static void _wrap_outputs(THPFunction *self,
raw_output_vars.push_back(as_variable(obj,i));
}
auto wrapped_outputs = _wrap_outputs(inputs, non_differentiable, dirty_inputs, raw_output_vars, cdata);
auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata);
for (int i = 0; i < num_outputs; i++) {
if (is_executable) {
self->output_info.emplace_back(wrapped_outputs[i]);
@ -600,7 +591,7 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
}
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
_wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
_wrap_outputs(grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
_trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
if (is_executable) {
_save_variables(grad_fn);