Add view support for library custom Function (#164520)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520
Approved by: https://github.com/soulitzer, https://github.com/ezyang
This commit is contained in:
albanD
2025-10-08 17:14:59 -04:00
committed by PyTorch MergeBot
parent eaa02655ea
commit 24d69c57cb
7 changed files with 262 additions and 56 deletions

View File

@ -1292,12 +1292,6 @@ torch::Tensor view_op(const torch::Tensor& self) {
return self.alias();
}
torch::Tensor view_op_with_extra_arg(
const torch::Tensor& self,
const torch::Tensor& other) {
return self.alias();
}
std::vector<torch::Tensor> ret_tensor_vector_view(
const torch::Tensor& self,
const torch::Tensor& other) {
@ -1534,35 +1528,9 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) {
// Test inplace on view
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
// raise on rebase_history when it refreshes grad_fn
ASSERT_THROWS_WITH(
v1.add_(t), "which does not have a derivative implemented is forbidden");
// base should not be aware of the views, so this is still okay
// this works as we can properly replay the view given by the user
v1.add_(t);
b1.add_(t);
ASSERT_THROWS_WITH(
v1.grad_fn(),
"which does not have a derivative implemented is forbidden");
}
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
REGISTER_TEST_OP(
"view_op_with_extra_arg",
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
view_op_with_extra_arg);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::view_op_with_extra_arg", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({2.}, {torch::kFloat32});
auto out1 = op(a, b);
ASSERT_TRUE(out1.is_view());
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {

View File

@ -22,7 +22,6 @@ import torch._custom_ops as custom_ops
import torch.testing._internal.optests as optests
import torch.utils._pytree as pytree
import torch.utils.cpp_extension
from functorch import make_fx
from torch import Tensor
from torch._custom_op.impl import CustomOp, infer_schema
from torch._library.fake_profile import (
@ -37,6 +36,7 @@ from torch._library.fake_profile import (
from torch._library.infer_schema import tuple_to_list
from torch._library.opaque_object import make_opaque, OpaqueType
from torch._utils_internal import get_file_path_2 # @manual
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal import custom_op_db
from torch.testing._internal.common_cuda import TEST_CUDA
@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import (
TestCase,
)
from torch.testing._internal.custom_op_db import numpy_nonzero
from torch.testing._internal.two_tensor import TwoTensor
# Shadowed by `torch.testing._internal.common_utils.custom_op`
@ -2553,6 +2554,111 @@ class TestCustomOpAPI(TestCase):
sin_(x)
self.assertEqual(x, expected)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_subclass_accessor_view_error(self):
@torch.library.custom_op(
"_torch_testing::_failing_two_tensor_accessor",
mutates_args=(),
schema="(Tensor(a) tx, SymInt idx) -> Tensor(a)",
)
def _failing_two_tensor_accessor(tx, idx):
return tx.view_as(tx)
def noop(*args):
pass
_failing_two_tensor_accessor.register_autograd(noop, setup_context=noop)
t = torch.rand(2)
with self.assertRaisesRegex(
RuntimeError, "Custom ops that are views do not support SymInt."
):
torch.ops._torch_testing._failing_two_tensor_accessor(t, 2)
@torch.library.custom_op(
"_torch_testing::_failing_two_tensor_accessor_list",
mutates_args=(),
schema="(Tensor(a) tx, SymInt[] idx) -> Tensor(a)",
)
def _failing_two_tensor_accessor_list(tx, idx):
return tx.view_as(tx)
def noop(*args):
pass
_failing_two_tensor_accessor_list.register_autograd(noop, setup_context=noop)
t = torch.rand(2)
with self.assertRaisesRegex(
RuntimeError, "Custom ops that are views do not support SymInt."
):
torch.ops._torch_testing._failing_two_tensor_accessor_list(t, (2,))
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_subclass_accessor_view(self):
class MyTwoTensor(TwoTensor):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func is torch.ops._torch_testing._two_tensor_accessor.default:
self.assertIsInstance(args[0], MyTwoTensor)
self.assertIn(args[1], (0, 1))
if args[1] == 0:
res = args[0].a
else:
res = args[0].b
# Always return a fresh Tensor!
return res.view_as(res)
return super().__torch_dispatch__(func, types, args, kwargs)
@torch.library.custom_op(
"_torch_testing::_two_tensor_accessor",
mutates_args=(),
schema="(Tensor(a) tx, int idx) -> Tensor(a)",
)
def _two_tensor_accessor(tx, idx):
raise RuntimeError("Should never be called")
def backward(ctx, gO):
gI = gO.clone()
if ctx.idx == 0:
return MyTwoTensor(gI, torch.zeros_like(gO)), None
else:
return MyTwoTensor(torch.zeros_like(gO), gI), None
def setup_ctx(ctx, inputs, output):
ctx._is_pure_view = True
ctx.idx = inputs[1]
_two_tensor_accessor.register_autograd(backward, setup_context=setup_ctx)
x = torch.rand(3)
y = torch.rand(3)
z = MyTwoTensor(x, y, requires_grad=True)
res = torch.ops._torch_testing._two_tensor_accessor(z, 0)
res.sum().backward()
self.assertEqual(res, x)
self.assertTrue(res._is_view())
self.assertTrue(res._base is z)
self.assertEqual(z.grad, torch.ones_like(z.grad))
res = torch.ops._torch_testing._two_tensor_accessor(z, 1)
res.sum().backward()
self.assertEqual(res, y)
self.assertTrue(res._is_view())
self.assertTrue(res._base is z)
self.assertEqual(z.grad, TwoTensor(torch.ones(3), torch.ones(3)))
leaf = MyTwoTensor(torch.rand(3), torch.rand(3), requires_grad=True)
non_leaf = leaf.clone()
view_a = torch.ops._torch_testing._two_tensor_accessor(non_leaf, 0)
self.assertTrue(view_a._is_view())
self.assertTrue(view_a._base is non_leaf)
view_a *= 2
self.assertEqual(non_leaf.a, view_a)
self.assertNotEqual(leaf.a, view_a)
non_leaf.sum().backward()
self.assertEqual(leaf.grad, MyTwoTensor(2 * torch.ones(3), torch.ones(3)))
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_kwarg_only_tensors(self):
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):

View File

@ -348,13 +348,15 @@ class CustomOpDef:
fn = self._backend_fns[device_type]
return inspect.getmodule(fn)
utils._c_check_aliasing_constraint(
self._name,
args,
kwargs,
result,
get_module,
)
schema = self._opoverload._schema
if not schema._is_view_op():
utils._c_check_aliasing_constraint(
self._name,
args,
kwargs,
result,
get_module,
)
return result
if device_type is None:
@ -587,7 +589,7 @@ class CustomOpDef:
"""
schema = self._opoverload._schema
if not utils.is_functional_schema(schema):
if not utils.is_functional_schema(schema, allow_valid_view=True):
raise RuntimeError(
f"Cannot register autograd formula for non-functional operator "
f"{self} with schema {schema}. Please create "
@ -632,20 +634,27 @@ class CustomOpDef:
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
schema = self._opoverload._schema
if schema._is_view_op() or schema.is_mutable:
lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr]
if schema.is_mutable:
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key(
f"{lib.ns}::{self._name}", "ADInplaceOrView"
)
def adinplaceorview_impl(keyset, *args, **kwargs):
# Handle the mutated idx the user gave us explicitly
for idx in mutated_idxs:
increment_version(args[idx])
for key in mutated_keys:
increment_version(kwargs[key])
with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload.redispatch(
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
)
# Handle view + mutation that are in the schema
return original_kernel.call_boxed(keyset, *args, **kwargs)
lib.impl(
self._name,

View File

@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, overload, Union
import torch
import torch.utils._pytree as pytree
import torchgen
from torch import _C, _utils_internal
from torch._ops import OpOverload
@ -74,12 +75,15 @@ def is_builtin(op: OpOverload) -> bool:
return op.namespace in {"aten", "prim", "prims"}
def is_functional_schema(schema: Any) -> bool:
def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool:
"""Check if the schema is functional.
An operator is functional if:
- it does not mutate any of its inputs
- it does not return a view on any of its inputs
- If no view are allowed
- it does not return a view on any of its inputs
- If valid views are allowed
- it is not a view or a view with a single input Tensor and single output Tensor
- it has at least one return
"""
@ -90,8 +94,31 @@ def is_functional_schema(schema: Any) -> bool:
is_non_mutating_view = len(rets) > 0 and any(
r.alias_info is not None and not r.alias_info.is_write for r in rets
)
num_tensor_inputs = 0
num_tensor_outputs = 0
if isinstance(schema, torch.FunctionSchema):
for arg in schema.arguments:
if isinstance(arg.type, torch.TensorType):
num_tensor_inputs += 1
for ret in schema.returns:
if isinstance(ret.type, torch.TensorType):
num_tensor_outputs += 1
elif isinstance(schema, torchgen.model.FunctionSchema):
for argument in schema.arguments.flat_non_out:
if argument.type.is_tensor_like():
num_tensor_inputs += 1
for ret_arg in schema.returns:
if ret_arg.type.is_tensor_like():
num_tensor_outputs += 1
if is_non_mutating_view:
return False
return allow_valid_view and (
num_tensor_inputs == 1 and num_tensor_outputs == 1
)
if not schema.returns:
return False
return True

View File

@ -478,6 +478,57 @@ torch::CppFunction autogradNotImplementedFallback() {
&autogradNotImplementedFallbackImpl>();
}
struct GenericViewFunc : public ViewFunc {
GenericViewFunc(
torch::jit::Stack non_tensor_stack,
size_t aliased_input_idx_val,
c10::OperatorHandle op)
: non_tensor_stack_(non_tensor_stack),
aliased_input_idx_val_(aliased_input_idx_val),
op_(op) {
// This should report saved Tensors and SymInts.
// We already have an assert that ensure there are no Tensors here
// by making sure there is only one Tensor input.
// We also verify there are no SymInt here for now.
// Both can be lifted if the visit and clone logic get updated.
const auto& schema = op_.schema();
for (const auto& arg : schema.arguments()) {
TORCH_CHECK(
arg.real_type()->kind() != c10::TypeKind::SymIntType,
"Custom ops that are views do not support SymInt. Please file an issue if you need it.");
for (const auto& ct : arg.real_type()->containedTypes()) {
TORCH_CHECK(
ct->kind() != c10::TypeKind::SymIntType,
"Custom ops that are views do not support SymInt. Please file an issue if you need it.");
}
}
}
at::Tensor operator()(const at::Tensor& new_base) const override {
torch::jit::Stack local_stack = non_tensor_stack_;
local_stack.at(aliased_input_idx_val_) = c10::IValue(new_base);
op_.callBoxed(local_stack);
auto& result = local_stack[local_stack.size() - 1];
TORCH_CHECK(
result.isTensor(),
"ADInplaceOrView fallback view replay did not return a Tensor");
return result.toTensor();
}
std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = std::nullopt,
std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
return std::make_unique<GenericViewFunc>(
non_tensor_stack_, aliased_input_idx_val_, op_);
}
private:
torch::jit::Stack non_tensor_stack_;
size_t aliased_input_idx_val_;
c10::OperatorHandle op_;
};
static void autogradNotImplementedInplaceOrViewFallbackImpl(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
@ -553,6 +604,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
"input and the first output (the output can be a vector of tensors). Please change the "
"order of your operator's parameters so that this is the case.");
const bool is_view = aliased_input_idx.has_value();
size_t aliased_input_idx_val;
// Save inputs before we redispatch down
torch::jit::Stack non_tensor_stack;
if (is_view) {
// Note that this won't be used if a TensorList is returned.
aliased_input_idx_val = aliased_input_idx.value();
non_tensor_stack.reserve(num_arguments);
for (const auto i : c10::irange(num_arguments)) {
non_tensor_stack.push_back((*stack)[stack_start + i]);
}
}
{
at::AutoDispatchBelowADInplaceOrView guard;
@ -608,13 +671,32 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
auto result = std::move(aliased_output);
stack->at(stack->size() - num_returns + aliased_output_idx) = result;
} else {
c10::IValue& aliased_output_iv =
(*stack)[stack->size() - num_returns + aliased_output_idx];
TORCH_CHECK(aliased_output_iv.isTensor());
TORCH_CHECK(
num_returns == 1,
"ADInplaceOrView fallback only support single output view functions");
// Remove the Tensor from the original stack
for (const auto i : c10::irange(num_arguments)) {
if (non_tensor_stack[i].isTensor()) {
TORCH_CHECK(
i == aliased_input_idx_val,
"Internal error in ADInplaceOrView fallback, unknown Tensor in the stack");
non_tensor_stack[i] = {};
}
}
auto view_func = std::make_unique<GenericViewFunc>(
non_tensor_stack, aliased_input_idx_val, op);
auto result = as_view(
/* base=*/aliased_input,
/* tensor=*/std::move(aliased_output_iv).toTensor(),
/* is_bw_differentiable=*/true,
/* is_fw_differentiable=*/true,
/* view_func=*/std::move(erroring_view_func),
/* view_func=*/std::move(view_func),
/* rev_view_func=*/erroring_rev_view_func,
/* creation_meta=*/
InferenceMode::is_enabled()

View File

@ -16,6 +16,7 @@
#include <c10/core/SafePyObject.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/tensor_new.h>
@ -494,7 +495,20 @@ void initDispatchBindings(PyObject* module) {
"",
py::arg("dispatch"),
py::arg("func"),
py::arg("with_keyset") = false);
py::arg("with_keyset") = false)
.def(
"register_ad_inplace_or_view_fallback",
[](const py::object& self, const char* name) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
lib.impl(
name,
c10::DispatchKey::ADInplaceOrView,
torch::autograd::autogradNotImplementedInplaceOrViewFallback());
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("name"));
m.def(
"_dispatch_library",

View File

@ -9,7 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
@staticmethod
def __new__(cls, a, b, outer_size=None, outer_stride=None):
def __new__(cls, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
if outer_size is None:
outer_size = a.size()
if outer_stride is None:
@ -28,7 +28,7 @@ class TwoTensor(torch.Tensor):
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["requires_grad"] = requires_grad or a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
@ -39,7 +39,7 @@ class TwoTensor(torch.Tensor):
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental
def __init__(self, a, b, outer_size=None, outer_stride=None):
def __init__(self, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
self.a = a
self.b = b