[BE][autograd Function] Raise an error if input is returned as-is and saved for forward or backward in setup_context (#97212)

Fixes https://github.com/pytorch/pytorch/issues/96887

We error out in BOTH the case when graph is created and when it is not created.

Still bc-breaking, but not as severe because we are limiting to the case where someone uses setup_context.

This makes setup_context and non-setup_context versions diverge in their behavior
- With the non-setup_context version, saved variables are assumed to have the grad_fn of the inputs.
- But now with the setup_context version, we produce an error for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97212
Approved by: https://github.com/zou3519
This commit is contained in:
soulitzer
2023-03-27 18:20:43 -04:00
committed by PyTorch MergeBot
parent c7fa648ea1
commit f3aca45a16
5 changed files with 216 additions and 33 deletions

View File

@ -254,6 +254,28 @@ TEST(AutogradAPITests, AnomalyMode) {
double_backward_produce_nan(true);
}
TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2) {
ctx->save_for_backward({var1, var2});
return var1 * var2, var1;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {};
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
MyFunction::apply(x, y);
}
TEST(CustomAutogradTest, CustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(

View File

@ -1036,6 +1036,80 @@ class TestAutogradFunction(TestCase):
grad(f)(y, x)
grad(grad(f))(y, x)
@parametrize("inner_requires_grad", [True, False])
@parametrize("save_for", ["jvp", "vjp"])
@parametrize("save_tensors", ["input", "output", "neither"])
@parametrize("mark_dirty", [True, False])
def test_function_returns_input(self, device, inner_requires_grad, save_for, save_tensors, mark_dirty):
class A(torch.autograd.Function):
@staticmethod
def forward(x):
return x
@staticmethod
def setup_context(ctx, inputs, output):
if save_for == "jvp":
save_fn = ctx.save_for_forward
else:
save_fn = ctx.save_for_backward
if mark_dirty:
ctx.mark_dirty(inputs[0])
if save_tensors == "input":
save_fn(inputs[0])
elif save_tensors == "output":
save_fn(output)
elif save_tensors == "neither":
pass
@staticmethod
def backward(ctx, grad_output):
return grad_output
@staticmethod
def jvp(ctx, x_t):
# NB: the logic to check ctx.save_for_forward happens
# before we reach this!
if mark_dirty:
x_t.add_(0)
return x_t
def fn(x):
return A.apply(x.clone())
err_msg = "A input that has been returned as-is"
a = torch.tensor(2., device=device, requires_grad=inner_requires_grad)
a_t = torch.tensor(2., device=device, requires_grad=inner_requires_grad)
if save_tensors in ("input", "output") and not mark_dirty:
with self.assertRaisesRegex(RuntimeError, err_msg):
grad(fn)(a)
with self.assertRaisesRegex(RuntimeError, err_msg):
jvp(fn, (a,), (a_t,))
else:
grad(fn)(a)
jvp(fn, (a,), (a_t,))
a = torch.tensor(2., device=device, requires_grad=inner_requires_grad).clone()
a_t = torch.tensor(2., device=device, requires_grad=inner_requires_grad).clone()
if save_tensors in ("input", "output") and not mark_dirty:
with self.assertRaisesRegex(RuntimeError, err_msg):
A.apply(a)
with self.assertRaisesRegex(RuntimeError, err_msg):
with fwAD.dual_level():
A.apply(fwAD.make_dual(a, a_t))
elif mark_dirty:
b = A.apply(a)
if mark_dirty:
self.assertTrue(a is b)
with fwAD.dual_level():
a_dual = fwAD.make_dual(a, a_t)
b_dual = A.apply(a_dual)
if mark_dirty:
self.assertTrue(a_dual is b_dual)
def test_needs_input_grads(self, device):
class A(torch.autograd.Function):
@staticmethod

View File

@ -273,18 +273,28 @@ optional_variable_list _process_backward_mode_ad(
const std::unordered_set<at::TensorImpl*>& non_differentiable,
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata) {
const std::shared_ptr<Node>& cdata,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
int num_outputs = raw_outputs.size();
const char* error_msg_input_returned_as_is =
"A input that has been returned as-is as output is being saved for backward. "
"This is not supported if you override setup_context. You should return and "
"save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside "
"the forward function itself.";
// Sets the grad_fn and output_nr of an output Variable.
auto set_history = [&](Variable& var,
uint32_t output_nr,
bool is_input,
bool is_modified,
bool is_differentiable) {
bool is_differentiable,
bool is_saved_and_setup_context) {
if (!is_differentiable) {
if (!var.requires_grad()) {
if (is_input && !is_modified) {
TORCH_CHECK(
!is_saved_and_setup_context, error_msg_input_returned_as_is)
var = _view_as_self_with_no_grad(var);
}
return;
@ -339,6 +349,7 @@ optional_variable_list _process_backward_mode_ad(
impl::rebase_history(var, {cdata, output_nr});
}
} else if (is_input) {
TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
var = _view_as_self_with_no_grad(var);
impl::set_gradient_edge(var, {cdata, output_nr});
} else if (cdata) {
@ -370,12 +381,20 @@ optional_variable_list _process_backward_mode_ad(
bool is_differentiable = cdata &&
non_differentiable.count(out_tensor_impl) == 0 &&
isDifferentiableType(var.scalar_type());
bool is_saved_and_setup_context =
to_save_if_setup_context.count(out_tensor_impl) > 0;
if (cdata) {
auto output_nr = cdata->add_input_metadata(var);
AT_ASSERT(i == (int)output_nr);
}
set_history(var, i, is_input, is_modified, is_differentiable);
set_history(
var,
i,
is_input,
is_modified,
is_differentiable,
is_saved_and_setup_context);
// For deprecation cycle. Can be removed after 1.6. In the case where we
// detected a view in no grad mode during the forward, only warn the user
@ -427,7 +446,8 @@ optional_variable_list _wrap_outputs(
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
_jvp_fn_t jvp_user_function) {
_jvp_fn_t jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
inputs_mapping.reserve(input_vars.size());
for (const auto i : c10::irange(input_vars.size())) {
@ -435,7 +455,12 @@ optional_variable_list _wrap_outputs(
}
auto outputs = _process_backward_mode_ad(
inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata);
inputs_mapping,
non_differentiable,
dirty_inputs,
raw_outputs,
cdata,
to_save_if_setup_context);
// This must happen after the backward processing as we expect the
// computations happening here to track backward mode gradients.

View File

@ -20,7 +20,8 @@ TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
_jvp_fn_t jvp_user_function);
_jvp_fn_t jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context);
TORCH_API void check_variable_result(
const at::TensorBase& original,
@ -308,7 +309,8 @@ auto Function<T>::apply(Args&&... args)
node->ctx_.get_and_bump_dirty(),
to_optional(outputs),
is_executable ? node : nullptr,
jvp_fn);
jvp_fn,
{});
node->output_info_.reserve(wrapped_outputs.size());
for (auto& output : wrapped_outputs) {

View File

@ -366,7 +366,8 @@ static void _wrap_outputs(
const variable_list& input_vars,
PyObject* raw_output,
PyObject* outputs,
bool is_executable) {
bool is_executable,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
auto cdata_if_executable = is_executable ? cdata : nullptr;
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
if (is_executable) {
@ -462,7 +463,8 @@ static void _wrap_outputs(
dirty_inputs,
raw_output_vars,
cdata_if_executable,
std::move(jvp_user_function));
std::move(jvp_user_function),
to_save_if_setup_context);
for (const auto i : c10::irange(num_outputs)) {
PyObject* obj = PyTuple_GetItem(raw_output, i);
@ -482,36 +484,77 @@ static void _wrap_outputs(
}
}
static void _get_tensors_to_save(
THPFunction* self,
std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
std::vector<c10::optional<at::Tensor>>& tensors_to_save,
bool overridden_setup_context,
bool is_executable) {
if (self->saved_for_forward && overridden_setup_context) {
// We look at saved_for_forward here purely for the purpose of populating
// to_save_if_setup_context, the actual saving is not done here.
THPFunction_assert(
PyTuple_Check(self->saved_for_forward),
"autograd internal "
"error: saved_for_forward attribute is expected to be a tuple but is %s",
THPUtils_typename(self->saved_for_forward));
Py_ssize_t num_saved_for_forward =
PyTuple_GET_SIZE(self->saved_for_forward);
for (const auto i : c10::irange(num_saved_for_forward)) {
PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
}
}
}
if (self->to_save) {
THPFunction_assert(
PyTuple_Check(self->to_save),
"autograd internal "
"error: to_save attribute is expected to be a tuple but is %s",
THPUtils_typename(self->to_save));
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
for (const auto i : c10::irange(num_saved)) {
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
if (obj == Py_None) {
tensors_to_save.push_back(c10::nullopt);
continue;
} else if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
if (overridden_setup_context) {
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
}
if (is_executable) {
tensors_to_save.push_back(tensor);
}
} else {
throw torch::TypeError(
"save_for_backward can only save variables, but argument %ld is of "
"type %s",
i,
Py_TYPE(obj)->tp_name);
}
}
}
}
// Save any variables that requested by to_save
static void _save_variables(
const std::vector<c10::optional<at::Tensor>>& tensors_to_save,
const std::shared_ptr<PyNode>& cdata_ptr,
THPFunction* self) {
if (!self->to_save)
return;
THPFunction_assert(
PyTuple_Check(self->to_save),
"autograd internal "
"error: to_save attribute is expected to be a tuple but is %s",
THPUtils_typename(self->to_save));
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
size_t num_saved = tensors_to_save.size();
self->saved_variables.clear();
self->saved_variables.reserve(num_saved);
for (const auto i : c10::irange(num_saved)) {
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
if (obj == Py_None) {
for (const auto& opt_tensor : tensors_to_save) {
if (!opt_tensor.has_value()) {
self->saved_variables.emplace_back();
continue;
} else if (THPVariable_Check(obj)) {
const auto& tensor = THPVariable_Unpack(obj);
bool is_output = tensor.grad_fn().get() == cdata_ptr.get();
self->saved_variables.emplace_back(tensor, is_output);
} else {
throw torch::TypeError(
"save_for_backward can only save variables, but argument %ld is of "
"type %s",
i,
Py_TYPE(obj)->tp_name);
bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
}
}
// Free .to_save
@ -757,7 +800,8 @@ PyObject* process_outputs(
PyObject* inputs,
THPObjectPtr&& raw_output,
bool is_executable,
torch::jit::Node* node) {
torch::jit::Node* node,
bool overridden_setup_context) {
bool unpack_output = ensure_tuple(raw_output);
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
@ -777,9 +821,24 @@ PyObject* process_outputs(
}
}
std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
std::vector<c10::optional<at::Tensor>> tensors_to_save{};
_get_tensors_to_save(
grad_fn,
to_save_if_setup_context,
tensors_to_save,
overridden_setup_context,
is_executable);
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
_wrap_outputs(
cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
cdata,
grad_fn,
unpacked.input_vars,
raw_output,
outputs,
is_executable,
to_save_if_setup_context);
_trace_post_record(
node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
@ -787,7 +846,7 @@ PyObject* process_outputs(
// wrapping as the outputs must have their grad_fn/fw_grad properly set before
// we save them.
if (is_executable) {
_save_variables(cdata, grad_fn);
_save_variables(tensors_to_save, cdata, grad_fn);
} else {
// Remove unnecessary attributes
Py_XDECREF(grad_fn->to_save);
@ -1010,7 +1069,8 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
inputs,
std::move(output),
is_executable,
node);
node,
overridden_setup_context);
END_HANDLE_TH_ERRORS
}