mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Pass variable_list of inputs to _wrap_outputs
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23037 Test Plan: Imported from OSS Differential Revision: D16380071 fbshipit-source-id: ae3333c02ef8a3c09b95bec7b8e92ce649553615
This commit is contained in:
@ -355,7 +355,7 @@ static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(THPFunction
|
||||
// do in this case. After this method is run, t2var is extended with
|
||||
// mappings for output tensors as well.
|
||||
static void _wrap_outputs(THPFunction *self,
|
||||
PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs, bool is_executable)
|
||||
const variable_list &input_vars, PyObject *raw_output, PyObject *outputs, bool is_executable)
|
||||
{
|
||||
auto cdata = is_executable ? THPFunction_asFunction(self) : nullptr;
|
||||
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
|
||||
@ -372,15 +372,6 @@ static void _wrap_outputs(THPFunction *self,
|
||||
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
|
||||
};
|
||||
|
||||
std::unordered_set<at::TensorImpl*> inputs;
|
||||
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(inputs_tuple, i);
|
||||
if (THPVariable_Check(obj)) {
|
||||
inputs.emplace(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
|
||||
}
|
||||
}
|
||||
|
||||
auto non_differentiable = _parse_non_differentiable(self);
|
||||
auto dirty_inputs = _mark_dirty(self);
|
||||
|
||||
@ -391,7 +382,7 @@ static void _wrap_outputs(THPFunction *self,
|
||||
raw_output_vars.push_back(as_variable(obj,i));
|
||||
}
|
||||
|
||||
auto wrapped_outputs = _wrap_outputs(inputs, non_differentiable, dirty_inputs, raw_output_vars, cdata);
|
||||
auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata);
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
if (is_executable) {
|
||||
self->output_info.emplace_back(wrapped_outputs[i]);
|
||||
@ -600,7 +591,7 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
|
||||
}
|
||||
|
||||
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
|
||||
_wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
|
||||
_wrap_outputs(grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
|
||||
_trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
|
||||
if (is_executable) {
|
||||
_save_variables(grad_fn);
|
||||
|
Reference in New Issue
Block a user