mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.library.custom_op (#122344)
This is the entrypoint for defining an opaque/blackbox (e.g. PyTorch will never peek into it) custom op. In this PR, you can specify backend impls and the abstract impl for this op. NB: most of this PR is docstrings, please don't be intimidated by the line count. There are a number of interesting features: - we infer the schema from type hints. In a followup I add the ability to manually specify a schema. - name inference. The user needs to manually specify an op name for now. In a followup we add the ability to automatically infer a name (this is a little tricky). - custom_op registrations can override each other. This makes them more pleasant to work with in environments like colab. - we require that the outputs of the custom_op do not alias any inputs or each other. We enforce this via a runtime check, but can relax this into an opcheck test if it really matters in the future. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/122344 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
@ -12,16 +12,17 @@ import typing
|
||||
|
||||
import torch._custom_ops as custom_ops
|
||||
|
||||
import torch.testing._internal.custom_op_db
|
||||
import torch.testing._internal.optests as optests
|
||||
import torch.utils.cpp_extension
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._custom_op.impl import custom_op, CustomOp, infer_schema
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.testing._internal import custom_op_db
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.custom_op_db import numpy_nonzero
|
||||
from typing import * # noqa: F403
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CustomOpTestCaseBase(TestCase):
|
||||
@ -323,7 +324,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
|
||||
@ops(custom_op_db, dtypes=OpDTypes.any_one)
|
||||
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
|
||||
def test_opcheck_opinfo(self, device, dtype, op):
|
||||
for sample_input in op.sample_inputs(
|
||||
device, dtype, requires_grad=op.supports_autograd
|
||||
@ -331,7 +332,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
if op.op in (
|
||||
torch.ops._torch_testing.numpy_nonzero,
|
||||
numpy_nonzero._opoverload,
|
||||
torch.ops._torch_testing.numpy_nms,
|
||||
):
|
||||
ctx = self.assertRaisesRegex(optests.OpCheckError, "failed with")
|
||||
@ -1452,7 +1453,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def test_meta_for_data_dependent_shape_operation(self):
|
||||
x = torch.randn(10, device="meta")
|
||||
with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
|
||||
torch.ops._torch_testing.numpy_nonzero(x)
|
||||
numpy_nonzero(x)
|
||||
|
||||
def test_basic_make_fx(self):
|
||||
# More serious tests are in our CustomOp opinfo db,
|
||||
@ -1494,31 +1495,16 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
|
||||
op((1, 2, 3))
|
||||
|
||||
def test_abstract_registration_location(self):
|
||||
custom_op = torch._custom_op.impl._find_custom_op(
|
||||
"_torch_testing::numpy_nonzero"
|
||||
)
|
||||
source = torch._library.simple_registry.singleton.find(
|
||||
"_torch_testing::numpy_nonzero"
|
||||
).abstract_impl.kernel.source
|
||||
self.assertRegex(source, r".*custom_op_db.py:\d+")
|
||||
|
||||
def test_data_dependent_basic(self):
|
||||
def f(x):
|
||||
return torch.ops._torch_testing.numpy_nonzero(x)
|
||||
|
||||
x = torch.randn(5, 5)
|
||||
gm = make_fx(f, tracing_mode="symbolic")(x)
|
||||
gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
|
||||
self.assertTrue("nonzero" in gm.code)
|
||||
|
||||
def test_data_dependent_fake_tracing(self):
|
||||
def f(x):
|
||||
return torch.ops._torch_testing.numpy_nonzero(x)
|
||||
|
||||
x = torch.randn(5, 5)
|
||||
# We've updated to attempt to use unbacked symints even for fake
|
||||
# tracing
|
||||
make_fx(f, tracing_mode="fake")(x)
|
||||
make_fx(numpy_nonzero, tracing_mode="fake")(x)
|
||||
|
||||
def test_symints(self):
|
||||
def f(x):
|
||||
@ -1552,15 +1538,17 @@ def forward(self, x_1):
|
||||
|
||||
@torch.compile(backend=cnt)
|
||||
def f(x):
|
||||
return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone()
|
||||
return numpy_nonzero(x.clone()).clone()
|
||||
|
||||
f(torch.randn(10))
|
||||
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{
|
||||
"dynamic shape operator: _torch_testing.numpy_nonzero.default; to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True": 1 # noqa: B950
|
||||
},
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
self.assertEqual(next(iter(counters["graph_break"].values())), 1)
|
||||
self.assertExpectedInline(
|
||||
next(iter(counters["graph_break"].keys())).replace(";", "\n"),
|
||||
"""\
|
||||
dynamic shape operator: _torch_testing.numpy_nonzero.default
|
||||
to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
|
||||
)
|
||||
|
||||
# pre-existing problem: torch.compile(dynamic=True) will, by default,
|
||||
@ -2057,6 +2045,222 @@ class MiniOpTest(CustomOpTestCaseBase):
|
||||
y = op(x)
|
||||
|
||||
|
||||
class TestCustomOpAPI(TestCase):
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_basic(self):
|
||||
@torch.library.custom_op("_torch_testing::add", mutated_args=())
|
||||
def add(x: Tensor, y: float) -> Tensor:
|
||||
x_np = x.numpy(force=True)
|
||||
out_np = x_np + y
|
||||
return torch.from_numpy(out_np).to(x.device)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = 3.14
|
||||
z = add(x, y)
|
||||
self.assertEqual(z, x + y)
|
||||
|
||||
cpu_called = False
|
||||
|
||||
@add.register_impl("cpu")
|
||||
def _(x, y):
|
||||
nonlocal cpu_called
|
||||
cpu_called = True
|
||||
x_np = x.numpy()
|
||||
out_np = x_np + y
|
||||
return torch.from_numpy(out_np)
|
||||
|
||||
z = add(x, y)
|
||||
self.assertEqual(z, x + y)
|
||||
self.assertTrue(cpu_called)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
||||
)
|
||||
def test_fake(self):
|
||||
@torch.library.custom_op("_torch_testing::add", mutated_args=())
|
||||
def add(x: Tensor, y: float) -> Tensor:
|
||||
x_np = x.cpu().numpy()
|
||||
out_np = x_np + y
|
||||
return torch.from_numpy(out_np).to(x.device)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = 3.14
|
||||
z = add(x, y)
|
||||
self.assertEqual(z, x + y)
|
||||
|
||||
try:
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
x = torch.randn(3)
|
||||
add(x, y)
|
||||
raise AssertionError("should not be hit")
|
||||
except RuntimeError as e:
|
||||
abstract_impl_error_msg = str(e)
|
||||
abstract_impl_error_msg = re.sub(
|
||||
r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
|
||||
).replace(". ", ".\n")
|
||||
self.assertExpectedInline(
|
||||
abstract_impl_error_msg,
|
||||
"""\
|
||||
There was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
|
||||
This is necessary for torch.compile/export/fx tracing to work.
|
||||
Please use `add.register_fake` to add an fake impl.""",
|
||||
)
|
||||
|
||||
if not IS_WINDOWS:
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f(x, y):
|
||||
return add(x, y)
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(RuntimeError, "no fake impl"):
|
||||
f(x, y)
|
||||
|
||||
abstract_called = False
|
||||
|
||||
@add.register_fake
|
||||
def _(x, y):
|
||||
nonlocal abstract_called
|
||||
abstract_called = True
|
||||
return torch.empty_like(x)
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
x = torch.randn(3)
|
||||
z = add(x, y)
|
||||
self.assertEqual(z.shape, x.shape)
|
||||
self.assertTrue(abstract_called)
|
||||
|
||||
@skipIfTorchDynamo("recursive dynamo")
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
||||
)
|
||||
def test_compile(self):
|
||||
called_impl = False
|
||||
called_abstract = False
|
||||
|
||||
@torch.library.custom_op("_torch_testing::linear", mutated_args=())
|
||||
def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
||||
nonlocal called_impl
|
||||
called_impl = True
|
||||
x_np = x.numpy()
|
||||
w_np = weight.numpy()
|
||||
b_np = bias.numpy()
|
||||
out_np = np.add(x_np @ w_np.T, bias)
|
||||
return out_np
|
||||
|
||||
@custom_linear.register_fake
|
||||
def _(x, weight, bias):
|
||||
nonlocal called_abstract
|
||||
called_abstract = True
|
||||
assert x.dim() == 2
|
||||
assert weight.dim() == 2
|
||||
assert bias.dim() == 1
|
||||
assert x.shape[1] == weight.shape[1]
|
||||
assert weight.shape[0] == bias.shape[0]
|
||||
assert x.device == weight.device
|
||||
return x.new_empty(x.size(0), weight.size(0))
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
weight = torch.randn(2, 2)
|
||||
bias = torch.randn(2)
|
||||
out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
|
||||
x, weight, bias
|
||||
)
|
||||
self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
|
||||
self.assertTrue(called_impl)
|
||||
self.assertTrue(called_abstract)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_replacement(self):
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args=())
|
||||
def f(x: Tensor) -> Tensor:
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(3)
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args=())
|
||||
def f(x: Tensor) -> Tensor:
|
||||
return x.cos()
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_split_device(self):
|
||||
cpu_call_count = 0
|
||||
cuda_call_count = 0
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::f", mutated_args=(), device_types="cpu"
|
||||
)
|
||||
def f(x: Tensor) -> Tensor:
|
||||
nonlocal cpu_call_count
|
||||
cpu_call_count += 1
|
||||
x_np = x.numpy()
|
||||
out_np = np.sin(x_np)
|
||||
return torch.from_numpy(out_np)
|
||||
|
||||
@f.register_impl("cuda")
|
||||
def _(x: Tensor) -> Tensor:
|
||||
nonlocal cuda_call_count
|
||||
cuda_call_count += 1
|
||||
x_np = x.cpu().numpy()
|
||||
out_np = np.sin(x_np)
|
||||
return torch.from_numpy(out_np).to(x.device)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
self.assertEqual(cpu_call_count, 1)
|
||||
self.assertEqual(cuda_call_count, 0)
|
||||
|
||||
x = x.cuda()
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
self.assertEqual(cpu_call_count, 1)
|
||||
self.assertEqual(cuda_call_count, 1)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_multi_types(self):
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::f", mutated_args=(), device_types=("cpu", "cuda")
|
||||
)
|
||||
def f(x: Tensor) -> Tensor:
|
||||
x_np = x.cpu().numpy()
|
||||
out_np = np.sin(x_np)
|
||||
return torch.from_numpy(out_np).to(x.device)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
x = x.cuda()
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
def test_disallows_output_aliasing(self):
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args=())
|
||||
def f(x: Tensor) -> Tensor:
|
||||
return x.view(-1)
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
f(x)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args=())
|
||||
def f(x: Tensor) -> Tensor:
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
f(x)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
||||
@ -2296,6 +2500,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
||||
instantiate_parametrized_tests(TestCustomOp)
|
||||
instantiate_parametrized_tests(TestCustomOpAPI)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user