Add torch.library.custom_op (#122344)

This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will
never peek into it) custom op. In this PR, you can specify backend impls
and the abstract impl for this op.

NB: most of this PR is docstrings, please don't be intimidated by the
line count.

There are a number of interesting features:
- we infer the schema from type hints. In a followup I add the ability
  to manually specify a schema.
- name inference. The user needs to manually specify an op name for now.
  In a followup we add the ability to automatically infer a name (this
  is a little tricky).
- custom_op registrations can override each other. This makes them
  more pleasant to work with in environments like colab.
- we require that the outputs of the custom_op do not alias any inputs
  or each other. We enforce this via a runtime check, but can relax this
  into an opcheck test if it really matters in the future.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
rzou
2024-04-03 06:42:09 -07:00
committed by PyTorch MergeBot
parent aa16c0163f
commit 44c0c0fc0f
9 changed files with 632 additions and 55 deletions

View File

@ -12,16 +12,17 @@ import typing
import torch._custom_ops as custom_ops
import torch.testing._internal.custom_op_db
import torch.testing._internal.optests as optests
import torch.utils.cpp_extension
from functorch import make_fx
from torch import Tensor
from torch._custom_op.impl import custom_op, CustomOp, infer_schema
from torch._utils_internal import get_file_path_2
from torch.testing._internal import custom_op_db
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.custom_op_db import numpy_nonzero
from typing import * # noqa: F403
import numpy as np
class CustomOpTestCaseBase(TestCase):
@ -323,7 +324,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
):
optests.opcheck(op, (x,), {})
@ops(custom_op_db, dtypes=OpDTypes.any_one)
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
def test_opcheck_opinfo(self, device, dtype, op):
for sample_input in op.sample_inputs(
device, dtype, requires_grad=op.supports_autograd
@ -331,7 +332,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
if op.op in (
torch.ops._torch_testing.numpy_nonzero,
numpy_nonzero._opoverload,
torch.ops._torch_testing.numpy_nms,
):
ctx = self.assertRaisesRegex(optests.OpCheckError, "failed with")
@ -1452,7 +1453,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_meta_for_data_dependent_shape_operation(self):
x = torch.randn(10, device="meta")
with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
torch.ops._torch_testing.numpy_nonzero(x)
numpy_nonzero(x)
def test_basic_make_fx(self):
# More serious tests are in our CustomOp opinfo db,
@ -1494,31 +1495,16 @@ class TestCustomOp(CustomOpTestCaseBase):
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
op((1, 2, 3))
def test_abstract_registration_location(self):
custom_op = torch._custom_op.impl._find_custom_op(
"_torch_testing::numpy_nonzero"
)
source = torch._library.simple_registry.singleton.find(
"_torch_testing::numpy_nonzero"
).abstract_impl.kernel.source
self.assertRegex(source, r".*custom_op_db.py:\d+")
def test_data_dependent_basic(self):
def f(x):
return torch.ops._torch_testing.numpy_nonzero(x)
x = torch.randn(5, 5)
gm = make_fx(f, tracing_mode="symbolic")(x)
gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
self.assertTrue("nonzero" in gm.code)
def test_data_dependent_fake_tracing(self):
def f(x):
return torch.ops._torch_testing.numpy_nonzero(x)
x = torch.randn(5, 5)
# We've updated to attempt to use unbacked symints even for fake
# tracing
make_fx(f, tracing_mode="fake")(x)
make_fx(numpy_nonzero, tracing_mode="fake")(x)
def test_symints(self):
def f(x):
@ -1552,15 +1538,17 @@ def forward(self, x_1):
@torch.compile(backend=cnt)
def f(x):
return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone()
return numpy_nonzero(x.clone()).clone()
f(torch.randn(10))
self.assertEqual(
dict(counters["graph_break"]),
{
"dynamic shape operator: _torch_testing.numpy_nonzero.default; to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True": 1 # noqa: B950
},
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(next(iter(counters["graph_break"].values())), 1)
self.assertExpectedInline(
next(iter(counters["graph_break"].keys())).replace(";", "\n"),
"""\
dynamic shape operator: _torch_testing.numpy_nonzero.default
to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
)
# pre-existing problem: torch.compile(dynamic=True) will, by default,
@ -2057,6 +2045,222 @@ class MiniOpTest(CustomOpTestCaseBase):
y = op(x)
class TestCustomOpAPI(TestCase):
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_basic(self):
@torch.library.custom_op("_torch_testing::add", mutated_args=())
def add(x: Tensor, y: float) -> Tensor:
x_np = x.numpy(force=True)
out_np = x_np + y
return torch.from_numpy(out_np).to(x.device)
x = torch.randn(3)
y = 3.14
z = add(x, y)
self.assertEqual(z, x + y)
cpu_called = False
@add.register_impl("cpu")
def _(x, y):
nonlocal cpu_called
cpu_called = True
x_np = x.numpy()
out_np = x_np + y
return torch.from_numpy(out_np)
z = add(x, y)
self.assertEqual(z, x + y)
self.assertTrue(cpu_called)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
def test_fake(self):
@torch.library.custom_op("_torch_testing::add", mutated_args=())
def add(x: Tensor, y: float) -> Tensor:
x_np = x.cpu().numpy()
out_np = x_np + y
return torch.from_numpy(out_np).to(x.device)
x = torch.randn(3)
y = 3.14
z = add(x, y)
self.assertEqual(z, x + y)
try:
with torch._subclasses.fake_tensor.FakeTensorMode():
x = torch.randn(3)
add(x, y)
raise AssertionError("should not be hit")
except RuntimeError as e:
abstract_impl_error_msg = str(e)
abstract_impl_error_msg = re.sub(
r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
).replace(". ", ".\n")
self.assertExpectedInline(
abstract_impl_error_msg,
"""\
There was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
This is necessary for torch.compile/export/fx tracing to work.
Please use `add.register_fake` to add an fake impl.""",
)
if not IS_WINDOWS:
@torch.compile(backend="eager")
def f(x, y):
return add(x, y)
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "no fake impl"):
f(x, y)
abstract_called = False
@add.register_fake
def _(x, y):
nonlocal abstract_called
abstract_called = True
return torch.empty_like(x)
with torch._subclasses.fake_tensor.FakeTensorMode():
x = torch.randn(3)
z = add(x, y)
self.assertEqual(z.shape, x.shape)
self.assertTrue(abstract_called)
@skipIfTorchDynamo("recursive dynamo")
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
def test_compile(self):
called_impl = False
called_abstract = False
@torch.library.custom_op("_torch_testing::linear", mutated_args=())
def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
nonlocal called_impl
called_impl = True
x_np = x.numpy()
w_np = weight.numpy()
b_np = bias.numpy()
out_np = np.add(x_np @ w_np.T, bias)
return out_np
@custom_linear.register_fake
def _(x, weight, bias):
nonlocal called_abstract
called_abstract = True
assert x.dim() == 2
assert weight.dim() == 2
assert bias.dim() == 1
assert x.shape[1] == weight.shape[1]
assert weight.shape[0] == bias.shape[0]
assert x.device == weight.device
return x.new_empty(x.size(0), weight.size(0))
x = torch.randn(2, 2)
weight = torch.randn(2, 2)
bias = torch.randn(2)
out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
x, weight, bias
)
self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
self.assertTrue(called_impl)
self.assertTrue(called_abstract)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_replacement(self):
@torch.library.custom_op("_torch_testing::f", mutated_args=())
def f(x: Tensor) -> Tensor:
return x.sin()
x = torch.randn(3)
y = f(x)
self.assertEqual(y, x.sin())
@torch.library.custom_op("_torch_testing::f", mutated_args=())
def f(x: Tensor) -> Tensor:
return x.cos()
y = f(x)
self.assertEqual(y, x.cos())
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_split_device(self):
cpu_call_count = 0
cuda_call_count = 0
@torch.library.custom_op(
"_torch_testing::f", mutated_args=(), device_types="cpu"
)
def f(x: Tensor) -> Tensor:
nonlocal cpu_call_count
cpu_call_count += 1
x_np = x.numpy()
out_np = np.sin(x_np)
return torch.from_numpy(out_np)
@f.register_impl("cuda")
def _(x: Tensor) -> Tensor:
nonlocal cuda_call_count
cuda_call_count += 1
x_np = x.cpu().numpy()
out_np = np.sin(x_np)
return torch.from_numpy(out_np).to(x.device)
x = torch.randn(3)
y = f(x)
self.assertEqual(y, x.sin())
self.assertEqual(cpu_call_count, 1)
self.assertEqual(cuda_call_count, 0)
x = x.cuda()
y = f(x)
self.assertEqual(y, x.sin())
self.assertEqual(cpu_call_count, 1)
self.assertEqual(cuda_call_count, 1)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_multi_types(self):
@torch.library.custom_op(
"_torch_testing::f", mutated_args=(), device_types=("cpu", "cuda")
)
def f(x: Tensor) -> Tensor:
x_np = x.cpu().numpy()
out_np = np.sin(x_np)
return torch.from_numpy(out_np).to(x.device)
x = torch.randn(3)
y = f(x)
self.assertEqual(y, x.sin())
x = x.cuda()
y = f(x)
self.assertEqual(y, x.sin())
def test_disallows_output_aliasing(self):
@torch.library.custom_op("_torch_testing::f", mutated_args=())
def f(x: Tensor) -> Tensor:
return x.view(-1)
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "may not alias"):
f(x)
@torch.library.custom_op("_torch_testing::f", mutated_args=())
def f(x: Tensor) -> Tensor:
return x
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "may not alias"):
f(x)
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"
@ -2296,6 +2500,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
instantiate_parametrized_tests(TestCustomOp)
instantiate_parametrized_tests(TestCustomOpAPI)
if __name__ == "__main__":
run_tests()