Add pure view support in autograd Function (#164467)

Fix https://github.com/pytorch/pytorch/issues/73604

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164467
Approved by: https://github.com/ezyang, https://github.com/soulitzer
This commit is contained in:
albanD
2025-10-03 14:33:11 -04:00
committed by PyTorch MergeBot
parent f006aee601
commit 10335ffb2c
5 changed files with 125 additions and 60 deletions

View File

@ -8347,7 +8347,8 @@ for shape in [(1,), ()]:
class IdOneOutput(Function):
@staticmethod
def forward(ctx, a, b, make_view):
def forward(ctx, a, make_view, pure_view):
ctx._is_pure_view = pure_view
if make_view:
a = a.narrow(0, 0, 2)
else:
@ -8361,7 +8362,8 @@ for shape in [(1,), ()]:
class IdTwoOutput(Function):
@staticmethod
def forward(ctx, a, b, make_view):
def forward(ctx, a, b, make_view, pure_view):
ctx._is_pure_view = pure_view
if make_view:
a = a.narrow(0, 0, 2)
else:
@ -8375,11 +8377,12 @@ for shape in [(1,), ()]:
ga_nz[0] = False
else:
ga_nz[0] = True
return ga + gab, gab, None
return ga + gab, gab, None, None
class ViewOfTemp(Function):
@staticmethod
def forward(ctx, a, make_view):
def forward(ctx, a, make_view, pure_view):
ctx._is_pure_view = pure_view
ctx.save_for_backward(a)
if make_view:
a = a.narrow(0, 0, 2)
@ -8394,7 +8397,7 @@ for shape in [(1,), ()]:
(a,) = ctx.saved_tensors
res = torch.zeros_like(a)
res.select(0, 0).copy_(grad)
return res, None
return res, None, None
fn_id_to_inplace_on_view_err_msg = {
"one_output": (
@ -8403,71 +8406,96 @@ for shape in [(1,), ()]:
),
"two_output": (
"Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
" This view is the output of a function that returns multiple views."
" This view is the output of a function that returns multiple views.",
"Pure view custom Function can only have one input Tensor and one output Tensor."
" Open an issue if you need to support more.",
),
"view_of_temp": (
"Output 0 of ViewOfTempBackward is a view and is being "
"modified inplace. This view was created inside a custom Function"
"modified inplace. This view was created inside a custom Function",
"a view of a leaf Variable that requires grad is being used in an in-place operation",
),
}
for fn_id in ["one_output", "two_output", "view_of_temp"]:
for inplace in [True, False]:
for make_view in [True, False]:
# Used for special casing the tests below
output_is_a_view = make_view or fn_id == "view_of_temp"
for pure_view in [True, False]:
# Used for special casing the tests below
output_is_a_view = make_view or fn_id == "view_of_temp"
def fn(a, b):
# never modify a, b inplace for gracheck
a = a.clone()
b = b.clone()
if fn_id == "two_output":
tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
if inplace:
tmp1 += 3
tmp2 += 3
def fn(a, b):
# never modify a, b inplace for gracheck
a = a.clone()
b = b.clone()
if fn_id == "two_output":
tmp1, tmp2 = IdTwoOutput.apply(
a, b, make_view, pure_view
)
if inplace:
tmp1 += 3
tmp2 += 3
else:
tmp1 = tmp1 + 3
tmp2 = tmp2 + 3
tmp = tmp1 * tmp2
else:
tmp1 = tmp1 + 3
tmp2 = tmp2 + 3
tmp = tmp1 * tmp2
if fn_id == "one_output":
tmp = IdOneOutput.apply(a, make_view, pure_view)
else:
tmp = ViewOfTemp.apply(a + b, make_view, pure_view)
if inplace:
tmp += 3
else:
tmp = tmp + 3
return tmp.sum()
a = torch.ones(2, dtype=dtype, requires_grad=True)
b = torch.ones(2, dtype=dtype, requires_grad=True)
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id][
int(pure_view)
]
will_raise_error = (
(pure_view and fn_id == "two_output")
or (pure_view and fn_id == "view_of_temp" and inplace)
or (not pure_view and inplace and output_is_a_view)
)
if will_raise_error:
with self.assertRaisesRegex(RuntimeError, err_msg):
gradcheck(fn, (a, b), check_batched_grad=False)
else:
if fn_id == "one_output":
tmp = IdOneOutput.apply(a, b, make_view)
else:
tmp = ViewOfTemp.apply(a + b, make_view)
if inplace:
tmp += 3
else:
tmp = tmp + 3
gradcheck(fn, (a, b), check_batched_grad=False)
return tmp.sum()
# Was the custom backward called properly
bw_called[0] = 0
ga_nz[0] = True # For the case where the backward is called
a = torch.ones(2, dtype=dtype, requires_grad=True)
b = torch.ones(2, dtype=dtype, requires_grad=True)
expected_called = 1
expected_ga_nz = True
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
if will_raise_error:
expected_called = 0
with self.assertRaisesRegex(RuntimeError, err_msg):
fn(a, b)
else:
fn(a, b).abs().backward()
if not inplace or not output_is_a_view:
gradcheck(fn, (a, b), check_batched_grad=False)
if (
fn_id == "one_output"
and inplace
and output_is_a_view
and pure_view
):
# We expect the op to have been replayed and we leveraged the pure view
# to re-create the graph, so the original backward was not called
expected_called = 0
# Was the custom backward called properly
bw_called[0] = 0
ga_nz[0] = True # For the case where the backward is called
if inplace and output_is_a_view:
with self.assertRaisesRegex(RuntimeError, err_msg):
fn(a, b)
else:
fn(a, b).abs().backward()
expected_called = 1
expected_ga_nz = True
if output_is_a_view and inplace:
expected_called = 0
self.assertTrue(bw_called[0] == expected_called)
self.assertTrue(ga_nz[0] == expected_ga_nz)
self.assertTrue(bw_called[0] == expected_called)
self.assertTrue(ga_nz[0] == expected_ga_nz)
def test_autograd_simple_views_python(self):
self._do_test_autograd_simple_views_python(torch.double)

View File

@ -261,7 +261,8 @@ static optional_variable_list _process_backward_mode_ad(
const at::ArrayRef<std::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn) {
const _view_as_self_fn_t& view_as_self_fn,
bool pure_view) {
auto num_outputs = raw_outputs.size();
#ifndef STRIP_ERROR_MESSAGES
@ -404,7 +405,8 @@ static optional_variable_list _process_backward_mode_ad(
if (!(is_input && is_modified) && var.is_view()) {
// is_view() => diff_view_meta
auto diff_view_meta = impl::get_view_autograd_meta(var);
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
diff_view_meta->set_creation_meta(
pure_view ? CreationMeta::DEFAULT : CreationMeta::IN_CUSTOM_FUNCTION);
}
if (is_differentiable) {
@ -448,13 +450,20 @@ optional_variable_list _wrap_outputs(
const std::shared_ptr<Node>& cdata,
const _jvp_fn_t& jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn) {
const _view_as_self_fn_t& view_as_self_fn,
bool pure_view) {
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
inputs_mapping.reserve(input_vars.size());
for (const auto i : c10::irange(input_vars.size())) {
inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i);
}
// Limit pure views to 1-1 mapping as it is unclear if it is even
// possible to have a pure view for N-1 or 1-N.
TORCH_CHECK(
!pure_view || (input_vars.size() == 1 && raw_outputs.size() == 1),
"Pure view custom Function can only have one input Tensor and one output Tensor. Open an issue if you need to support more.");
auto outputs = _process_backward_mode_ad(
inputs_mapping,
non_differentiable,
@ -462,7 +471,8 @@ optional_variable_list _wrap_outputs(
raw_outputs,
cdata,
to_save_if_setup_context,
view_as_self_fn);
view_as_self_fn,
pure_view);
// This must happen after the backward processing as we expect the
// computations happening here to track backward mode gradients.

View File

@ -24,7 +24,8 @@ TORCH_API std::vector<std::optional<Variable>> _wrap_outputs(
const std::shared_ptr<Node>& cdata,
const _jvp_fn_t& jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn);
const _view_as_self_fn_t& view_as_self_fn,
bool pure_view);
TORCH_API void check_variable_result(
const at::TensorBase& original,
@ -523,7 +524,8 @@ auto Function<T>::apply(Args&&... args)
is_executable ? node : nullptr,
jvp_fn,
{},
view_as_self_fn);
view_as_self_fn,
false);
node->output_info_.reserve(wrapped_outputs.size());
for (auto& output : wrapped_outputs) {

View File

@ -539,6 +539,7 @@ static PyObject* THPFunction_new(
new (&self->saved_variables) std::vector<SavedVariable>();
new (&self->is_variable_input) std::vector<bool>();
self->materialize_grads = true;
self->pure_view = false;
self->materialize_non_diff_grads = true;
return obj;
}
@ -716,7 +717,8 @@ static void _wrap_outputs(
cdata_if_executable,
jvp_user_function,
to_save_if_setup_context,
view_as_self_fn);
view_as_self_fn,
self->pure_view);
for (const auto i : c10::irange(num_outputs)) {
PyObject* obj = PyTuple_GetItem(raw_output, i);
@ -1441,6 +1443,20 @@ int THPFunction_set_materialize_grads(
END_HANDLE_TH_ERRORS_RET(-1)
}
int THPFunction_set_pure_view(
THPFunction* self,
PyObject* value,
void* unused) {
HANDLE_TH_ERRORS
if (!PyBool_Check(value)) {
THPUtils_invalidArguments(value, nullptr, "set_pure_view", 1, "(bool)");
return -1;
}
self->pure_view = (value == Py_True);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPFunction_get_materialize_non_diff_grads(
THPFunction* self,
void* _unused) {
@ -1715,6 +1731,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
(setter)THPFunction_set_materialize_grads,
nullptr,
nullptr},
{"_is_pure_view",
nullptr,
(setter)THPFunction_set_pure_view,
nullptr,
nullptr},
{"_materialize_non_diff_grads",
(getter)THPFunction_get_materialize_non_diff_grads,
(setter)THPFunction_set_materialize_non_diff_grads,

View File

@ -109,6 +109,10 @@ struct THPFunction {
// Default is true.
bool materialize_grads;
// boolean indicating whether the function is a "pure view", meaning that
// replaying the view is enough to get a correct backward.
bool pure_view;
// boolean indicating whether to materialize output grad tensors
// corresponding to non-differentiable outputs. Normally, someone would
// already get this behavior by switching off materialize_grads,