From af32d16a71681ca05c6d410fb1b9cee091d4577d Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 6 Oct 2025 18:21:03 +0000 Subject: [PATCH] Add pure view support in autograd Function (#164736) This is the same as https://github.com/pytorch/pytorch/pull/164467 But it needs to be co-deved due to internal insanity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164736 Approved by: https://github.com/soulitzer --- test/test_autograd.py | 134 ++++++++++++++---------- torch/csrc/autograd/custom_function.cpp | 18 +++- torch/csrc/autograd/custom_function.h | 6 +- torch/csrc/autograd/python_function.cpp | 23 +++- torch/csrc/autograd/python_function.h | 4 + 5 files changed, 125 insertions(+), 60 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 74ecefce0821..021659b81122 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8223,7 +8223,8 @@ for shape in [(1,), ()]: class IdOneOutput(Function): @staticmethod - def forward(ctx, a, b, make_view): + def forward(ctx, a, make_view, pure_view): + ctx._is_pure_view = pure_view if make_view: a = a.narrow(0, 0, 2) else: @@ -8237,7 +8238,8 @@ for shape in [(1,), ()]: class IdTwoOutput(Function): @staticmethod - def forward(ctx, a, b, make_view): + def forward(ctx, a, b, make_view, pure_view): + ctx._is_pure_view = pure_view if make_view: a = a.narrow(0, 0, 2) else: @@ -8251,11 +8253,12 @@ for shape in [(1,), ()]: ga_nz[0] = False else: ga_nz[0] = True - return ga + gab, gab, None + return ga + gab, gab, None, None class ViewOfTemp(Function): @staticmethod - def forward(ctx, a, make_view): + def forward(ctx, a, make_view, pure_view): + ctx._is_pure_view = pure_view ctx.save_for_backward(a) if make_view: a = a.narrow(0, 0, 2) @@ -8270,7 +8273,7 @@ for shape in [(1,), ()]: (a,) = ctx.saved_tensors res = torch.zeros_like(a) res.select(0, 0).copy_(grad) - return res, None + return res, None, None fn_id_to_inplace_on_view_err_msg = { "one_output": ( @@ -8279,71 +8282,96 @@ for shape in [(1,), ()]: ), "two_output": ( "Output 0 of IdTwoOutputBackward is a view and is being modified inplace." - " This view is the output of a function that returns multiple views." + " This view is the output of a function that returns multiple views.", + "Pure view custom Function can only have one input Tensor and one output Tensor." + " Open an issue if you need to support more.", ), "view_of_temp": ( "Output 0 of ViewOfTempBackward is a view and is being " - "modified inplace. This view was created inside a custom Function" + "modified inplace. This view was created inside a custom Function", + "a view of a leaf Variable that requires grad is being used in an in-place operation", ), } for fn_id in ["one_output", "two_output", "view_of_temp"]: for inplace in [True, False]: for make_view in [True, False]: - # Used for special casing the tests below - output_is_a_view = make_view or fn_id == "view_of_temp" + for pure_view in [True, False]: + # Used for special casing the tests below + output_is_a_view = make_view or fn_id == "view_of_temp" - def fn(a, b): - # never modify a, b inplace for gracheck - a = a.clone() - b = b.clone() - if fn_id == "two_output": - tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view) - if inplace: - tmp1 += 3 - tmp2 += 3 + def fn(a, b): + # never modify a, b inplace for gracheck + a = a.clone() + b = b.clone() + if fn_id == "two_output": + tmp1, tmp2 = IdTwoOutput.apply( + a, b, make_view, pure_view + ) + if inplace: + tmp1 += 3 + tmp2 += 3 + else: + tmp1 = tmp1 + 3 + tmp2 = tmp2 + 3 + tmp = tmp1 * tmp2 else: - tmp1 = tmp1 + 3 - tmp2 = tmp2 + 3 - tmp = tmp1 * tmp2 + if fn_id == "one_output": + tmp = IdOneOutput.apply(a, make_view, pure_view) + else: + tmp = ViewOfTemp.apply(a + b, make_view, pure_view) + if inplace: + tmp += 3 + else: + tmp = tmp + 3 + + return tmp.sum() + + a = torch.ones(2, dtype=dtype, requires_grad=True) + b = torch.ones(2, dtype=dtype, requires_grad=True) + + err_msg = fn_id_to_inplace_on_view_err_msg[fn_id][ + int(pure_view) + ] + + will_raise_error = ( + (pure_view and fn_id == "two_output") + or (pure_view and fn_id == "view_of_temp" and inplace) + or (not pure_view and inplace and output_is_a_view) + ) + + if will_raise_error: + with self.assertRaisesRegex(RuntimeError, err_msg): + gradcheck(fn, (a, b), check_batched_grad=False) else: - if fn_id == "one_output": - tmp = IdOneOutput.apply(a, b, make_view) - else: - tmp = ViewOfTemp.apply(a + b, make_view) - if inplace: - tmp += 3 - else: - tmp = tmp + 3 + gradcheck(fn, (a, b), check_batched_grad=False) - return tmp.sum() + # Was the custom backward called properly + bw_called[0] = 0 + ga_nz[0] = True # For the case where the backward is called - a = torch.ones(2, dtype=dtype, requires_grad=True) - b = torch.ones(2, dtype=dtype, requires_grad=True) + expected_called = 1 + expected_ga_nz = True - err_msg = fn_id_to_inplace_on_view_err_msg[fn_id] + if will_raise_error: + expected_called = 0 + with self.assertRaisesRegex(RuntimeError, err_msg): + fn(a, b) + else: + fn(a, b).abs().backward() - if not inplace or not output_is_a_view: - gradcheck(fn, (a, b), check_batched_grad=False) + if ( + fn_id == "one_output" + and inplace + and output_is_a_view + and pure_view + ): + # We expect the op to have been replayed and we leveraged the pure view + # to re-create the graph, so the original backward was not called + expected_called = 0 - # Was the custom backward called properly - bw_called[0] = 0 - ga_nz[0] = True # For the case where the backward is called - - if inplace and output_is_a_view: - with self.assertRaisesRegex(RuntimeError, err_msg): - fn(a, b) - else: - fn(a, b).abs().backward() - - expected_called = 1 - expected_ga_nz = True - - if output_is_a_view and inplace: - expected_called = 0 - - self.assertTrue(bw_called[0] == expected_called) - self.assertTrue(ga_nz[0] == expected_ga_nz) + self.assertTrue(bw_called[0] == expected_called) + self.assertTrue(ga_nz[0] == expected_ga_nz) def test_autograd_simple_views_python(self): self._do_test_autograd_simple_views_python(torch.double) diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 135637c26acf..4104dac14a5f 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -261,7 +261,8 @@ static optional_variable_list _process_backward_mode_ad( const at::ArrayRef> raw_outputs, const std::shared_ptr& cdata, const std::unordered_set& to_save_if_setup_context, - const _view_as_self_fn_t& view_as_self_fn) { + const _view_as_self_fn_t& view_as_self_fn, + bool pure_view) { auto num_outputs = raw_outputs.size(); #ifndef STRIP_ERROR_MESSAGES @@ -404,7 +405,8 @@ static optional_variable_list _process_backward_mode_ad( if (!(is_input && is_modified) && var.is_view()) { // is_view() => diff_view_meta auto diff_view_meta = impl::get_view_autograd_meta(var); - diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION); + diff_view_meta->set_creation_meta( + pure_view ? CreationMeta::DEFAULT : CreationMeta::IN_CUSTOM_FUNCTION); } if (is_differentiable) { @@ -448,13 +450,20 @@ optional_variable_list _wrap_outputs( const std::shared_ptr& cdata, const _jvp_fn_t& jvp_user_function, const std::unordered_set& to_save_if_setup_context, - const _view_as_self_fn_t& view_as_self_fn) { + const _view_as_self_fn_t& view_as_self_fn, + bool pure_view) { std::unordered_map inputs_mapping; inputs_mapping.reserve(input_vars.size()); for (const auto i : c10::irange(input_vars.size())) { inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i); } + // Limit pure views to 1-1 mapping as it is unclear if it is even + // possible to have a pure view for N-1 or 1-N. + TORCH_CHECK( + !pure_view || (input_vars.size() == 1 && raw_outputs.size() == 1), + "Pure view custom Function can only have one input Tensor and one output Tensor. Open an issue if you need to support more."); + auto outputs = _process_backward_mode_ad( inputs_mapping, non_differentiable, @@ -462,7 +471,8 @@ optional_variable_list _wrap_outputs( raw_outputs, cdata, to_save_if_setup_context, - view_as_self_fn); + view_as_self_fn, + pure_view); // This must happen after the backward processing as we expect the // computations happening here to track backward mode gradients. diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index d15a02e12ebb..3b9cf755f4c2 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -24,7 +24,8 @@ TORCH_API std::vector> _wrap_outputs( const std::shared_ptr& cdata, const _jvp_fn_t& jvp_user_function, const std::unordered_set& to_save_if_setup_context, - const _view_as_self_fn_t& view_as_self_fn); + const _view_as_self_fn_t& view_as_self_fn, + bool pure_view); TORCH_API void check_variable_result( const at::TensorBase& original, @@ -523,7 +524,8 @@ auto Function::apply(Args&&... args) is_executable ? node : nullptr, jvp_fn, {}, - view_as_self_fn); + view_as_self_fn, + false); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 27aea84d849a..b4378faf8d3e 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -539,6 +539,7 @@ static PyObject* THPFunction_new( new (&self->saved_variables) std::vector(); new (&self->is_variable_input) std::vector(); self->materialize_grads = true; + self->pure_view = false; self->materialize_non_diff_grads = true; return obj; } @@ -716,7 +717,8 @@ static void _wrap_outputs( cdata_if_executable, jvp_user_function, to_save_if_setup_context, - view_as_self_fn); + view_as_self_fn, + self->pure_view); for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GetItem(raw_output, i); @@ -1456,6 +1458,20 @@ int THPFunction_set_materialize_grads( END_HANDLE_TH_ERRORS_RET(-1) } +int THPFunction_set_pure_view( + THPFunction* self, + PyObject* value, + void* unused) { + HANDLE_TH_ERRORS + if (!PyBool_Check(value)) { + THPUtils_invalidArguments(value, nullptr, "set_pure_view", 1, "(bool)"); + return -1; + } + self->pure_view = (value == Py_True); + return 0; + END_HANDLE_TH_ERRORS_RET(-1) +} + PyObject* THPFunction_get_materialize_non_diff_grads( THPFunction* self, void* _unused) { @@ -1730,6 +1746,11 @@ static struct PyGetSetDef THPFunction_properties[] = { (setter)THPFunction_set_materialize_grads, nullptr, nullptr}, + {"_is_pure_view", + nullptr, + (setter)THPFunction_set_pure_view, + nullptr, + nullptr}, {"_materialize_non_diff_grads", (getter)THPFunction_get_materialize_non_diff_grads, (setter)THPFunction_set_materialize_non_diff_grads, diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index e24399c10aa3..4b22c40725f9 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -109,6 +109,10 @@ struct THPFunction { // Default is true. bool materialize_grads; + // boolean indicating whether the function is a "pure view", meaning that + // replaying the view is enough to get a correct backward. + bool pure_view; + // boolean indicating whether to materialize output grad tensors // corresponding to non-differentiable outputs. Normally, someone would // already get this behavior by switching off materialize_grads,