mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
[BE][autograd Function] Raise an error if input is returned as-is and saved for forward or backward in setup_context (#97212)
Fixes https://github.com/pytorch/pytorch/issues/96887 We error out in BOTH the case when graph is created and when it is not created. Still bc-breaking, but not as severe because we are limiting to the case where someone uses setup_context. This makes setup_context and non-setup_context versions diverge in their behavior - With the non-setup_context version, saved variables are assumed to have the grad_fn of the inputs. - But now with the setup_context version, we produce an error for this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97212 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7fa648ea1
commit
f3aca45a16
@ -366,7 +366,8 @@ static void _wrap_outputs(
|
||||
const variable_list& input_vars,
|
||||
PyObject* raw_output,
|
||||
PyObject* outputs,
|
||||
bool is_executable) {
|
||||
bool is_executable,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
auto cdata_if_executable = is_executable ? cdata : nullptr;
|
||||
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
|
||||
if (is_executable) {
|
||||
@ -462,7 +463,8 @@ static void _wrap_outputs(
|
||||
dirty_inputs,
|
||||
raw_output_vars,
|
||||
cdata_if_executable,
|
||||
std::move(jvp_user_function));
|
||||
std::move(jvp_user_function),
|
||||
to_save_if_setup_context);
|
||||
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
PyObject* obj = PyTuple_GetItem(raw_output, i);
|
||||
@ -482,36 +484,77 @@ static void _wrap_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
static void _get_tensors_to_save(
|
||||
THPFunction* self,
|
||||
std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
std::vector<c10::optional<at::Tensor>>& tensors_to_save,
|
||||
bool overridden_setup_context,
|
||||
bool is_executable) {
|
||||
if (self->saved_for_forward && overridden_setup_context) {
|
||||
// We look at saved_for_forward here purely for the purpose of populating
|
||||
// to_save_if_setup_context, the actual saving is not done here.
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->saved_for_forward),
|
||||
"autograd internal "
|
||||
"error: saved_for_forward attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->saved_for_forward));
|
||||
Py_ssize_t num_saved_for_forward =
|
||||
PyTuple_GET_SIZE(self->saved_for_forward);
|
||||
for (const auto i : c10::irange(num_saved_for_forward)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
|
||||
if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (self->to_save) {
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->to_save),
|
||||
"autograd internal "
|
||||
"error: to_save attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->to_save));
|
||||
|
||||
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
|
||||
for (const auto i : c10::irange(num_saved)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
|
||||
if (obj == Py_None) {
|
||||
tensors_to_save.push_back(c10::nullopt);
|
||||
continue;
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
if (overridden_setup_context) {
|
||||
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
if (is_executable) {
|
||||
tensors_to_save.push_back(tensor);
|
||||
}
|
||||
} else {
|
||||
throw torch::TypeError(
|
||||
"save_for_backward can only save variables, but argument %ld is of "
|
||||
"type %s",
|
||||
i,
|
||||
Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save any variables that requested by to_save
|
||||
static void _save_variables(
|
||||
const std::vector<c10::optional<at::Tensor>>& tensors_to_save,
|
||||
const std::shared_ptr<PyNode>& cdata_ptr,
|
||||
THPFunction* self) {
|
||||
if (!self->to_save)
|
||||
return;
|
||||
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->to_save),
|
||||
"autograd internal "
|
||||
"error: to_save attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->to_save));
|
||||
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
|
||||
size_t num_saved = tensors_to_save.size();
|
||||
self->saved_variables.clear();
|
||||
self->saved_variables.reserve(num_saved);
|
||||
for (const auto i : c10::irange(num_saved)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
|
||||
if (obj == Py_None) {
|
||||
for (const auto& opt_tensor : tensors_to_save) {
|
||||
if (!opt_tensor.has_value()) {
|
||||
self->saved_variables.emplace_back();
|
||||
continue;
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
bool is_output = tensor.grad_fn().get() == cdata_ptr.get();
|
||||
self->saved_variables.emplace_back(tensor, is_output);
|
||||
} else {
|
||||
throw torch::TypeError(
|
||||
"save_for_backward can only save variables, but argument %ld is of "
|
||||
"type %s",
|
||||
i,
|
||||
Py_TYPE(obj)->tp_name);
|
||||
bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
|
||||
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
|
||||
}
|
||||
}
|
||||
// Free .to_save
|
||||
@ -757,7 +800,8 @@ PyObject* process_outputs(
|
||||
PyObject* inputs,
|
||||
THPObjectPtr&& raw_output,
|
||||
bool is_executable,
|
||||
torch::jit::Node* node) {
|
||||
torch::jit::Node* node,
|
||||
bool overridden_setup_context) {
|
||||
bool unpack_output = ensure_tuple(raw_output);
|
||||
|
||||
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
|
||||
@ -777,9 +821,24 @@ PyObject* process_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
|
||||
std::vector<c10::optional<at::Tensor>> tensors_to_save{};
|
||||
_get_tensors_to_save(
|
||||
grad_fn,
|
||||
to_save_if_setup_context,
|
||||
tensors_to_save,
|
||||
overridden_setup_context,
|
||||
is_executable);
|
||||
|
||||
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
|
||||
_wrap_outputs(
|
||||
cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
|
||||
cdata,
|
||||
grad_fn,
|
||||
unpacked.input_vars,
|
||||
raw_output,
|
||||
outputs,
|
||||
is_executable,
|
||||
to_save_if_setup_context);
|
||||
_trace_post_record(
|
||||
node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
|
||||
|
||||
@ -787,7 +846,7 @@ PyObject* process_outputs(
|
||||
// wrapping as the outputs must have their grad_fn/fw_grad properly set before
|
||||
// we save them.
|
||||
if (is_executable) {
|
||||
_save_variables(cdata, grad_fn);
|
||||
_save_variables(tensors_to_save, cdata, grad_fn);
|
||||
} else {
|
||||
// Remove unnecessary attributes
|
||||
Py_XDECREF(grad_fn->to_save);
|
||||
@ -1010,7 +1069,8 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
|
||||
inputs,
|
||||
std::move(output),
|
||||
is_executable,
|
||||
node);
|
||||
node,
|
||||
overridden_setup_context);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user