mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[BE][autograd Function] Raise an error if input is returned as-is and saved for forward or backward in setup_context (#97212)
Fixes https://github.com/pytorch/pytorch/issues/96887 We error out in BOTH the case when graph is created and when it is not created. Still bc-breaking, but not as severe because we are limiting to the case where someone uses setup_context. This makes setup_context and non-setup_context versions diverge in their behavior - With the non-setup_context version, saved variables are assumed to have the grad_fn of the inputs. - But now with the setup_context version, we produce an error for this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97212 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7fa648ea1
commit
f3aca45a16
@ -254,6 +254,28 @@ TEST(AutogradAPITests, AnomalyMode) {
|
||||
double_backward_produce_nan(true);
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
|
||||
struct MyFunction : public Function<MyFunction> {
|
||||
static Variable forward(
|
||||
AutogradContext* ctx,
|
||||
Variable var1,
|
||||
Variable var2) {
|
||||
ctx->save_for_backward({var1, var2});
|
||||
return var1 * var2, var1;
|
||||
}
|
||||
|
||||
static variable_list backward(
|
||||
AutogradContext* ctx,
|
||||
variable_list grad_output) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
Variable x = torch::randn({5, 5}, torch::requires_grad());
|
||||
Variable y = torch::randn({5, 5}, torch::requires_grad());
|
||||
MyFunction::apply(x, y);
|
||||
}
|
||||
|
||||
TEST(CustomAutogradTest, CustomFunction) {
|
||||
struct MyFunction : public Function<MyFunction> {
|
||||
static Variable forward(
|
||||
|
||||
@ -1036,6 +1036,80 @@ class TestAutogradFunction(TestCase):
|
||||
grad(f)(y, x)
|
||||
grad(grad(f))(y, x)
|
||||
|
||||
@parametrize("inner_requires_grad", [True, False])
|
||||
@parametrize("save_for", ["jvp", "vjp"])
|
||||
@parametrize("save_tensors", ["input", "output", "neither"])
|
||||
@parametrize("mark_dirty", [True, False])
|
||||
def test_function_returns_input(self, device, inner_requires_grad, save_for, save_tensors, mark_dirty):
|
||||
class A(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
if save_for == "jvp":
|
||||
save_fn = ctx.save_for_forward
|
||||
else:
|
||||
save_fn = ctx.save_for_backward
|
||||
|
||||
if mark_dirty:
|
||||
ctx.mark_dirty(inputs[0])
|
||||
|
||||
if save_tensors == "input":
|
||||
save_fn(inputs[0])
|
||||
elif save_tensors == "output":
|
||||
save_fn(output)
|
||||
elif save_tensors == "neither":
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_t):
|
||||
# NB: the logic to check ctx.save_for_forward happens
|
||||
# before we reach this!
|
||||
if mark_dirty:
|
||||
x_t.add_(0)
|
||||
return x_t
|
||||
|
||||
def fn(x):
|
||||
return A.apply(x.clone())
|
||||
|
||||
err_msg = "A input that has been returned as-is"
|
||||
|
||||
a = torch.tensor(2., device=device, requires_grad=inner_requires_grad)
|
||||
a_t = torch.tensor(2., device=device, requires_grad=inner_requires_grad)
|
||||
if save_tensors in ("input", "output") and not mark_dirty:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
grad(fn)(a)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
jvp(fn, (a,), (a_t,))
|
||||
else:
|
||||
grad(fn)(a)
|
||||
jvp(fn, (a,), (a_t,))
|
||||
|
||||
a = torch.tensor(2., device=device, requires_grad=inner_requires_grad).clone()
|
||||
a_t = torch.tensor(2., device=device, requires_grad=inner_requires_grad).clone()
|
||||
|
||||
if save_tensors in ("input", "output") and not mark_dirty:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
A.apply(a)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
with fwAD.dual_level():
|
||||
A.apply(fwAD.make_dual(a, a_t))
|
||||
elif mark_dirty:
|
||||
b = A.apply(a)
|
||||
if mark_dirty:
|
||||
self.assertTrue(a is b)
|
||||
with fwAD.dual_level():
|
||||
a_dual = fwAD.make_dual(a, a_t)
|
||||
b_dual = A.apply(a_dual)
|
||||
if mark_dirty:
|
||||
self.assertTrue(a_dual is b_dual)
|
||||
|
||||
def test_needs_input_grads(self, device):
|
||||
class A(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
||||
@ -273,18 +273,28 @@ optional_variable_list _process_backward_mode_ad(
|
||||
const std::unordered_set<at::TensorImpl*>& non_differentiable,
|
||||
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata) {
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
int num_outputs = raw_outputs.size();
|
||||
|
||||
const char* error_msg_input_returned_as_is =
|
||||
"A input that has been returned as-is as output is being saved for backward. "
|
||||
"This is not supported if you override setup_context. You should return and "
|
||||
"save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside "
|
||||
"the forward function itself.";
|
||||
|
||||
// Sets the grad_fn and output_nr of an output Variable.
|
||||
auto set_history = [&](Variable& var,
|
||||
uint32_t output_nr,
|
||||
bool is_input,
|
||||
bool is_modified,
|
||||
bool is_differentiable) {
|
||||
bool is_differentiable,
|
||||
bool is_saved_and_setup_context) {
|
||||
if (!is_differentiable) {
|
||||
if (!var.requires_grad()) {
|
||||
if (is_input && !is_modified) {
|
||||
TORCH_CHECK(
|
||||
!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
||||
var = _view_as_self_with_no_grad(var);
|
||||
}
|
||||
return;
|
||||
@ -339,6 +349,7 @@ optional_variable_list _process_backward_mode_ad(
|
||||
impl::rebase_history(var, {cdata, output_nr});
|
||||
}
|
||||
} else if (is_input) {
|
||||
TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
||||
var = _view_as_self_with_no_grad(var);
|
||||
impl::set_gradient_edge(var, {cdata, output_nr});
|
||||
} else if (cdata) {
|
||||
@ -370,12 +381,20 @@ optional_variable_list _process_backward_mode_ad(
|
||||
bool is_differentiable = cdata &&
|
||||
non_differentiable.count(out_tensor_impl) == 0 &&
|
||||
isDifferentiableType(var.scalar_type());
|
||||
bool is_saved_and_setup_context =
|
||||
to_save_if_setup_context.count(out_tensor_impl) > 0;
|
||||
|
||||
if (cdata) {
|
||||
auto output_nr = cdata->add_input_metadata(var);
|
||||
AT_ASSERT(i == (int)output_nr);
|
||||
}
|
||||
set_history(var, i, is_input, is_modified, is_differentiable);
|
||||
set_history(
|
||||
var,
|
||||
i,
|
||||
is_input,
|
||||
is_modified,
|
||||
is_differentiable,
|
||||
is_saved_and_setup_context);
|
||||
|
||||
// For deprecation cycle. Can be removed after 1.6. In the case where we
|
||||
// detected a view in no grad mode during the forward, only warn the user
|
||||
@ -427,7 +446,8 @@ optional_variable_list _wrap_outputs(
|
||||
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
_jvp_fn_t jvp_user_function) {
|
||||
_jvp_fn_t jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
|
||||
inputs_mapping.reserve(input_vars.size());
|
||||
for (const auto i : c10::irange(input_vars.size())) {
|
||||
@ -435,7 +455,12 @@ optional_variable_list _wrap_outputs(
|
||||
}
|
||||
|
||||
auto outputs = _process_backward_mode_ad(
|
||||
inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata);
|
||||
inputs_mapping,
|
||||
non_differentiable,
|
||||
dirty_inputs,
|
||||
raw_outputs,
|
||||
cdata,
|
||||
to_save_if_setup_context);
|
||||
|
||||
// This must happen after the backward processing as we expect the
|
||||
// computations happening here to track backward mode gradients.
|
||||
|
||||
@ -20,7 +20,8 @@ TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
|
||||
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
_jvp_fn_t jvp_user_function);
|
||||
_jvp_fn_t jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context);
|
||||
|
||||
TORCH_API void check_variable_result(
|
||||
const at::TensorBase& original,
|
||||
@ -308,7 +309,8 @@ auto Function<T>::apply(Args&&... args)
|
||||
node->ctx_.get_and_bump_dirty(),
|
||||
to_optional(outputs),
|
||||
is_executable ? node : nullptr,
|
||||
jvp_fn);
|
||||
jvp_fn,
|
||||
{});
|
||||
|
||||
node->output_info_.reserve(wrapped_outputs.size());
|
||||
for (auto& output : wrapped_outputs) {
|
||||
|
||||
@ -366,7 +366,8 @@ static void _wrap_outputs(
|
||||
const variable_list& input_vars,
|
||||
PyObject* raw_output,
|
||||
PyObject* outputs,
|
||||
bool is_executable) {
|
||||
bool is_executable,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
auto cdata_if_executable = is_executable ? cdata : nullptr;
|
||||
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
|
||||
if (is_executable) {
|
||||
@ -462,7 +463,8 @@ static void _wrap_outputs(
|
||||
dirty_inputs,
|
||||
raw_output_vars,
|
||||
cdata_if_executable,
|
||||
std::move(jvp_user_function));
|
||||
std::move(jvp_user_function),
|
||||
to_save_if_setup_context);
|
||||
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
PyObject* obj = PyTuple_GetItem(raw_output, i);
|
||||
@ -482,36 +484,77 @@ static void _wrap_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
static void _get_tensors_to_save(
|
||||
THPFunction* self,
|
||||
std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
std::vector<c10::optional<at::Tensor>>& tensors_to_save,
|
||||
bool overridden_setup_context,
|
||||
bool is_executable) {
|
||||
if (self->saved_for_forward && overridden_setup_context) {
|
||||
// We look at saved_for_forward here purely for the purpose of populating
|
||||
// to_save_if_setup_context, the actual saving is not done here.
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->saved_for_forward),
|
||||
"autograd internal "
|
||||
"error: saved_for_forward attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->saved_for_forward));
|
||||
Py_ssize_t num_saved_for_forward =
|
||||
PyTuple_GET_SIZE(self->saved_for_forward);
|
||||
for (const auto i : c10::irange(num_saved_for_forward)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
|
||||
if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (self->to_save) {
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->to_save),
|
||||
"autograd internal "
|
||||
"error: to_save attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->to_save));
|
||||
|
||||
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
|
||||
for (const auto i : c10::irange(num_saved)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
|
||||
if (obj == Py_None) {
|
||||
tensors_to_save.push_back(c10::nullopt);
|
||||
continue;
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
if (overridden_setup_context) {
|
||||
to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
if (is_executable) {
|
||||
tensors_to_save.push_back(tensor);
|
||||
}
|
||||
} else {
|
||||
throw torch::TypeError(
|
||||
"save_for_backward can only save variables, but argument %ld is of "
|
||||
"type %s",
|
||||
i,
|
||||
Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save any variables that requested by to_save
|
||||
static void _save_variables(
|
||||
const std::vector<c10::optional<at::Tensor>>& tensors_to_save,
|
||||
const std::shared_ptr<PyNode>& cdata_ptr,
|
||||
THPFunction* self) {
|
||||
if (!self->to_save)
|
||||
return;
|
||||
|
||||
THPFunction_assert(
|
||||
PyTuple_Check(self->to_save),
|
||||
"autograd internal "
|
||||
"error: to_save attribute is expected to be a tuple but is %s",
|
||||
THPUtils_typename(self->to_save));
|
||||
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
|
||||
size_t num_saved = tensors_to_save.size();
|
||||
self->saved_variables.clear();
|
||||
self->saved_variables.reserve(num_saved);
|
||||
for (const auto i : c10::irange(num_saved)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
|
||||
if (obj == Py_None) {
|
||||
for (const auto& opt_tensor : tensors_to_save) {
|
||||
if (!opt_tensor.has_value()) {
|
||||
self->saved_variables.emplace_back();
|
||||
continue;
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
bool is_output = tensor.grad_fn().get() == cdata_ptr.get();
|
||||
self->saved_variables.emplace_back(tensor, is_output);
|
||||
} else {
|
||||
throw torch::TypeError(
|
||||
"save_for_backward can only save variables, but argument %ld is of "
|
||||
"type %s",
|
||||
i,
|
||||
Py_TYPE(obj)->tp_name);
|
||||
bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
|
||||
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
|
||||
}
|
||||
}
|
||||
// Free .to_save
|
||||
@ -757,7 +800,8 @@ PyObject* process_outputs(
|
||||
PyObject* inputs,
|
||||
THPObjectPtr&& raw_output,
|
||||
bool is_executable,
|
||||
torch::jit::Node* node) {
|
||||
torch::jit::Node* node,
|
||||
bool overridden_setup_context) {
|
||||
bool unpack_output = ensure_tuple(raw_output);
|
||||
|
||||
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
|
||||
@ -777,9 +821,24 @@ PyObject* process_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
|
||||
std::vector<c10::optional<at::Tensor>> tensors_to_save{};
|
||||
_get_tensors_to_save(
|
||||
grad_fn,
|
||||
to_save_if_setup_context,
|
||||
tensors_to_save,
|
||||
overridden_setup_context,
|
||||
is_executable);
|
||||
|
||||
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
|
||||
_wrap_outputs(
|
||||
cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
|
||||
cdata,
|
||||
grad_fn,
|
||||
unpacked.input_vars,
|
||||
raw_output,
|
||||
outputs,
|
||||
is_executable,
|
||||
to_save_if_setup_context);
|
||||
_trace_post_record(
|
||||
node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
|
||||
|
||||
@ -787,7 +846,7 @@ PyObject* process_outputs(
|
||||
// wrapping as the outputs must have their grad_fn/fw_grad properly set before
|
||||
// we save them.
|
||||
if (is_executable) {
|
||||
_save_variables(cdata, grad_fn);
|
||||
_save_variables(tensors_to_save, cdata, grad_fn);
|
||||
} else {
|
||||
// Remove unnecessary attributes
|
||||
Py_XDECREF(grad_fn->to_save);
|
||||
@ -1010,7 +1069,8 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
|
||||
inputs,
|
||||
std::move(output),
|
||||
is_executable,
|
||||
node);
|
||||
node,
|
||||
overridden_setup_context);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user