mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add view support for library custom Function (#164520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520 Approved by: https://github.com/soulitzer, https://github.com/ezyang
This commit is contained in:
@ -1292,12 +1292,6 @@ torch::Tensor view_op(const torch::Tensor& self) {
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
torch::Tensor view_op_with_extra_arg(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> ret_tensor_vector_view(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
@ -1534,35 +1528,9 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) {
|
||||
// Test inplace on view
|
||||
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
||||
|
||||
// raise on rebase_history when it refreshes grad_fn
|
||||
ASSERT_THROWS_WITH(
|
||||
v1.add_(t), "which does not have a derivative implemented is forbidden");
|
||||
// base should not be aware of the views, so this is still okay
|
||||
// this works as we can properly replay the view given by the user
|
||||
v1.add_(t);
|
||||
b1.add_(t);
|
||||
ASSERT_THROWS_WITH(
|
||||
v1.grad_fn(),
|
||||
"which does not have a derivative implemented is forbidden");
|
||||
}
|
||||
|
||||
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
|
||||
REGISTER_TEST_OP(
|
||||
"view_op_with_extra_arg",
|
||||
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
|
||||
view_op_with_extra_arg);
|
||||
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
"_test::view_op_with_extra_arg", "");
|
||||
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
||||
return callOpUnboxed<
|
||||
torch::Tensor,
|
||||
const torch::Tensor&,
|
||||
const torch::Tensor&>(opHandle, _1, _2);
|
||||
};
|
||||
assertBasicChecks(op);
|
||||
auto a = torch::tensor({1.}, {torch::kFloat32});
|
||||
auto b = torch::tensor({2.}, {torch::kFloat32});
|
||||
auto out1 = op(a, b);
|
||||
ASSERT_TRUE(out1.is_view());
|
||||
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
|
||||
|
@ -22,7 +22,6 @@ import torch._custom_ops as custom_ops
|
||||
import torch.testing._internal.optests as optests
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.cpp_extension
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._custom_op.impl import CustomOp, infer_schema
|
||||
from torch._library.fake_profile import (
|
||||
@ -37,6 +36,7 @@ from torch._library.fake_profile import (
|
||||
from torch._library.infer_schema import tuple_to_list
|
||||
from torch._library.opaque_object import make_opaque, OpaqueType
|
||||
from torch._utils_internal import get_file_path_2 # @manual
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing._internal import custom_op_db
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import numpy_nonzero
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
# Shadowed by `torch.testing._internal.common_utils.custom_op`
|
||||
@ -2553,6 +2554,111 @@ class TestCustomOpAPI(TestCase):
|
||||
sin_(x)
|
||||
self.assertEqual(x, expected)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_subclass_accessor_view_error(self):
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::_failing_two_tensor_accessor",
|
||||
mutates_args=(),
|
||||
schema="(Tensor(a) tx, SymInt idx) -> Tensor(a)",
|
||||
)
|
||||
def _failing_two_tensor_accessor(tx, idx):
|
||||
return tx.view_as(tx)
|
||||
|
||||
def noop(*args):
|
||||
pass
|
||||
|
||||
_failing_two_tensor_accessor.register_autograd(noop, setup_context=noop)
|
||||
|
||||
t = torch.rand(2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Custom ops that are views do not support SymInt."
|
||||
):
|
||||
torch.ops._torch_testing._failing_two_tensor_accessor(t, 2)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::_failing_two_tensor_accessor_list",
|
||||
mutates_args=(),
|
||||
schema="(Tensor(a) tx, SymInt[] idx) -> Tensor(a)",
|
||||
)
|
||||
def _failing_two_tensor_accessor_list(tx, idx):
|
||||
return tx.view_as(tx)
|
||||
|
||||
def noop(*args):
|
||||
pass
|
||||
|
||||
_failing_two_tensor_accessor_list.register_autograd(noop, setup_context=noop)
|
||||
|
||||
t = torch.rand(2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Custom ops that are views do not support SymInt."
|
||||
):
|
||||
torch.ops._torch_testing._failing_two_tensor_accessor_list(t, (2,))
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_subclass_accessor_view(self):
|
||||
class MyTwoTensor(TwoTensor):
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func is torch.ops._torch_testing._two_tensor_accessor.default:
|
||||
self.assertIsInstance(args[0], MyTwoTensor)
|
||||
self.assertIn(args[1], (0, 1))
|
||||
if args[1] == 0:
|
||||
res = args[0].a
|
||||
else:
|
||||
res = args[0].b
|
||||
# Always return a fresh Tensor!
|
||||
return res.view_as(res)
|
||||
return super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::_two_tensor_accessor",
|
||||
mutates_args=(),
|
||||
schema="(Tensor(a) tx, int idx) -> Tensor(a)",
|
||||
)
|
||||
def _two_tensor_accessor(tx, idx):
|
||||
raise RuntimeError("Should never be called")
|
||||
|
||||
def backward(ctx, gO):
|
||||
gI = gO.clone()
|
||||
if ctx.idx == 0:
|
||||
return MyTwoTensor(gI, torch.zeros_like(gO)), None
|
||||
else:
|
||||
return MyTwoTensor(torch.zeros_like(gO), gI), None
|
||||
|
||||
def setup_ctx(ctx, inputs, output):
|
||||
ctx._is_pure_view = True
|
||||
ctx.idx = inputs[1]
|
||||
|
||||
_two_tensor_accessor.register_autograd(backward, setup_context=setup_ctx)
|
||||
|
||||
x = torch.rand(3)
|
||||
y = torch.rand(3)
|
||||
z = MyTwoTensor(x, y, requires_grad=True)
|
||||
res = torch.ops._torch_testing._two_tensor_accessor(z, 0)
|
||||
res.sum().backward()
|
||||
self.assertEqual(res, x)
|
||||
self.assertTrue(res._is_view())
|
||||
self.assertTrue(res._base is z)
|
||||
self.assertEqual(z.grad, torch.ones_like(z.grad))
|
||||
|
||||
res = torch.ops._torch_testing._two_tensor_accessor(z, 1)
|
||||
res.sum().backward()
|
||||
self.assertEqual(res, y)
|
||||
self.assertTrue(res._is_view())
|
||||
self.assertTrue(res._base is z)
|
||||
self.assertEqual(z.grad, TwoTensor(torch.ones(3), torch.ones(3)))
|
||||
|
||||
leaf = MyTwoTensor(torch.rand(3), torch.rand(3), requires_grad=True)
|
||||
non_leaf = leaf.clone()
|
||||
view_a = torch.ops._torch_testing._two_tensor_accessor(non_leaf, 0)
|
||||
self.assertTrue(view_a._is_view())
|
||||
self.assertTrue(view_a._base is non_leaf)
|
||||
view_a *= 2
|
||||
self.assertEqual(non_leaf.a, view_a)
|
||||
self.assertNotEqual(leaf.a, view_a)
|
||||
non_leaf.sum().backward()
|
||||
self.assertEqual(leaf.grad, MyTwoTensor(2 * torch.ones(3), torch.ones(3)))
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_kwarg_only_tensors(self):
|
||||
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
|
||||
|
@ -348,6 +348,8 @@ class CustomOpDef:
|
||||
fn = self._backend_fns[device_type]
|
||||
return inspect.getmodule(fn)
|
||||
|
||||
schema = self._opoverload._schema
|
||||
if not schema._is_view_op():
|
||||
utils._c_check_aliasing_constraint(
|
||||
self._name,
|
||||
args,
|
||||
@ -587,7 +589,7 @@ class CustomOpDef:
|
||||
|
||||
"""
|
||||
schema = self._opoverload._schema
|
||||
if not utils.is_functional_schema(schema):
|
||||
if not utils.is_functional_schema(schema, allow_valid_view=True):
|
||||
raise RuntimeError(
|
||||
f"Cannot register autograd formula for non-functional operator "
|
||||
f"{self} with schema {schema}. Please create "
|
||||
@ -632,20 +634,27 @@ class CustomOpDef:
|
||||
|
||||
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
|
||||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
|
||||
|
||||
schema = self._opoverload._schema
|
||||
|
||||
if schema._is_view_op() or schema.is_mutable:
|
||||
lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr]
|
||||
|
||||
if schema.is_mutable:
|
||||
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
|
||||
|
||||
original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key(
|
||||
f"{lib.ns}::{self._name}", "ADInplaceOrView"
|
||||
)
|
||||
|
||||
def adinplaceorview_impl(keyset, *args, **kwargs):
|
||||
# Handle the mutated idx the user gave us explicitly
|
||||
|
||||
for idx in mutated_idxs:
|
||||
increment_version(args[idx])
|
||||
for key in mutated_keys:
|
||||
increment_version(kwargs[key])
|
||||
with _C._AutoDispatchBelowADInplaceOrView():
|
||||
return self._opoverload.redispatch(
|
||||
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
||||
)
|
||||
# Handle view + mutation that are in the schema
|
||||
return original_kernel.call_boxed(keyset, *args, **kwargs)
|
||||
|
||||
lib.impl(
|
||||
self._name,
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import torchgen
|
||||
from torch import _C, _utils_internal
|
||||
from torch._ops import OpOverload
|
||||
|
||||
@ -74,12 +75,15 @@ def is_builtin(op: OpOverload) -> bool:
|
||||
return op.namespace in {"aten", "prim", "prims"}
|
||||
|
||||
|
||||
def is_functional_schema(schema: Any) -> bool:
|
||||
def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool:
|
||||
"""Check if the schema is functional.
|
||||
|
||||
An operator is functional if:
|
||||
- it does not mutate any of its inputs
|
||||
- If no view are allowed
|
||||
- it does not return a view on any of its inputs
|
||||
- If valid views are allowed
|
||||
- it is not a view or a view with a single input Tensor and single output Tensor
|
||||
- it has at least one return
|
||||
"""
|
||||
|
||||
@ -90,8 +94,31 @@ def is_functional_schema(schema: Any) -> bool:
|
||||
is_non_mutating_view = len(rets) > 0 and any(
|
||||
r.alias_info is not None and not r.alias_info.is_write for r in rets
|
||||
)
|
||||
num_tensor_inputs = 0
|
||||
num_tensor_outputs = 0
|
||||
|
||||
if isinstance(schema, torch.FunctionSchema):
|
||||
for arg in schema.arguments:
|
||||
if isinstance(arg.type, torch.TensorType):
|
||||
num_tensor_inputs += 1
|
||||
|
||||
for ret in schema.returns:
|
||||
if isinstance(ret.type, torch.TensorType):
|
||||
num_tensor_outputs += 1
|
||||
|
||||
elif isinstance(schema, torchgen.model.FunctionSchema):
|
||||
for argument in schema.arguments.flat_non_out:
|
||||
if argument.type.is_tensor_like():
|
||||
num_tensor_inputs += 1
|
||||
|
||||
for ret_arg in schema.returns:
|
||||
if ret_arg.type.is_tensor_like():
|
||||
num_tensor_outputs += 1
|
||||
|
||||
if is_non_mutating_view:
|
||||
return False
|
||||
return allow_valid_view and (
|
||||
num_tensor_inputs == 1 and num_tensor_outputs == 1
|
||||
)
|
||||
if not schema.returns:
|
||||
return False
|
||||
return True
|
||||
|
@ -478,6 +478,57 @@ torch::CppFunction autogradNotImplementedFallback() {
|
||||
&autogradNotImplementedFallbackImpl>();
|
||||
}
|
||||
|
||||
struct GenericViewFunc : public ViewFunc {
|
||||
GenericViewFunc(
|
||||
torch::jit::Stack non_tensor_stack,
|
||||
size_t aliased_input_idx_val,
|
||||
c10::OperatorHandle op)
|
||||
: non_tensor_stack_(non_tensor_stack),
|
||||
aliased_input_idx_val_(aliased_input_idx_val),
|
||||
op_(op) {
|
||||
// This should report saved Tensors and SymInts.
|
||||
// We already have an assert that ensure there are no Tensors here
|
||||
// by making sure there is only one Tensor input.
|
||||
// We also verify there are no SymInt here for now.
|
||||
// Both can be lifted if the visit and clone logic get updated.
|
||||
const auto& schema = op_.schema();
|
||||
for (const auto& arg : schema.arguments()) {
|
||||
TORCH_CHECK(
|
||||
arg.real_type()->kind() != c10::TypeKind::SymIntType,
|
||||
"Custom ops that are views do not support SymInt. Please file an issue if you need it.");
|
||||
for (const auto& ct : arg.real_type()->containedTypes()) {
|
||||
TORCH_CHECK(
|
||||
ct->kind() != c10::TypeKind::SymIntType,
|
||||
"Custom ops that are views do not support SymInt. Please file an issue if you need it.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor operator()(const at::Tensor& new_base) const override {
|
||||
torch::jit::Stack local_stack = non_tensor_stack_;
|
||||
local_stack.at(aliased_input_idx_val_) = c10::IValue(new_base);
|
||||
|
||||
op_.callBoxed(local_stack);
|
||||
auto& result = local_stack[local_stack.size() - 1];
|
||||
TORCH_CHECK(
|
||||
result.isTensor(),
|
||||
"ADInplaceOrView fallback view replay did not return a Tensor");
|
||||
return result.toTensor();
|
||||
}
|
||||
|
||||
std::unique_ptr<ViewFunc> clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> = std::nullopt,
|
||||
std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
|
||||
return std::make_unique<GenericViewFunc>(
|
||||
non_tensor_stack_, aliased_input_idx_val_, op_);
|
||||
}
|
||||
|
||||
private:
|
||||
torch::jit::Stack non_tensor_stack_;
|
||||
size_t aliased_input_idx_val_;
|
||||
c10::OperatorHandle op_;
|
||||
};
|
||||
|
||||
static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet dispatch_keys,
|
||||
@ -553,6 +604,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
"input and the first output (the output can be a vector of tensors). Please change the "
|
||||
"order of your operator's parameters so that this is the case.");
|
||||
const bool is_view = aliased_input_idx.has_value();
|
||||
size_t aliased_input_idx_val;
|
||||
|
||||
// Save inputs before we redispatch down
|
||||
torch::jit::Stack non_tensor_stack;
|
||||
if (is_view) {
|
||||
// Note that this won't be used if a TensorList is returned.
|
||||
aliased_input_idx_val = aliased_input_idx.value();
|
||||
non_tensor_stack.reserve(num_arguments);
|
||||
for (const auto i : c10::irange(num_arguments)) {
|
||||
non_tensor_stack.push_back((*stack)[stack_start + i]);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
@ -608,13 +671,32 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
auto result = std::move(aliased_output);
|
||||
stack->at(stack->size() - num_returns + aliased_output_idx) = result;
|
||||
} else {
|
||||
c10::IValue& aliased_output_iv =
|
||||
(*stack)[stack->size() - num_returns + aliased_output_idx];
|
||||
TORCH_CHECK(aliased_output_iv.isTensor());
|
||||
TORCH_CHECK(
|
||||
num_returns == 1,
|
||||
"ADInplaceOrView fallback only support single output view functions");
|
||||
|
||||
// Remove the Tensor from the original stack
|
||||
for (const auto i : c10::irange(num_arguments)) {
|
||||
if (non_tensor_stack[i].isTensor()) {
|
||||
TORCH_CHECK(
|
||||
i == aliased_input_idx_val,
|
||||
"Internal error in ADInplaceOrView fallback, unknown Tensor in the stack");
|
||||
non_tensor_stack[i] = {};
|
||||
}
|
||||
}
|
||||
|
||||
auto view_func = std::make_unique<GenericViewFunc>(
|
||||
non_tensor_stack, aliased_input_idx_val, op);
|
||||
|
||||
auto result = as_view(
|
||||
/* base=*/aliased_input,
|
||||
/* tensor=*/std::move(aliased_output_iv).toTensor(),
|
||||
/* is_bw_differentiable=*/true,
|
||||
/* is_fw_differentiable=*/true,
|
||||
/* view_func=*/std::move(erroring_view_func),
|
||||
/* view_func=*/std::move(view_func),
|
||||
/* rev_view_func=*/erroring_rev_view_func,
|
||||
/* creation_meta=*/
|
||||
InferenceMode::is_enabled()
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <torch/csrc/PyInterpreter.h>
|
||||
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
@ -494,7 +495,20 @@ void initDispatchBindings(PyObject* module) {
|
||||
"",
|
||||
py::arg("dispatch"),
|
||||
py::arg("func"),
|
||||
py::arg("with_keyset") = false);
|
||||
py::arg("with_keyset") = false)
|
||||
.def(
|
||||
"register_ad_inplace_or_view_fallback",
|
||||
[](const py::object& self, const char* name) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& lib = self.cast<torch::Library&>();
|
||||
lib.impl(
|
||||
name,
|
||||
c10::DispatchKey::ADInplaceOrView,
|
||||
torch::autograd::autogradNotImplementedInplaceOrViewFallback());
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
},
|
||||
"",
|
||||
py::arg("name"));
|
||||
|
||||
m.def(
|
||||
"_dispatch_library",
|
||||
|
@ -9,7 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
|
||||
class TwoTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, a, b, outer_size=None, outer_stride=None):
|
||||
def __new__(cls, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
|
||||
if outer_size is None:
|
||||
outer_size = a.size()
|
||||
if outer_stride is None:
|
||||
@ -28,7 +28,7 @@ class TwoTensor(torch.Tensor):
|
||||
kwargs["storage_offset"] = a.storage_offset()
|
||||
kwargs["device"] = a.device
|
||||
kwargs["layout"] = a.layout
|
||||
kwargs["requires_grad"] = a.requires_grad
|
||||
kwargs["requires_grad"] = requires_grad or a.requires_grad
|
||||
kwargs["dtype"] = a.dtype
|
||||
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
||||
|
||||
@ -39,7 +39,7 @@ class TwoTensor(torch.Tensor):
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental
|
||||
def __init__(self, a, b, outer_size=None, outer_stride=None):
|
||||
def __init__(self, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
|
Reference in New Issue
Block a user