Refactoring _wrap_outputs to remove python dependence.

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22631

Test Plan:
test suite

Imported from OSS

Differential Revision: D16185040

fbshipit-source-id: 9b83749f6c9cd05d13f54a3bb4801e263293252b
This commit is contained in:
mal
2019-07-10 11:58:30 -07:00
committed by Facebook Github Bot
parent ec1b669d23
commit 58e20638f7
5 changed files with 129 additions and 76 deletions

View File

@ -8,6 +8,7 @@
#include <ATen/ATen.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
@ -314,23 +315,24 @@ using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
// Bump the counters of all recorded dirty input tensors, adding each of them
// into dirty_inputs. Also does some sanity checking.
static std::vector<PyObject*> _mark_dirty(THPFunction *self)
static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction *self)
{
// Increase versions of modified tensors
std::vector<PyObject*> dirty_inputs;
std::unordered_set<at::TensorImpl*> dirty_inputs;
if (!self->dirty_tensors) return dirty_inputs;
THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd "
"internal error: dirty_tensors attribute is expected to be a tuple "
"but is %s", THPUtils_typename(self->dirty_tensors));
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
dirty_inputs.reserve(num_dirty);
for (int i = 0; i < num_dirty; i++) {
PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
THPFunction_assert(THPVariable_Check(obj), "mark_dirty can "
"only accept variables, but argument %d is of type %s", i,
THPUtils_typename(obj));
dirty_inputs.push_back(obj);
dirty_inputs.insert(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
auto variable = (THPVariable*)obj;
variable->cdata.bump_version();
}
@ -339,7 +341,7 @@ static std::vector<PyObject*> _mark_dirty(THPFunction *self)
return dirty_inputs;
}
static std::unordered_set<PyObject*> _parse_non_differentiable(THPFunction *self);
static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(THPFunction *self);
// Given a Python tuple of raw output tensors (raw_output), set each of
// the corresponding entries in a different Python tuple (outputs) with
@ -362,15 +364,6 @@ static void _wrap_outputs(THPFunction *self,
self->output_info.reserve(num_outputs);
}
std::unordered_set<PyObject*> inputs;
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
for (int i = 0; i < num_inputs; i++) {
inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i));
}
auto non_differentiable = _parse_non_differentiable(self);
auto dirty_inputs = _mark_dirty(self);
auto as_variable = [&](PyObject* obj, int i) -> Variable {
if (THPVariable_Check(obj)) {
return ((THPVariable*)obj)->cdata;
@ -379,72 +372,31 @@ static void _wrap_outputs(THPFunction *self,
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
};
// Sets the grad_fn and output_nr of an output Variable.
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
bool is_differentiable) {
if (!is_differentiable) {
if (!var.requires_grad()) {
return;
}
// NB: we don't support returning non-differentiable views that could require grad
if (var.is_view()) {
throw std::runtime_error("Returning Variables sharing storage with other Variables "
"that require grad is not supported in Python functions. "
"Please submit a feature request if you hit this error.");
}
// Return detached aliases of inputs, instead of changing their requires_grad
// property.
if (is_input) {
var = var.detach();
} else {
var.detach_();
}
} else if (is_modified) {
if (var.is_leaf() && var.requires_grad()) {
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
}
// If the input was modified, transplant the grad_fn in the graph:
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
var.grad().reset();
var.clear_hooks();
if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
grad_acc->variable.reset();
}
if (cdata) {
var.rebase_history({cdata, output_nr});
}
} else if (is_input) {
// An input has been returned, but it wasn't modified. Return it as a view
// so that we can attach a new grad_fn to the Variable.
var = var.view_as(var);
var.set_gradient_edge({cdata, output_nr});
} else if (cdata) {
var.set_gradient_edge({cdata, output_nr});
std::unordered_set<at::TensorImpl*> inputs;
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
for (int i = 0; i < num_inputs; i++) {
PyObject* obj = PyTuple_GET_ITEM(inputs_tuple, i);
if (THPVariable_Check(obj)) {
inputs.emplace(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
}
};
}
for (int i = 0; i < num_outputs; i++) {
auto non_differentiable = _parse_non_differentiable(self);
auto dirty_inputs = _mark_dirty(self);
std::vector<Variable> raw_output_vars;
raw_output_vars.reserve(num_outputs);
for(int i = 0; i < num_outputs; ++i){
PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
raw_output_vars.push_back(as_variable(obj,i));
}
bool is_input = inputs.count(obj) > 0;
bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end();
bool is_differentiable = is_executable && non_differentiable.count(obj) == 0;
// Note that output Variables may be repeated. In that case, the last call
// to set_history wins.
auto var = as_variable(obj, i);
if (cdata) {
auto output_nr = cdata->add_input_metadata(var);
AT_ASSERT(i == (int)output_nr);
}
set_history(var, i, is_input, is_modified, is_differentiable);
auto wrapped_outputs = _wrap_outputs(inputs, non_differentiable, dirty_inputs, raw_output_vars, cdata);
for (int i = 0; i < num_outputs; i++) {
if (is_executable) {
self->output_info.emplace_back(var);
self->output_info.emplace_back(wrapped_outputs[i]);
}
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var));
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(wrapped_outputs[i]));
}
}
@ -480,10 +432,10 @@ static void _save_variables(THPFunction* self)
}
// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable)
static std::unordered_set<PyObject*>
static std::unordered_set<at::TensorImpl*>
_parse_non_differentiable(THPFunction *self)
{
std::unordered_set<PyObject*> set;
std::unordered_set<at::TensorImpl*> set;
if (!self->non_differentiable) return set;
THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd "
@ -495,7 +447,7 @@ _parse_non_differentiable(THPFunction *self)
PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i);
THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable "
"only accepts variable arguments, but got %s", THPUtils_typename(t));
set.insert(t);
set.insert(((THPVariable*)t)->cdata.unsafeGetTensorImpl());
}
Py_CLEAR(self->non_differentiable);
return set;