From 24d69c57cbaa94cc828dbbdf83c889f5f244ae28 Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 8 Oct 2025 17:14:59 -0400 Subject: [PATCH] Add view support for library custom Function (#164520) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520 Approved by: https://github.com/soulitzer, https://github.com/ezyang --- test/cpp/api/autograd.cpp | 36 +----- test/test_custom_ops.py | 108 +++++++++++++++++- torch/_library/custom_ops.py | 35 +++--- torch/_library/utils.py | 33 +++++- .../autograd_not_implemented_fallback.cpp | 84 +++++++++++++- torch/csrc/utils/python_dispatch.cpp | 16 ++- torch/testing/_internal/two_tensor.py | 6 +- 7 files changed, 262 insertions(+), 56 deletions(-) diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 7b6d65ca8e6d..b7e75acb659d 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -1292,12 +1292,6 @@ torch::Tensor view_op(const torch::Tensor& self) { return self.alias(); } -torch::Tensor view_op_with_extra_arg( - const torch::Tensor& self, - const torch::Tensor& other) { - return self.alias(); -} - std::vector ret_tensor_vector_view( const torch::Tensor& self, const torch::Tensor& other) { @@ -1534,35 +1528,9 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) { // Test inplace on view auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); - // raise on rebase_history when it refreshes grad_fn - ASSERT_THROWS_WITH( - v1.add_(t), "which does not have a derivative implemented is forbidden"); - // base should not be aware of the views, so this is still okay + // this works as we can properly replay the view given by the user + v1.add_(t); b1.add_(t); - ASSERT_THROWS_WITH( - v1.grad_fn(), - "which does not have a derivative implemented is forbidden"); -} - -TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) { - REGISTER_TEST_OP( - "view_op_with_extra_arg", - "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)", - view_op_with_extra_arg); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( - "_test::view_op_with_extra_arg", ""); - auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed< - torch::Tensor, - const torch::Tensor&, - const torch::Tensor&>(opHandle, _1, _2); - }; - assertBasicChecks(op); - auto a = torch::tensor({1.}, {torch::kFloat32}); - auto b = torch::tensor({2.}, {torch::kFloat32}); - auto out1 = op(a, b); - ASSERT_TRUE(out1.is_view()); - ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); } TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) { diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index e84031e6d4fa..4838e73f1f4c 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -22,7 +22,6 @@ import torch._custom_ops as custom_ops import torch.testing._internal.optests as optests import torch.utils._pytree as pytree import torch.utils.cpp_extension -from functorch import make_fx from torch import Tensor from torch._custom_op.impl import CustomOp, infer_schema from torch._library.fake_profile import ( @@ -37,6 +36,7 @@ from torch._library.fake_profile import ( from torch._library.infer_schema import tuple_to_list from torch._library.opaque_object import make_opaque, OpaqueType from torch._utils_internal import get_file_path_2 # @manual +from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal import custom_op_db from torch.testing._internal.common_cuda import TEST_CUDA @@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import ( TestCase, ) from torch.testing._internal.custom_op_db import numpy_nonzero +from torch.testing._internal.two_tensor import TwoTensor # Shadowed by `torch.testing._internal.common_utils.custom_op` @@ -2553,6 +2554,111 @@ class TestCustomOpAPI(TestCase): sin_(x) self.assertEqual(x, expected) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_subclass_accessor_view_error(self): + @torch.library.custom_op( + "_torch_testing::_failing_two_tensor_accessor", + mutates_args=(), + schema="(Tensor(a) tx, SymInt idx) -> Tensor(a)", + ) + def _failing_two_tensor_accessor(tx, idx): + return tx.view_as(tx) + + def noop(*args): + pass + + _failing_two_tensor_accessor.register_autograd(noop, setup_context=noop) + + t = torch.rand(2) + with self.assertRaisesRegex( + RuntimeError, "Custom ops that are views do not support SymInt." + ): + torch.ops._torch_testing._failing_two_tensor_accessor(t, 2) + + @torch.library.custom_op( + "_torch_testing::_failing_two_tensor_accessor_list", + mutates_args=(), + schema="(Tensor(a) tx, SymInt[] idx) -> Tensor(a)", + ) + def _failing_two_tensor_accessor_list(tx, idx): + return tx.view_as(tx) + + def noop(*args): + pass + + _failing_two_tensor_accessor_list.register_autograd(noop, setup_context=noop) + + t = torch.rand(2) + with self.assertRaisesRegex( + RuntimeError, "Custom ops that are views do not support SymInt." + ): + torch.ops._torch_testing._failing_two_tensor_accessor_list(t, (2,)) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_subclass_accessor_view(self): + class MyTwoTensor(TwoTensor): + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is torch.ops._torch_testing._two_tensor_accessor.default: + self.assertIsInstance(args[0], MyTwoTensor) + self.assertIn(args[1], (0, 1)) + if args[1] == 0: + res = args[0].a + else: + res = args[0].b + # Always return a fresh Tensor! + return res.view_as(res) + return super().__torch_dispatch__(func, types, args, kwargs) + + @torch.library.custom_op( + "_torch_testing::_two_tensor_accessor", + mutates_args=(), + schema="(Tensor(a) tx, int idx) -> Tensor(a)", + ) + def _two_tensor_accessor(tx, idx): + raise RuntimeError("Should never be called") + + def backward(ctx, gO): + gI = gO.clone() + if ctx.idx == 0: + return MyTwoTensor(gI, torch.zeros_like(gO)), None + else: + return MyTwoTensor(torch.zeros_like(gO), gI), None + + def setup_ctx(ctx, inputs, output): + ctx._is_pure_view = True + ctx.idx = inputs[1] + + _two_tensor_accessor.register_autograd(backward, setup_context=setup_ctx) + + x = torch.rand(3) + y = torch.rand(3) + z = MyTwoTensor(x, y, requires_grad=True) + res = torch.ops._torch_testing._two_tensor_accessor(z, 0) + res.sum().backward() + self.assertEqual(res, x) + self.assertTrue(res._is_view()) + self.assertTrue(res._base is z) + self.assertEqual(z.grad, torch.ones_like(z.grad)) + + res = torch.ops._torch_testing._two_tensor_accessor(z, 1) + res.sum().backward() + self.assertEqual(res, y) + self.assertTrue(res._is_view()) + self.assertTrue(res._base is z) + self.assertEqual(z.grad, TwoTensor(torch.ones(3), torch.ones(3))) + + leaf = MyTwoTensor(torch.rand(3), torch.rand(3), requires_grad=True) + non_leaf = leaf.clone() + view_a = torch.ops._torch_testing._two_tensor_accessor(non_leaf, 0) + self.assertTrue(view_a._is_view()) + self.assertTrue(view_a._base is non_leaf) + view_a *= 2 + self.assertEqual(non_leaf.a, view_a) + self.assertNotEqual(leaf.a, view_a) + non_leaf.sum().backward() + self.assertEqual(leaf.grad, MyTwoTensor(2 * torch.ones(3), torch.ones(3))) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_kwarg_only_tensors(self): with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index cb9f97d651a4..faa066a987f6 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -348,13 +348,15 @@ class CustomOpDef: fn = self._backend_fns[device_type] return inspect.getmodule(fn) - utils._c_check_aliasing_constraint( - self._name, - args, - kwargs, - result, - get_module, - ) + schema = self._opoverload._schema + if not schema._is_view_op(): + utils._c_check_aliasing_constraint( + self._name, + args, + kwargs, + result, + get_module, + ) return result if device_type is None: @@ -587,7 +589,7 @@ class CustomOpDef: """ schema = self._opoverload._schema - if not utils.is_functional_schema(schema): + if not utils.is_functional_schema(schema, allow_valid_view=True): raise RuntimeError( f"Cannot register autograd formula for non-functional operator " f"{self} with schema {schema}. Please create " @@ -632,20 +634,27 @@ class CustomOpDef: autograd_impl = autograd.make_autograd_impl(self._opoverload, self) lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) - schema = self._opoverload._schema + + if schema._is_view_op() or schema.is_mutable: + lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr] + if schema.is_mutable: mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema) + original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key( + f"{lib.ns}::{self._name}", "ADInplaceOrView" + ) + def adinplaceorview_impl(keyset, *args, **kwargs): + # Handle the mutated idx the user gave us explicitly + for idx in mutated_idxs: increment_version(args[idx]) for key in mutated_keys: increment_version(kwargs[key]) - with _C._AutoDispatchBelowADInplaceOrView(): - return self._opoverload.redispatch( - keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs - ) + # Handle view + mutation that are in the schema + return original_kernel.call_boxed(keyset, *args, **kwargs) lib.impl( self._name, diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 8b62619aceb4..edbe86992b6a 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, overload, Union import torch import torch.utils._pytree as pytree +import torchgen from torch import _C, _utils_internal from torch._ops import OpOverload @@ -74,12 +75,15 @@ def is_builtin(op: OpOverload) -> bool: return op.namespace in {"aten", "prim", "prims"} -def is_functional_schema(schema: Any) -> bool: +def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool: """Check if the schema is functional. An operator is functional if: - it does not mutate any of its inputs - - it does not return a view on any of its inputs + - If no view are allowed + - it does not return a view on any of its inputs + - If valid views are allowed + - it is not a view or a view with a single input Tensor and single output Tensor - it has at least one return """ @@ -90,8 +94,31 @@ def is_functional_schema(schema: Any) -> bool: is_non_mutating_view = len(rets) > 0 and any( r.alias_info is not None and not r.alias_info.is_write for r in rets ) + num_tensor_inputs = 0 + num_tensor_outputs = 0 + + if isinstance(schema, torch.FunctionSchema): + for arg in schema.arguments: + if isinstance(arg.type, torch.TensorType): + num_tensor_inputs += 1 + + for ret in schema.returns: + if isinstance(ret.type, torch.TensorType): + num_tensor_outputs += 1 + + elif isinstance(schema, torchgen.model.FunctionSchema): + for argument in schema.arguments.flat_non_out: + if argument.type.is_tensor_like(): + num_tensor_inputs += 1 + + for ret_arg in schema.returns: + if ret_arg.type.is_tensor_like(): + num_tensor_outputs += 1 + if is_non_mutating_view: - return False + return allow_valid_view and ( + num_tensor_inputs == 1 and num_tensor_outputs == 1 + ) if not schema.returns: return False return True diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 2b44dd5905e9..3d4ab7104293 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -478,6 +478,57 @@ torch::CppFunction autogradNotImplementedFallback() { &autogradNotImplementedFallbackImpl>(); } +struct GenericViewFunc : public ViewFunc { + GenericViewFunc( + torch::jit::Stack non_tensor_stack, + size_t aliased_input_idx_val, + c10::OperatorHandle op) + : non_tensor_stack_(non_tensor_stack), + aliased_input_idx_val_(aliased_input_idx_val), + op_(op) { + // This should report saved Tensors and SymInts. + // We already have an assert that ensure there are no Tensors here + // by making sure there is only one Tensor input. + // We also verify there are no SymInt here for now. + // Both can be lifted if the visit and clone logic get updated. + const auto& schema = op_.schema(); + for (const auto& arg : schema.arguments()) { + TORCH_CHECK( + arg.real_type()->kind() != c10::TypeKind::SymIntType, + "Custom ops that are views do not support SymInt. Please file an issue if you need it."); + for (const auto& ct : arg.real_type()->containedTypes()) { + TORCH_CHECK( + ct->kind() != c10::TypeKind::SymIntType, + "Custom ops that are views do not support SymInt. Please file an issue if you need it."); + } + } + } + + at::Tensor operator()(const at::Tensor& new_base) const override { + torch::jit::Stack local_stack = non_tensor_stack_; + local_stack.at(aliased_input_idx_val_) = c10::IValue(new_base); + + op_.callBoxed(local_stack); + auto& result = local_stack[local_stack.size() - 1]; + TORCH_CHECK( + result.isTensor(), + "ADInplaceOrView fallback view replay did not return a Tensor"); + return result.toTensor(); + } + + std::unique_ptr clone_and_set( + std::optional> = std::nullopt, + std::optional> = std::nullopt) const override { + return std::make_unique( + non_tensor_stack_, aliased_input_idx_val_, op_); + } + + private: + torch::jit::Stack non_tensor_stack_; + size_t aliased_input_idx_val_; + c10::OperatorHandle op_; +}; + static void autogradNotImplementedInplaceOrViewFallbackImpl( const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, @@ -553,6 +604,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( "input and the first output (the output can be a vector of tensors). Please change the " "order of your operator's parameters so that this is the case."); const bool is_view = aliased_input_idx.has_value(); + size_t aliased_input_idx_val; + + // Save inputs before we redispatch down + torch::jit::Stack non_tensor_stack; + if (is_view) { + // Note that this won't be used if a TensorList is returned. + aliased_input_idx_val = aliased_input_idx.value(); + non_tensor_stack.reserve(num_arguments); + for (const auto i : c10::irange(num_arguments)) { + non_tensor_stack.push_back((*stack)[stack_start + i]); + } + } { at::AutoDispatchBelowADInplaceOrView guard; @@ -608,13 +671,32 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( auto result = std::move(aliased_output); stack->at(stack->size() - num_returns + aliased_output_idx) = result; } else { + c10::IValue& aliased_output_iv = + (*stack)[stack->size() - num_returns + aliased_output_idx]; TORCH_CHECK(aliased_output_iv.isTensor()); + TORCH_CHECK( + num_returns == 1, + "ADInplaceOrView fallback only support single output view functions"); + + // Remove the Tensor from the original stack + for (const auto i : c10::irange(num_arguments)) { + if (non_tensor_stack[i].isTensor()) { + TORCH_CHECK( + i == aliased_input_idx_val, + "Internal error in ADInplaceOrView fallback, unknown Tensor in the stack"); + non_tensor_stack[i] = {}; + } + } + + auto view_func = std::make_unique( + non_tensor_stack, aliased_input_idx_val, op); + auto result = as_view( /* base=*/aliased_input, /* tensor=*/std::move(aliased_output_iv).toTensor(), /* is_bw_differentiable=*/true, /* is_fw_differentiable=*/true, - /* view_func=*/std::move(erroring_view_func), + /* view_func=*/std::move(view_func), /* rev_view_func=*/erroring_rev_view_func, /* creation_meta=*/ InferenceMode::is_enabled() diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 9d6eb35c7178..f97b6ac0ba9b 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -494,7 +495,20 @@ void initDispatchBindings(PyObject* module) { "", py::arg("dispatch"), py::arg("func"), - py::arg("with_keyset") = false); + py::arg("with_keyset") = false) + .def( + "register_ad_inplace_or_view_fallback", + [](const py::object& self, const char* name) { + HANDLE_TH_ERRORS + auto& lib = self.cast(); + lib.impl( + name, + c10::DispatchKey::ADInplaceOrView, + torch::autograd::autogradNotImplementedInplaceOrViewFallback()); + END_HANDLE_TH_ERRORS_PYBIND + }, + "", + py::arg("name")); m.def( "_dispatch_library", diff --git a/torch/testing/_internal/two_tensor.py b/torch/testing/_internal/two_tensor.py index f0bdbf2d4ef6..3a503c741e88 100644 --- a/torch/testing/_internal/two_tensor.py +++ b/torch/testing/_internal/two_tensor.py @@ -9,7 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing # A simple tensor subclass that holds two tensors internally, and runs every op on both tensors. class TwoTensor(torch.Tensor): @staticmethod - def __new__(cls, a, b, outer_size=None, outer_stride=None): + def __new__(cls, a, b, outer_size=None, outer_stride=None, *, requires_grad=None): if outer_size is None: outer_size = a.size() if outer_stride is None: @@ -28,7 +28,7 @@ class TwoTensor(torch.Tensor): kwargs["storage_offset"] = a.storage_offset() kwargs["device"] = a.device kwargs["layout"] = a.layout - kwargs["requires_grad"] = a.requires_grad + kwargs["requires_grad"] = requires_grad or a.requires_grad kwargs["dtype"] = a.dtype out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) @@ -39,7 +39,7 @@ class TwoTensor(torch.Tensor): @torch._disable_dynamo @mark_subclass_constructor_exportable_experimental - def __init__(self, a, b, outer_size=None, outer_stride=None): + def __init__(self, a, b, outer_size=None, outer_stride=None, *, requires_grad=None): self.a = a self.b = b