Add torch.library.opcheck (#124496)

This PR:
- exposes torch.testing._internal.optests.opcheck as
  torch.library.opcheck
- Adds support for CustomOpDef (aka functions decorated with
  torch.library.custom_op) to opcheck.

Test Plan:
- Updated tests
- We validated opcheck's design internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124496
Approved by: https://github.com/williamwen42
This commit is contained in:
rzou
2024-04-23 10:31:46 -07:00
committed by PyTorch MergeBot
parent 763dc26e59
commit 4ceb44c40d
4 changed files with 153 additions and 55 deletions

View File

@ -4,13 +4,23 @@ torch.library
.. currentmodule:: torch.library
torch.library is a collection of APIs for extending PyTorch's core library
of operators. It contains utilities for creating new custom operators as
well as extending operators defined with PyTorch's C++ operator
of operators. It contains utilities for testing custom operators, creating new
custom operators, and extending operators defined with PyTorch's C++ operator
registration APIs (e.g. aten operators).
For a detailed guide on effectively using these APIs, please see
`this gdoc <https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit>`_
Testing custom ops
------------------
Use :func:`torch.library.opcheck` to test custom ops for incorrect usage of the
Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports
training, use :func:`torch.autograd.gradcheck` to test that the gradients are
mathematically correct.
.. autofunction:: opcheck
Creating new custom ops in Python
---------------------------------

View File

@ -131,7 +131,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
with self.assertRaisesRegex(
optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
def test_incorrect_schema_view(self, device):
lib = self.lib()
@ -167,7 +167,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
optests.OpCheckError,
"Argument x is not defined to alias output but was aliasing",
):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
def test_missing_abstract_impl(self, device):
lib = self.lib()
@ -196,7 +196,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
optests.OpCheckError,
"_test_custom_op.foo.default",
):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
def test_incorrect_abstract_impl(self, device):
lib = self.lib()
@ -234,7 +234,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
x = torch.tensor([0, 1.0], requires_grad=True)
with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
def test_missing_functionalization(self, device):
lib = self.lib()
@ -269,7 +269,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
optests.OpCheckError,
"We only support functionalizing operators whose outputs do not have alias annotations",
):
optests.opcheck(op, (y,), {})
torch.library.opcheck(op, (y,), {})
def test_autograd_registered_at_backend(self, device):
lib = self.lib()
@ -295,7 +295,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
torch.testing._internal.optests.OpCheckError,
"does not have an autograd kernel",
):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
# I'm not sure why this is necessary
del lib
@ -323,7 +323,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
with self.assertRaisesRegex(
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
):
optests.opcheck(op, (x,), {})
torch.library.opcheck(op, (x,), {})
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
def test_opcheck_opinfo(self, device, dtype, op):
@ -332,7 +332,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
optests.opcheck(
torch.library.opcheck(
op.op,
args,
kwargs,
@ -352,7 +352,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
with self.assertRaisesRegex(
optests.OpCheckError, "Autograd has not been implemented for operator"
):
optests.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
def test_autograd_registration_check_autograd_kernel(self, device):
lib = self.lib()
@ -2933,10 +2933,10 @@ opcheck(op, args, kwargs, test_utils="test_schema")
def test_opcheck(self):
x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(ValueError, "OpOverload"):
optests.opcheck(torch.sin, (x,))
torch.library.opcheck(torch.sin, (x,))
with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
optests.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
result = optests.opcheck(torch.ops.aten.sin.default, (x,))
torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
self.assertEqual(
result,
@ -2948,7 +2948,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
},
)
result = optests.opcheck(
result = torch.library.opcheck(
torch.ops.aten.sin.default, (x,), test_utils="test_schema"
)
self.assertEqual(
@ -2958,7 +2958,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
},
)
result = optests.opcheck(
result = torch.library.opcheck(
torch.ops.aten.sin.default,
(x,),
test_utils=["test_schema", "test_faketensor"],
@ -2971,6 +2971,21 @@ opcheck(op, args, kwargs, test_utils="test_schema")
},
)
def test_opcheck_customopdef(self):
sample_inputs = [
(torch.randn(3),),
(torch.randn(3, requires_grad=True),),
]
if torch.cuda.is_available():
sample_inputs.extend(
[
(torch.randn(3, device="cuda"),),
(torch.randn(3, device="cuda", requires_grad=True),),
]
)
for args in sample_inputs:
torch.library.opcheck(custom_op_db.numpy_cube, args)
def test_is_inside_opcheck_mode(self):
self.assertFalse(optests.is_inside_opcheck_mode())
with optests.generate_tests.OpCheckMode(
@ -2982,9 +2997,9 @@ opcheck(op, args, kwargs, test_utils="test_schema")
op = op_with_incorrect_schema(self, "foo")
x = torch.randn(3)
with self.assertRaisesRegex(Exception, "is not defined to alias output"):
optests.opcheck(op, (x,))
torch.library.opcheck(op, (x,))
result = optests.opcheck(op, (x,), raise_exception=False)
result = torch.library.opcheck(op, (x,), raise_exception=False)
self.assertTrue(isinstance(result["test_schema"], RuntimeError))
del result["test_schema"]
self.assertEqual(

View File

@ -1,5 +1,5 @@
from ._ops import OpOverload
from typing import Any, Optional, Set, List, Union, Callable
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
import traceback
import torch
import weakref
@ -9,7 +9,7 @@ import re
import contextlib
import sys
import warnings
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
import torch._library as _library
@ -743,3 +743,104 @@ def get_ctx() -> "torch._library.abstract_impl.AbstractImplCtx":
(see :func:`torch.library.register_fake` for more usage details.
"""
return torch._library.abstract_impl.global_ctx_getter()
_OPCHECK_DEFAULT_UTILS = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
)
def opcheck(
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
raise_exception: bool = True,
) -> Dict[str, str]:
"""Given an operator and some sample arguments, tests if the operator is
registered correctly.
That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
custom op, you specified metadata (e.g. mutability info) about the custom op
and these APIs require that the functions you pass them satisfy certain
properties (e.g. no data pointer access in the fake/meta/abstract kernel)
``opcheck`` tests these metadata and properties.
Concretely, we test the following:
- test_schema: if the operator's schema is correct.
- test_autograd_registration: if autograd was registered correctly.
- test_faketensor: If the operator has a FakeTensor kernel
(and if it is correct). The FakeTensor kernel is necessary (
but not sufficient) for the operator to work with PyTorch compilation
APIs (torch.compile/export/FX).
- test_aot_dispatch_dynamic: If the operator has correct behavior
with PyTorch compilation APIs (torch.compile/export/FX).
This checks that the outputs (and gradients, if applicable) are the
same under eager-mode PyTorch and torch.compile.
This test is a superset of ``test_faketensor``.
For best results, please call ``opcheck`` multiple times with a
representative set of inputs. If your operator supports
autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
if your operator supports multiple devices (e.g. CPU and CUDA), please
use ``opcheck`` with inputs on all supported devices.
Args:
op: The operator. Must either be a function decorated with
:func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
args: The args to the operator
kwargs: The kwargs to the operator
test_utils: Tests that we should run. Default: all of them.
Example: ("test_schema", "test_faketensor")
raise_exception: If we should raise an exception on the first
error. If False, we will return a dict with information
on if each test passed or not.
.. warning::
opcheck and :func:`torch.autograd.gradcheck` test different things;
opcheck tests if your usage of torch.library APIs is correct while
:func:`torch.autograd.gradcheck` tests if your autograd formula is
mathematically correct. Use both to test custom ops that support
gradient computation.
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_add(x: Tensor, y: float) -> Tensor:
>>> x_np = x.numpy(force=True)
>>> z_np = x_np + y
>>> return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_sin.register_fake
>>> def _(x, y):
>>> return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output)
>>> y, = inputs
>>> ctx.y = y
>>>
>>> def backward(ctx, grad)
>>> return grad * ctx.y, None
>>>
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>> (torch.randn(3), 3.14),
>>> (torch.randn(2, 3, device='cuda'), 2.718),
>>> (torch.randn(1, 10, requires_grad=True), 1.234),
>>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>> torch.library.opcheck(foo, args)
"""
import torch.testing._internal.optests as optests
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)

View File

@ -10,7 +10,7 @@ import re
import tempfile
import threading
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
@ -18,6 +18,7 @@ import torch._dynamo
import torch.utils._pytree as pytree
from torch._dynamo.utils import clone_input
from torch._library.custom_ops import CustomOpDef
from torch._subclasses.schema_check_mode import SchemaCheckMode
from torch._utils_internal import get_file_path_2
from torch.overrides import TorchFunctionMode
@ -620,48 +621,19 @@ def should_print_better_repro() -> None:
def opcheck(
op: torch._ops.OperatorBase,
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, List[str]] = DEFAULT_TEST_UTILS,
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
raise_exception: bool = True,
) -> Dict[str, str]:
"""Given an operator and some sample arguments, tests if the operator is
registered correctly.
We test the following (which are important for correctness in eager-mode
PyTorch and with torch.compile):
- test_schema: if the operator's schema is correct.
- test_autograd_registration: if autograd was registered correctly,
i.e. to the correct DispatchKey.
- test_faketensor: If the operator has a FakeTensor implementation
(and if it is correct).
- test_aot_dispatch_static: If the operator works with
AOTAutograd/AOTDispatch, which is one of the parts in the PT2 stack.
Checks that the outputs (and gradients, if they are computable)
of the operator are the same under eager-mode PyTorch and torch.compile.
- test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
tests dynamic shapes instead of static shapes.
For best results, please call ``opcheck`` multiple times with a
representative set of inputs. For example, if your operator supports
autograd, please use ``opcheck`` with inputs that require_grad.
Args:
op: The operator. Should look like torch.ops.aten.foo
args: The args to the operator
kwargs: The kwargs to the operator
test_utils: Tests that we should run. Default: all of them.
Example: ["test_schema", "test_faketensor"]
raise_exception: If we should raise an exception on the first
error. If False, we will return a dict with information
on if each test passed or not.
"""
"""See torch.library.opcheck for docstring"""
if kwargs is None:
kwargs = {}
if isinstance(op, CustomOpDef):
op = op._opoverload
if isinstance(op, torch._ops.OpOverloadPacket):
op = resolve_unique_overload_or_throw(op)
if not isinstance(op, torch._ops.OpOverload):