mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
As described in https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk/edit This PR changes the CustomOp API to be private and adds new public wrappers around it so that the user does not need to know about the "CustomOp" object. We've effectively changed the "CustomOp" object to be some metadata about the operator that the user does not directly interact with. The "updated custom op API" is in torch._custom_ops. Pending good customer feedback, we will promote this module to torch.custom_ops. NB: I cannot move around the older torch._custom_op APIs yet because people are already using them. Test Plan: - I changed all of our tests to use the new `torch._custom_ops` module instead of the old CustomOp API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105947 Approved by: https://github.com/soulitzer
1345 lines
47 KiB
Python
1345 lines
47 KiB
Python
# Owner(s): ["module: custom-operators"]
|
|
|
|
from torch.testing._internal.common_utils import * # noqa: F403
|
|
from torch.testing._internal.common_device_type import * # noqa: F403
|
|
import collections
|
|
|
|
import itertools
|
|
import re
|
|
import typing
|
|
|
|
import torch._custom_ops as custom_ops
|
|
|
|
import torch.testing._internal.custom_op_db
|
|
from functorch import make_fx
|
|
from torch import Tensor
|
|
from torch._custom_op.impl import custom_op, CustomOp
|
|
from torch.testing._internal.custom_op_db import custom_op_db
|
|
from torch.testing._internal.optests.compile_check import operator_compile_check
|
|
from typing import * # noqa: F403
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
class TestCustomOpTesting(TestCase):
|
|
def setUp(self):
|
|
self.test_ns = "_test_custom_op"
|
|
self.libraries = []
|
|
|
|
def tearDown(self):
|
|
import torch._custom_op
|
|
|
|
keys = list(torch._custom_op.impl.global_registry.keys())
|
|
for key in keys:
|
|
if not key.startswith(f"{self.test_ns}::"):
|
|
continue
|
|
torch._custom_op.impl.global_registry[key]._destroy()
|
|
if hasattr(torch.ops, self.test_ns):
|
|
del torch.ops._test_custom_op
|
|
for lib in self.libraries:
|
|
del lib.m
|
|
del self.libraries
|
|
|
|
def ns(self):
|
|
return getattr(torch.ops, self.test_ns)
|
|
|
|
def lib(self):
|
|
result = torch.library.Library(self.test_ns, "FRAGMENT")
|
|
self.libraries.append(result)
|
|
return result
|
|
|
|
def get_op(self, qualname):
|
|
ns, name = qualname.split("::")
|
|
return getattr(getattr(torch.ops, ns), name).default
|
|
|
|
def test_incorrect_schema_mutation(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
guard = torch._C._AutoDispatchBelowAutograd()
|
|
try:
|
|
return op(x)
|
|
finally:
|
|
del guard
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
def foo_impl(x):
|
|
x.sin_()
|
|
return x.clone()
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_impl, "CUDA")
|
|
|
|
def f(x):
|
|
x = x.clone()
|
|
v = x.view_as(x)
|
|
y = op(v)
|
|
return x
|
|
|
|
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Argument x is not defined as mutable but was mutated"
|
|
):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
def test_incorrect_schema_view(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
with torch._C._ExcludeDispatchKeyGuard(
|
|
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
|
|
):
|
|
return op(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
def foo_impl(x):
|
|
return x.view_as(x)
|
|
|
|
def foo_meta(x):
|
|
return x.view_as(x)
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_meta, "Meta")
|
|
|
|
def f(x):
|
|
x = x.clone()
|
|
y = op(x)
|
|
x.sin_()
|
|
return y
|
|
|
|
x = torch.tensor(3.14159 / 3, requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Argument x is not defined to alias output but was aliasing"
|
|
):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
def test_missing_abstract_impl(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
return op(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return 2 * gx
|
|
|
|
def foo_impl(x):
|
|
return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_impl, "CUDA")
|
|
|
|
def f(x):
|
|
y = op(x)
|
|
return y.sum(0)
|
|
|
|
x = torch.tensor([0, 1.0], requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
torch._subclasses.fake_tensor.UnsupportedOperatorException,
|
|
"_test_custom_op.foo.default",
|
|
):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
def test_incorrect_abstract_impl(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
|
|
guard = torch._C._AutoDispatchBelowAutograd()
|
|
guard2 = torch._C.ExcludeDispatchKeyGuard(
|
|
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
|
|
)
|
|
try:
|
|
return op(x)
|
|
finally:
|
|
del guard
|
|
del guard2
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
def foo_impl(x):
|
|
return x**2
|
|
|
|
def foo_meta(x):
|
|
return x.unsqueeze(1) ** 2
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_impl, "CUDA")
|
|
lib.impl("foo", foo_meta, "Meta")
|
|
|
|
def f(x):
|
|
y = op(x)
|
|
return y.sum(0)
|
|
|
|
x = torch.tensor([0, 1.0], requires_grad=True)
|
|
with self.assertRaisesRegex(RuntimeError, "Shapes .* are not equal"):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
def test_missing_functionalization(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.mark_dirty(x)
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
return op(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
def foo_impl(x):
|
|
return x.sin_()
|
|
|
|
def foo_meta(x):
|
|
return x
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_impl, "CUDA")
|
|
lib.impl("foo", foo_meta, "Meta")
|
|
|
|
def f(x):
|
|
x = x.clone()
|
|
y = op(x)
|
|
return y.sum(0)
|
|
|
|
x = torch.tensor([0, 1.0], requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Getting these operators to work with functionalization requires some extra work",
|
|
):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
def test_autograd_registered_at_backend(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx * 0.5
|
|
|
|
lib.impl("foo", Foo.apply, "CPU")
|
|
lib.impl("foo", Foo.apply, "CUDA")
|
|
lib.impl("foo", lambda x: x.clone(), "Meta")
|
|
|
|
def f(x):
|
|
y = op(x)
|
|
return x + y
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "mismatched requires_grad-ness"):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
# I'm not sure why this is necessary
|
|
del lib
|
|
|
|
def test_global_state_mutation(self, device):
|
|
lib = self.lib()
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.ns().foo.default
|
|
|
|
class Foo(torch.autograd.Function):
|
|
invoked = 0
|
|
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
Foo.invoked += 1
|
|
return x.clone() * Foo.invoked
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
|
|
|
|
def f(x):
|
|
return op(x)
|
|
|
|
x = torch.tensor(3.14159 / 3, requires_grad=True)
|
|
with self.assertRaisesRegex(AssertionError, "not completely traceable"):
|
|
operator_compile_check(f, (x,), {})
|
|
|
|
@ops(custom_op_db, dtypes=OpDTypes.any_one)
|
|
def test_operator_compile_check_op(self, device, dtype, op):
|
|
for sample_input in op.sample_inputs(
|
|
device, dtype, requires_grad=op.supports_autograd
|
|
):
|
|
dynamic_only = op.name in ("NumpyNMSCustomOp", "NumpyNonzeroCustomOp")
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
operator_compile_check(
|
|
op.op,
|
|
args,
|
|
kwargs,
|
|
supports_autograd=op.supports_autograd,
|
|
dynamic_only=dynamic_only,
|
|
fullgraph=False, # Dynamo graph breaks on CustomOp today
|
|
)
|
|
|
|
def test_operator_compile_check_fails_basic(self, device):
|
|
@custom_op(f"{self.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
...
|
|
|
|
@foo.impl(["cpu", "cuda"])
|
|
def foo_impl(x):
|
|
return x.sum()
|
|
|
|
x = torch.randn(3, device=device, requires_grad=True)
|
|
# Triggers the CustomOp autograd NYI error
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Autograd has not been implemented for operator"
|
|
):
|
|
operator_compile_check(
|
|
lambda x: self.get_op(f"{self.test_ns}::foo")(x), (x,), {}
|
|
)
|
|
|
|
def test_assert_raises_regex(self, device):
|
|
from torch.testing._internal.optests.aot_autograd import assert_raises_regex
|
|
|
|
with assert_raises_regex(RuntimeError, "c"):
|
|
raise RuntimeError("abcd")
|
|
with assert_raises_regex(RuntimeError, "c.*"):
|
|
raise RuntimeError("abcd")
|
|
with self.assertRaisesRegex(AssertionError, "instead got"):
|
|
with assert_raises_regex(RuntimeError, "c.*"):
|
|
raise ValueError("abcd")
|
|
with self.assertRaisesRegex(AssertionError, "Expected exception"):
|
|
with assert_raises_regex(RuntimeError, "c.*"):
|
|
pass
|
|
with self.assertRaisesRegex(AssertionError, "to match regex"):
|
|
with assert_raises_regex(RuntimeError, "f"):
|
|
raise RuntimeError("abcd")
|
|
|
|
|
|
class TestCustomOp(TestCase):
|
|
test_ns = "_test_custom_op"
|
|
|
|
def tearDown(self):
|
|
import torch._custom_op
|
|
|
|
keys = list(torch._custom_op.impl.global_registry.keys())
|
|
for key in keys:
|
|
if not key.startswith(f"{TestCustomOp.test_ns}::"):
|
|
continue
|
|
torch._custom_op.impl.global_registry[key]._destroy()
|
|
|
|
def get_op(self, qualname):
|
|
ns, name = qualname.split("::")
|
|
return getattr(getattr(torch.ops, ns), name).default
|
|
|
|
def test_invalid_schemas(self):
|
|
# function schmea validation goes through torchgen, so this is just a
|
|
# basic test.
|
|
with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
|
|
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
|
|
|
|
def test_name_must_match(self):
|
|
with self.assertRaisesRegex(ValueError, "to have name"):
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def baz(x: Tensor) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def test_unsupported_schemas(self):
|
|
with self.assertRaisesRegex(ValueError, "does not support non-functional"):
|
|
custom_ops.custom_op(
|
|
f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
|
|
)(foo)
|
|
with self.assertRaisesRegex(ValueError, "does not support view functions"):
|
|
custom_ops.custom_op(
|
|
f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
|
|
)(foo)
|
|
with self.assertRaisesRegex(ValueError, "no outputs"):
|
|
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
|
|
foo
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "self"):
|
|
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
|
|
foo
|
|
)
|
|
|
|
# Tests for the older custom_op API
|
|
def test_schema_matches_signature(self):
|
|
with self.assertRaisesRegex(ValueError, "signature to match"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
|
|
def blah(x):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "signature to match"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
|
|
)
|
|
def blah2(x, y):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "signature to match"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah3",
|
|
"(Tensor x, *, Tensor w, Tensor z) -> Tensor",
|
|
)
|
|
def blah3(x, *, y, z):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "signature to match"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah4",
|
|
"(Tensor x, *, Tensor z, Tensor y) -> Tensor",
|
|
)
|
|
def blah4(x, *, y, z):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "not supported"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
|
|
def blah5(*args):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "not supported"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
|
|
)
|
|
def blah6(**kwargs):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "default arguments"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
|
|
)
|
|
def blah7(x=1, *, y):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(ValueError, "default arguments"):
|
|
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
|
|
)
|
|
def blah8(x, *, y=1):
|
|
pass
|
|
|
|
# kwonly-arg works
|
|
@custom_op(
|
|
f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
|
|
)
|
|
def blah9(x, *, y):
|
|
pass
|
|
|
|
# Tests for the older custom_op API
|
|
def test_unsupported_annotation_categories(self):
|
|
with self.assertRaisesRegex(ValueError, "varargs"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(*args):
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "varkwargs"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(**kwargs):
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x):
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "default value"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Optional[Tensor] = None):
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "default value"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Optional[Tensor] = None):
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor) -> int:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor) -> Tuple[Tensor, int]:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"):
|
|
|
|
@custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor) -> Tuple[Tensor, ...]:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
def test_supported_param_types(self):
|
|
def generate_examples(typ):
|
|
if typ is int:
|
|
return [17]
|
|
if typ is float:
|
|
return [3.14]
|
|
if typ is bool:
|
|
return [True]
|
|
if typ is str:
|
|
return ["foo"]
|
|
if typ is torch.dtype:
|
|
return [torch.float32]
|
|
if typ is torch.device:
|
|
return [torch.device("cpu")]
|
|
if typ == torch.types.Number:
|
|
return [2.718]
|
|
if typ is torch.Tensor:
|
|
return [torch.tensor(3)]
|
|
if typ == Optional[torch.types.Number]:
|
|
return [None, 2.718]
|
|
origin = typing.get_origin(typ)
|
|
if origin is Union:
|
|
args = typing.get_args(typ)
|
|
assert len(args) == 2 and (
|
|
args[0] is type(None) or args[1] is type(None)
|
|
)
|
|
elt = args[0] if args[1] is type(None) else args[1]
|
|
return generate_examples(elt) + [None]
|
|
if origin is collections.abc.Sequence:
|
|
args = typing.get_args(typ)
|
|
assert len(args) == 1
|
|
examples = generate_examples(args[0])
|
|
return list(itertools.product(examples, examples)) + []
|
|
raise AssertionError(f"unsupported param type {typ}")
|
|
|
|
for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES:
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor, y: typ) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
yeet = None
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
|
|
def foo_cpu(x, y):
|
|
nonlocal yeet
|
|
yeet = y
|
|
return x.clone()
|
|
|
|
try:
|
|
for example in generate_examples(typ):
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
op(torch.randn([]), example)
|
|
self.assertEqual(yeet, example, msg=f"{typ} {example}")
|
|
yeet = None
|
|
finally:
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
|
|
def test_sequences(self):
|
|
# Sequence[int] gets automagically turned into int[] in the schema.
|
|
# This test checks that we actually do support arbitrary sequence types.
|
|
class MySequence(collections.abc.Sequence):
|
|
def __init__(self):
|
|
self._container = [1, 2, 3]
|
|
|
|
def __getitem__(self, idx):
|
|
return self._container[idx]
|
|
|
|
def __len__(self):
|
|
return len(self._container)
|
|
|
|
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
|
def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
called = 0
|
|
|
|
@custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
|
|
def foo_cpu(x, sizes):
|
|
nonlocal called
|
|
called += 1
|
|
# Dispatcher will normalize the sequence type into a List
|
|
self.assertEqual(sizes, [1, 2, 3])
|
|
return x.clone()
|
|
|
|
x = torch.randn([])
|
|
seq = MySequence()
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
op(x, seq)
|
|
self.assertEqual(called, 1)
|
|
|
|
def test_unsupported_param_types(self):
|
|
# Not comprehensive (it doesn't need to be), just a check that our mechanism works
|
|
with self.assertRaisesRegex(ValueError, "unsupported type"):
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "unsupported type"):
|
|
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "unsupported type"):
|
|
# We could theoretically support this, but the syntax for suporting
|
|
# int[] is Sequence[int]
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor, y: List[int]) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
with self.assertRaisesRegex(ValueError, "unsupported type"):
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Tensor, y: Callable) -> Tensor:
|
|
raise NotImplementedError()
|
|
|
|
del foo
|
|
|
|
def test_supported_schemas(self):
|
|
# All of these should already be tested by PyTorch codegen
|
|
# (we share the same mechanism), but here's a sanity check.
|
|
schemas = [
|
|
"(Tensor x) -> Tensor",
|
|
"(Tensor x) -> Tensor y",
|
|
"(Tensor[] x) -> Tensor y",
|
|
"(Tensor x) -> (Tensor, Tensor)",
|
|
"(Tensor x) -> (Tensor y, Tensor z)",
|
|
"(Tensor x) -> (Tensor y, Tensor z)",
|
|
]
|
|
other_schemas = [
|
|
"(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
|
|
"(Tensor x, Tensor w) -> (Tensor, Tensor)",
|
|
"(Tensor x, Tensor w) -> Tensor",
|
|
"(Tensor? x, Tensor w) -> Tensor",
|
|
"(Tensor? x, Tensor[] w) -> Tensor",
|
|
"(Tensor x, int[] w) -> Tensor",
|
|
"(Tensor x, SymInt[] w) -> Tensor",
|
|
"(Tensor x, Scalar w) -> Tensor",
|
|
"(Tensor x, float w) -> Tensor",
|
|
"(Tensor x, float? w) -> Tensor",
|
|
"(Tensor x, bool[] w) -> Tensor",
|
|
]
|
|
|
|
for schema in schemas:
|
|
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
for schema in other_schemas:
|
|
custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
|
|
|
|
def test_reserved_ns(self):
|
|
from torch._custom_op.impl import RESERVED_NS
|
|
|
|
for ns in RESERVED_NS:
|
|
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
|
|
custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
|
|
|
|
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
|
|
|
|
@custom_ops.custom_op(f"{ns}::foo2")
|
|
def foo2(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def test_private_ctor(self):
|
|
with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
|
|
CustomOp(None, None, None, None, None)
|
|
|
|
def test_lifetime(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
|
|
|
|
# We can't define an op multiple times,
|
|
with self.assertRaisesRegex(RuntimeError, "multiple times"):
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
|
raise NotImplementedError()
|
|
|
|
# Unless we delete the original op.
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
|
|
# Smoke test
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
|
raise NotImplementedError()
|
|
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
|
|
def test_autograd_notimplemented(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
|
raise NotImplementedError()
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
|
|
op(x)
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
del foo
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
|
|
op([y, x])
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
del foo
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
|
|
op(y, x)
|
|
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
|
|
|
def test_autograd_notimplemented_gradmode(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x, y):
|
|
return x * y
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with torch.no_grad():
|
|
# Shouldn't raise, because we are in no_grad
|
|
op(y, x)
|
|
|
|
def test_impl_cpu(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
|
|
def foo_cpu(x):
|
|
return x.sin()
|
|
|
|
x = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
result = op(x)
|
|
self.assertEqual(result, foo_cpu(x))
|
|
|
|
def test_impl_invalid_devices(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
|
|
|
|
for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
|
|
# Smoke test: should not raise error
|
|
custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
|
|
foo_impl
|
|
)
|
|
|
|
# Not supported by this API: we can either support them in the future
|
|
# or provide some other CustomOp.def_* function. This depends on how
|
|
# common the use cases are.
|
|
for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
|
|
with self.assertRaisesRegex(ValueError, "we only support device_type"):
|
|
custom_ops.impl(
|
|
f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
|
|
)(foo_impl)
|
|
|
|
def test_backward_partially_registered(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return grad * saved.cos()
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "unable to find a 'save_for_backward'"
|
|
):
|
|
y = op(x)
|
|
y.backward()
|
|
|
|
def test_save_for_backward_inputs_are_namedtuple(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
hit = 0
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
nonlocal hit
|
|
hit += 1
|
|
self.assertTrue(isinstance(inputs, tuple))
|
|
self.assertEqual(list(inputs._asdict().keys()), ["x"])
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": grad * saved.cos()}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x)
|
|
self.assertEqual(hit, 1)
|
|
y.backward()
|
|
self.assertEqual(hit, 1)
|
|
|
|
def test_backward_returns_dict(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return grad * saved.cos()
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x)
|
|
with self.assertRaisesRegex(RuntimeError, "to be a dict"):
|
|
y.backward()
|
|
|
|
def test_backward_dict_invalid_keys(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": grad * saved.cos(), "y": None}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x)
|
|
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
|
|
y.backward()
|
|
|
|
def test_backward_dict_grad_for_nontensor(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x, dim):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": grad * saved.cos(), "dim": None}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x, 32)
|
|
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
|
|
y.backward()
|
|
|
|
def test_backward_dict_requires_keys_for_input_tensors(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x, y):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": grad * saved.cos()}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x, x)
|
|
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
|
|
y.backward()
|
|
|
|
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x, y):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": grad * saved.cos()}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x, None)
|
|
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
|
|
y.backward()
|
|
|
|
def test_backward_grads_are_tensor_or_none(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.x
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"x": (grad * saved.cos(),)}
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(x)
|
|
with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
|
|
y.backward()
|
|
|
|
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(xs):
|
|
return xs[0].sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.xs[0]
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"xs": [grad * saved.cos(), None]}
|
|
|
|
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(xs)
|
|
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
|
|
y.backward()
|
|
|
|
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(xs):
|
|
return xs[0].sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.xs[0]
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"xs": [grad * saved.cos(), None, (None,)]}
|
|
|
|
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(xs)
|
|
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
|
|
y.backward()
|
|
|
|
def test_backward_tensorlist_input_requires_list_grads(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(xs):
|
|
return xs[0].sin()
|
|
|
|
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_save_for_backward(inputs, output):
|
|
return inputs.xs[0]
|
|
|
|
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"xs": None}
|
|
|
|
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
y = op(xs)
|
|
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
|
|
y.backward()
|
|
|
|
def test_backward_output_differentiability_type(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
|
|
|
@custom_ops.impl_backward(
|
|
f"{TestCustomOp.test_ns}::foo", output_differentiability=True
|
|
)
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"xs": None}
|
|
|
|
def test_backward_output_differentiability_numel(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
|
|
|
@custom_ops.impl_backward(
|
|
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
|
|
)
|
|
def foo_backward(ctx, saved, grad):
|
|
return {"xs": None}
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
|
def test_impl_separate(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
|
|
def foo_cpu(x):
|
|
return x.sin()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
|
|
def foo_cuda(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
result = op(x)
|
|
self.assertEqual(result, foo_cpu(x))
|
|
|
|
x_cuda = x.cuda()
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
result = op(x_cuda)
|
|
self.assertEqual(result, foo_cuda(x_cuda))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
|
def test_impl_multiple(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_impl(x):
|
|
return x.cos()
|
|
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
x = torch.randn(3)
|
|
result = op(x)
|
|
self.assertEqual(result, foo_impl(x))
|
|
|
|
x_cuda = x.cuda()
|
|
result = op(x_cuda)
|
|
self.assertEqual(result, foo_impl(x_cuda))
|
|
|
|
def test_impl_meta(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_meta(x, dim):
|
|
output_shape = list(x.shape)
|
|
del output_shape[dim]
|
|
return x.new_empty(output_shape)
|
|
|
|
x = torch.randn(2, 3, device="meta")
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
result = op(x, 1)
|
|
self.assertEqual(result.shape, foo_meta(x, 1).shape)
|
|
|
|
def test_duplicate_impl(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_meta(x, dim):
|
|
output_shape = list(x.shape)
|
|
del output_shape[dim]
|
|
return x.new_empty(output_shape)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"already has a abstract impl.*at .*test_custom_ops.py:\d+"
|
|
):
|
|
|
|
@custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_meta2(x, dim):
|
|
output_shape = list(x.shape)
|
|
del output_shape[dim]
|
|
return x.new_empty(output_shape)
|
|
|
|
def test_new_data_dependent_symint(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_meta(x):
|
|
ctx = torch._custom_op.impl.get_ctx()
|
|
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):
|
|
ctx.create_unbacked_symint(min=1)
|
|
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):
|
|
ctx.create_unbacked_symint(min=-1)
|
|
with self.assertRaisesRegex(ValueError, "SymInt"):
|
|
ctx.create_unbacked_symint(max=x.numel())
|
|
return torch.clone(x)
|
|
|
|
x = torch.randn(2, 3, device="cpu")
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
make_fx(op, tracing_mode="symbolic")(x)
|
|
|
|
def test_meta_for_data_dependent_shape_operation(self):
|
|
x = torch.randn(10, device="meta")
|
|
with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
|
|
torch.ops._torch_testing.numpy_nonzero(x)
|
|
|
|
def test_basic_make_fx(self):
|
|
# More serious tests are in our CustomOp opinfo db,
|
|
# this one is just a sanity check.
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo")
|
|
def foo_meta(x):
|
|
return x.sum()
|
|
|
|
x = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
gm = make_fx(op, tracing_mode="symbolic")(x)
|
|
self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
|
|
|
|
def test_not_implemented_error(self):
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
x = torch.randn(3)
|
|
op = self.get_op(f"{self.test_ns}::foo")
|
|
with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
|
|
op(x)
|
|
|
|
x = torch.randn(3, device="meta")
|
|
with self.assertRaisesRegex(NotImplementedError, "abstract impl registered"):
|
|
op(x)
|
|
|
|
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
|
|
def bar(sizes: Sequence[int]) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
op = self.get_op(f"{self.test_ns}::bar")
|
|
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
|
|
op((1, 2, 3))
|
|
|
|
def test_abstract_registration_location(self):
|
|
custom_op = torch._custom_op.impl._find_custom_op(
|
|
"_torch_testing::numpy_nonzero"
|
|
)
|
|
loc = custom_op._get_impl("abstract").location
|
|
matches = re.match(r".*custom_op_db.py:\d+", loc)
|
|
self.assertIsNotNone(matches)
|
|
|
|
def test_data_dependent_basic(self):
|
|
def f(x):
|
|
return torch.ops._torch_testing.numpy_nonzero(x)
|
|
|
|
x = torch.randn(5, 5)
|
|
gm = make_fx(f, tracing_mode="symbolic")(x)
|
|
self.assertTrue("nonzero" in gm.code)
|
|
|
|
def test_data_dependent_fake_tracing(self):
|
|
def f(x):
|
|
return torch.ops._torch_testing.numpy_nonzero(x)
|
|
|
|
x = torch.randn(5, 5)
|
|
with self.assertRaises(
|
|
torch._subclasses.fake_tensor.DynamicOutputShapeException
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x)
|
|
|
|
def test_symints(self):
|
|
def f(x):
|
|
return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
gm = make_fx(f, tracing_mode="symbolic")(x)
|
|
result = gm(x)
|
|
self.assertEqual(result, f(x))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
|
sym_size_1 = torch.ops.aten.sym_size(x_1, 1)
|
|
sym_size_2 = torch.ops.aten.sym_size(x_1, 2)
|
|
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size, sym_size_1, sym_size_2]); x_1 = sym_size = sym_size_1 = sym_size_2 = None
|
|
return numpy_view_copy""", # noqa: B950
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
|
|
def test_data_dependent_compile(self):
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone()
|
|
|
|
f(torch.randn(10))
|
|
|
|
self.assertEqual(
|
|
dict(counters["graph_break"]),
|
|
{"dynamic shape operator: _torch_testing.numpy_nonzero.default": 1},
|
|
)
|
|
|
|
# pre-existing problem: torch.compile(dynamic=True) will, by default,
|
|
# graph break on data-dependent operations. Eventually we'll make it so
|
|
# that it never graph breaks on data-dependent operations.
|
|
@unittest.expectedFailure
|
|
def test_data_dependent_nms_dynamic_compile(self):
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, dynamic=True)
|
|
def f(x, s, i):
|
|
return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
|
|
|
|
f(torch.randn(20, 4), torch.randn(20), 0.1)
|
|
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
|
|
only_for = ("cpu", "cuda")
|
|
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|