mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Refactoring _wrap_outputs to remove python dependence.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22631 Test Plan: test suite Imported from OSS Differential Revision: D16185040 fbshipit-source-id: 9b83749f6c9cd05d13f54a3bb4801e263293252b
This commit is contained in:
@ -8,6 +8,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/autograd/custom_function.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
#include <torch/csrc/autograd/functions/basic_ops.h>
|
||||
@ -314,23 +315,24 @@ using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
|
||||
|
||||
// Bump the counters of all recorded dirty input tensors, adding each of them
|
||||
// into dirty_inputs. Also does some sanity checking.
|
||||
static std::vector<PyObject*> _mark_dirty(THPFunction *self)
|
||||
static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction *self)
|
||||
{
|
||||
// Increase versions of modified tensors
|
||||
std::vector<PyObject*> dirty_inputs;
|
||||
std::unordered_set<at::TensorImpl*> dirty_inputs;
|
||||
if (!self->dirty_tensors) return dirty_inputs;
|
||||
|
||||
THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd "
|
||||
"internal error: dirty_tensors attribute is expected to be a tuple "
|
||||
"but is %s", THPUtils_typename(self->dirty_tensors));
|
||||
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
|
||||
dirty_inputs.reserve(num_dirty);
|
||||
for (int i = 0; i < num_dirty; i++) {
|
||||
PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
|
||||
THPFunction_assert(THPVariable_Check(obj), "mark_dirty can "
|
||||
"only accept variables, but argument %d is of type %s", i,
|
||||
THPUtils_typename(obj));
|
||||
|
||||
dirty_inputs.push_back(obj);
|
||||
dirty_inputs.insert(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
|
||||
auto variable = (THPVariable*)obj;
|
||||
variable->cdata.bump_version();
|
||||
}
|
||||
@ -339,7 +341,7 @@ static std::vector<PyObject*> _mark_dirty(THPFunction *self)
|
||||
return dirty_inputs;
|
||||
}
|
||||
|
||||
static std::unordered_set<PyObject*> _parse_non_differentiable(THPFunction *self);
|
||||
static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(THPFunction *self);
|
||||
|
||||
// Given a Python tuple of raw output tensors (raw_output), set each of
|
||||
// the corresponding entries in a different Python tuple (outputs) with
|
||||
@ -362,15 +364,6 @@ static void _wrap_outputs(THPFunction *self,
|
||||
self->output_info.reserve(num_outputs);
|
||||
}
|
||||
|
||||
std::unordered_set<PyObject*> inputs;
|
||||
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i));
|
||||
}
|
||||
|
||||
auto non_differentiable = _parse_non_differentiable(self);
|
||||
auto dirty_inputs = _mark_dirty(self);
|
||||
|
||||
auto as_variable = [&](PyObject* obj, int i) -> Variable {
|
||||
if (THPVariable_Check(obj)) {
|
||||
return ((THPVariable*)obj)->cdata;
|
||||
@ -379,72 +372,31 @@ static void _wrap_outputs(THPFunction *self,
|
||||
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
|
||||
};
|
||||
|
||||
// Sets the grad_fn and output_nr of an output Variable.
|
||||
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
|
||||
bool is_differentiable) {
|
||||
if (!is_differentiable) {
|
||||
if (!var.requires_grad()) {
|
||||
return;
|
||||
}
|
||||
// NB: we don't support returning non-differentiable views that could require grad
|
||||
if (var.is_view()) {
|
||||
throw std::runtime_error("Returning Variables sharing storage with other Variables "
|
||||
"that require grad is not supported in Python functions. "
|
||||
"Please submit a feature request if you hit this error.");
|
||||
}
|
||||
// Return detached aliases of inputs, instead of changing their requires_grad
|
||||
// property.
|
||||
if (is_input) {
|
||||
var = var.detach();
|
||||
} else {
|
||||
var.detach_();
|
||||
}
|
||||
} else if (is_modified) {
|
||||
if (var.is_leaf() && var.requires_grad()) {
|
||||
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
|
||||
}
|
||||
// If the input was modified, transplant the grad_fn in the graph:
|
||||
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
|
||||
var.grad().reset();
|
||||
var.clear_hooks();
|
||||
if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
|
||||
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
|
||||
grad_acc->variable.reset();
|
||||
}
|
||||
if (cdata) {
|
||||
var.rebase_history({cdata, output_nr});
|
||||
}
|
||||
} else if (is_input) {
|
||||
// An input has been returned, but it wasn't modified. Return it as a view
|
||||
// so that we can attach a new grad_fn to the Variable.
|
||||
var = var.view_as(var);
|
||||
var.set_gradient_edge({cdata, output_nr});
|
||||
} else if (cdata) {
|
||||
var.set_gradient_edge({cdata, output_nr});
|
||||
std::unordered_set<at::TensorImpl*> inputs;
|
||||
int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(inputs_tuple, i);
|
||||
if (THPVariable_Check(obj)) {
|
||||
inputs.emplace(((THPVariable*)obj)->cdata.unsafeGetTensorImpl());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
auto non_differentiable = _parse_non_differentiable(self);
|
||||
auto dirty_inputs = _mark_dirty(self);
|
||||
|
||||
std::vector<Variable> raw_output_vars;
|
||||
raw_output_vars.reserve(num_outputs);
|
||||
for(int i = 0; i < num_outputs; ++i){
|
||||
PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
|
||||
raw_output_vars.push_back(as_variable(obj,i));
|
||||
}
|
||||
|
||||
bool is_input = inputs.count(obj) > 0;
|
||||
bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end();
|
||||
bool is_differentiable = is_executable && non_differentiable.count(obj) == 0;
|
||||
|
||||
// Note that output Variables may be repeated. In that case, the last call
|
||||
// to set_history wins.
|
||||
auto var = as_variable(obj, i);
|
||||
if (cdata) {
|
||||
auto output_nr = cdata->add_input_metadata(var);
|
||||
AT_ASSERT(i == (int)output_nr);
|
||||
}
|
||||
set_history(var, i, is_input, is_modified, is_differentiable);
|
||||
|
||||
auto wrapped_outputs = _wrap_outputs(inputs, non_differentiable, dirty_inputs, raw_output_vars, cdata);
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
if (is_executable) {
|
||||
self->output_info.emplace_back(var);
|
||||
self->output_info.emplace_back(wrapped_outputs[i]);
|
||||
}
|
||||
|
||||
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var));
|
||||
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(wrapped_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -480,10 +432,10 @@ static void _save_variables(THPFunction* self)
|
||||
}
|
||||
|
||||
// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable)
|
||||
static std::unordered_set<PyObject*>
|
||||
static std::unordered_set<at::TensorImpl*>
|
||||
_parse_non_differentiable(THPFunction *self)
|
||||
{
|
||||
std::unordered_set<PyObject*> set;
|
||||
std::unordered_set<at::TensorImpl*> set;
|
||||
if (!self->non_differentiable) return set;
|
||||
|
||||
THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd "
|
||||
@ -495,7 +447,7 @@ _parse_non_differentiable(THPFunction *self)
|
||||
PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i);
|
||||
THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable "
|
||||
"only accepts variable arguments, but got %s", THPUtils_typename(t));
|
||||
set.insert(t);
|
||||
set.insert(((THPVariable*)t)->cdata.unsafeGetTensorImpl());
|
||||
}
|
||||
Py_CLEAR(self->non_differentiable);
|
||||
return set;
|
||||
|
Reference in New Issue
Block a user