[BE][autograd Function] Raise an error if input is returned as-is and saved for forward or backward in setup_context (#97212)

Fixes https://github.com/pytorch/pytorch/issues/96887

We error out in BOTH the case when graph is created and when it is not created.

Still bc-breaking, but not as severe because we are limiting to the case where someone uses setup_context.

This makes setup_context and non-setup_context versions diverge in their behavior
- With the non-setup_context version, saved variables are assumed to have the grad_fn of the inputs.
- But now with the setup_context version, we produce an error for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97212
Approved by: https://github.com/zou3519
This commit is contained in:
soulitzer
2023-03-27 18:20:43 -04:00
committed by PyTorch MergeBot
parent c7fa648ea1
commit f3aca45a16
5 changed files with 216 additions and 33 deletions

View File

@ -366,7 +366,8 @@ static void _wrap_outputs(
const variable_list& input_vars,
PyObject* raw_output,
PyObject* outputs,
bool is_executable) {
bool is_executable,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
auto cdata_if_executable = is_executable ? cdata : nullptr;
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
if (is_executable) {
@ -462,7 +463,8 @@ static void _wrap_outputs(
dirty_inputs,
raw_output_vars,
cdata_if_executable,
std::move(jvp_user_function));
std::move(jvp_user_function),
to_save_if_setup_context);
for (const auto i : c10::irange(num_outputs)) {
PyObject* obj = PyTuple_GetItem(raw_output, i);
@ -482,36 +484,77 @@ static void _wrap_outputs(
}
}
static void _get_tensors_to_save(
THPFunction* self,
std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
std::vector<c10::optional<at::Tensor>>& tensors_to_save,
bool overridden_setup_context,
bool is_executable) {
if (self->saved_for_forward && overridden_setup_context) {
// We look at saved_for_forward here purely for the purpose of populating
// to_save_if_setup_context, the actual saving is not done here.
THPFunction_assert(
PyTuple_Check(self->saved_for_forward),
"autograd internal "
"error: saved_for_forward attribute is expected to be a tuple but is %s",
THPUtils_typename(self->saved_for_forward));
Py_ssize_t num_saved_for_forward =
PyTuple_GET_SIZE(self->saved_for_forward);
for (const auto i : c10::irange(num_saved_for_forward)) {
PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
}
}
}
if (self->to_save) {
THPFunction_assert(
PyTuple_Check(self->to_save),
"autograd internal "
"error: to_save attribute is expected to be a tuple but is %s",
THPUtils_typename(self->to_save));
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
for (const auto i : c10::irange(num_saved)) {
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
if (obj == Py_None) {
tensors_to_save.push_back(c10::nullopt);
continue;
} else if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
if (overridden_setup_context) {
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
}
if (is_executable) {
tensors_to_save.push_back(tensor);
}
} else {
throw torch::TypeError(
"save_for_backward can only save variables, but argument %ld is of "
"type %s",
i,
Py_TYPE(obj)->tp_name);
}
}
}
}
// Save any variables that requested by to_save
static void _save_variables(
const std::vector<c10::optional<at::Tensor>>& tensors_to_save,
const std::shared_ptr<PyNode>& cdata_ptr,
THPFunction* self) {
if (!self->to_save)
return;
THPFunction_assert(
PyTuple_Check(self->to_save),
"autograd internal "
"error: to_save attribute is expected to be a tuple but is %s",
THPUtils_typename(self->to_save));
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
size_t num_saved = tensors_to_save.size();
self->saved_variables.clear();
self->saved_variables.reserve(num_saved);
for (const auto i : c10::irange(num_saved)) {
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
if (obj == Py_None) {
for (const auto& opt_tensor : tensors_to_save) {
if (!opt_tensor.has_value()) {
self->saved_variables.emplace_back();
continue;
} else if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
bool is_output = tensor.grad_fn().get() == cdata_ptr.get();
self->saved_variables.emplace_back(tensor, is_output);
} else {
throw torch::TypeError(
"save_for_backward can only save variables, but argument %ld is of "
"type %s",
i,
Py_TYPE(obj)->tp_name);
bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
}
}
// Free .to_save
@ -757,7 +800,8 @@ PyObject* process_outputs(
PyObject* inputs,
THPObjectPtr&& raw_output,
bool is_executable,
torch::jit::Node* node) {
torch::jit::Node* node,
bool overridden_setup_context) {
bool unpack_output = ensure_tuple(raw_output);
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
@ -777,9 +821,24 @@ PyObject* process_outputs(
}
}
std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
std::vector<c10::optional<at::Tensor>> tensors_to_save{};
_get_tensors_to_save(
grad_fn,
to_save_if_setup_context,
tensors_to_save,
overridden_setup_context,
is_executable);
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
_wrap_outputs(
cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
cdata,
grad_fn,
unpacked.input_vars,
raw_output,
outputs,
is_executable,
to_save_if_setup_context);
_trace_post_record(
node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
@ -787,7 +846,7 @@ PyObject* process_outputs(
// wrapping as the outputs must have their grad_fn/fw_grad properly set before
// we save them.
if (is_executable) {
_save_variables(cdata, grad_fn);
_save_variables(tensors_to_save, cdata, grad_fn);
} else {
// Remove unnecessary attributes
Py_XDECREF(grad_fn->to_save);
@ -1010,7 +1069,8 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
inputs,
std::move(output),
is_executable,
node);
node,
overridden_setup_context);
END_HANDLE_TH_ERRORS
}