mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.library.opcheck (#124496)
This PR: - exposes torch.testing._internal.optests.opcheck as torch.library.opcheck - Adds support for CustomOpDef (aka functions decorated with torch.library.custom_op) to opcheck. Test Plan: - Updated tests - We validated opcheck's design internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124496 Approved by: https://github.com/williamwen42
This commit is contained in:
@ -4,13 +4,23 @@ torch.library
|
||||
.. currentmodule:: torch.library
|
||||
|
||||
torch.library is a collection of APIs for extending PyTorch's core library
|
||||
of operators. It contains utilities for creating new custom operators as
|
||||
well as extending operators defined with PyTorch's C++ operator
|
||||
of operators. It contains utilities for testing custom operators, creating new
|
||||
custom operators, and extending operators defined with PyTorch's C++ operator
|
||||
registration APIs (e.g. aten operators).
|
||||
|
||||
For a detailed guide on effectively using these APIs, please see
|
||||
`this gdoc <https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit>`_
|
||||
|
||||
Testing custom ops
|
||||
------------------
|
||||
|
||||
Use :func:`torch.library.opcheck` to test custom ops for incorrect usage of the
|
||||
Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports
|
||||
training, use :func:`torch.autograd.gradcheck` to test that the gradients are
|
||||
mathematically correct.
|
||||
|
||||
.. autofunction:: opcheck
|
||||
|
||||
Creating new custom ops in Python
|
||||
---------------------------------
|
||||
|
||||
|
@ -131,7 +131,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
with self.assertRaisesRegex(
|
||||
optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
def test_incorrect_schema_view(self, device):
|
||||
lib = self.lib()
|
||||
@ -167,7 +167,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
optests.OpCheckError,
|
||||
"Argument x is not defined to alias output but was aliasing",
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
def test_missing_abstract_impl(self, device):
|
||||
lib = self.lib()
|
||||
@ -196,7 +196,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
optests.OpCheckError,
|
||||
"_test_custom_op.foo.default",
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
def test_incorrect_abstract_impl(self, device):
|
||||
lib = self.lib()
|
||||
@ -234,7 +234,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
|
||||
x = torch.tensor([0, 1.0], requires_grad=True)
|
||||
with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
def test_missing_functionalization(self, device):
|
||||
lib = self.lib()
|
||||
@ -269,7 +269,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
optests.OpCheckError,
|
||||
"We only support functionalizing operators whose outputs do not have alias annotations",
|
||||
):
|
||||
optests.opcheck(op, (y,), {})
|
||||
torch.library.opcheck(op, (y,), {})
|
||||
|
||||
def test_autograd_registered_at_backend(self, device):
|
||||
lib = self.lib()
|
||||
@ -295,7 +295,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
torch.testing._internal.optests.OpCheckError,
|
||||
"does not have an autograd kernel",
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
# I'm not sure why this is necessary
|
||||
del lib
|
||||
@ -323,7 +323,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
with self.assertRaisesRegex(
|
||||
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
|
||||
):
|
||||
optests.opcheck(op, (x,), {})
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
|
||||
def test_opcheck_opinfo(self, device, dtype, op):
|
||||
@ -332,7 +332,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
):
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
optests.opcheck(
|
||||
torch.library.opcheck(
|
||||
op.op,
|
||||
args,
|
||||
kwargs,
|
||||
@ -352,7 +352,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
with self.assertRaisesRegex(
|
||||
optests.OpCheckError, "Autograd has not been implemented for operator"
|
||||
):
|
||||
optests.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
|
||||
torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
|
||||
|
||||
def test_autograd_registration_check_autograd_kernel(self, device):
|
||||
lib = self.lib()
|
||||
@ -2933,10 +2933,10 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
def test_opcheck(self):
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
with self.assertRaisesRegex(ValueError, "OpOverload"):
|
||||
optests.opcheck(torch.sin, (x,))
|
||||
torch.library.opcheck(torch.sin, (x,))
|
||||
with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
|
||||
optests.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
|
||||
result = optests.opcheck(torch.ops.aten.sin.default, (x,))
|
||||
torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
|
||||
result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
@ -2948,7 +2948,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
},
|
||||
)
|
||||
|
||||
result = optests.opcheck(
|
||||
result = torch.library.opcheck(
|
||||
torch.ops.aten.sin.default, (x,), test_utils="test_schema"
|
||||
)
|
||||
self.assertEqual(
|
||||
@ -2958,7 +2958,7 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
},
|
||||
)
|
||||
|
||||
result = optests.opcheck(
|
||||
result = torch.library.opcheck(
|
||||
torch.ops.aten.sin.default,
|
||||
(x,),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
@ -2971,6 +2971,21 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
},
|
||||
)
|
||||
|
||||
def test_opcheck_customopdef(self):
|
||||
sample_inputs = [
|
||||
(torch.randn(3),),
|
||||
(torch.randn(3, requires_grad=True),),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
sample_inputs.extend(
|
||||
[
|
||||
(torch.randn(3, device="cuda"),),
|
||||
(torch.randn(3, device="cuda", requires_grad=True),),
|
||||
]
|
||||
)
|
||||
for args in sample_inputs:
|
||||
torch.library.opcheck(custom_op_db.numpy_cube, args)
|
||||
|
||||
def test_is_inside_opcheck_mode(self):
|
||||
self.assertFalse(optests.is_inside_opcheck_mode())
|
||||
with optests.generate_tests.OpCheckMode(
|
||||
@ -2982,9 +2997,9 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
op = op_with_incorrect_schema(self, "foo")
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(Exception, "is not defined to alias output"):
|
||||
optests.opcheck(op, (x,))
|
||||
torch.library.opcheck(op, (x,))
|
||||
|
||||
result = optests.opcheck(op, (x,), raise_exception=False)
|
||||
result = torch.library.opcheck(op, (x,), raise_exception=False)
|
||||
self.assertTrue(isinstance(result["test_schema"], RuntimeError))
|
||||
del result["test_schema"]
|
||||
self.assertEqual(
|
||||
|
105
torch/library.py
105
torch/library.py
@ -1,5 +1,5 @@
|
||||
from ._ops import OpOverload
|
||||
from typing import Any, Optional, Set, List, Union, Callable
|
||||
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
|
||||
import traceback
|
||||
import torch
|
||||
import weakref
|
||||
@ -9,7 +9,7 @@ import re
|
||||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t
|
||||
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
|
||||
import torch._library as _library
|
||||
|
||||
|
||||
@ -743,3 +743,104 @@ def get_ctx() -> "torch._library.abstract_impl.AbstractImplCtx":
|
||||
(see :func:`torch.library.register_fake` for more usage details.
|
||||
"""
|
||||
return torch._library.abstract_impl.global_ctx_getter()
|
||||
|
||||
|
||||
_OPCHECK_DEFAULT_UTILS = (
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
"test_faketensor",
|
||||
"test_aot_dispatch_dynamic",
|
||||
)
|
||||
|
||||
|
||||
def opcheck(
|
||||
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
|
||||
raise_exception: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""Given an operator and some sample arguments, tests if the operator is
|
||||
registered correctly.
|
||||
|
||||
That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
|
||||
custom op, you specified metadata (e.g. mutability info) about the custom op
|
||||
and these APIs require that the functions you pass them satisfy certain
|
||||
properties (e.g. no data pointer access in the fake/meta/abstract kernel)
|
||||
``opcheck`` tests these metadata and properties.
|
||||
|
||||
Concretely, we test the following:
|
||||
- test_schema: if the operator's schema is correct.
|
||||
- test_autograd_registration: if autograd was registered correctly.
|
||||
- test_faketensor: If the operator has a FakeTensor kernel
|
||||
(and if it is correct). The FakeTensor kernel is necessary (
|
||||
but not sufficient) for the operator to work with PyTorch compilation
|
||||
APIs (torch.compile/export/FX).
|
||||
- test_aot_dispatch_dynamic: If the operator has correct behavior
|
||||
with PyTorch compilation APIs (torch.compile/export/FX).
|
||||
This checks that the outputs (and gradients, if applicable) are the
|
||||
same under eager-mode PyTorch and torch.compile.
|
||||
This test is a superset of ``test_faketensor``.
|
||||
|
||||
For best results, please call ``opcheck`` multiple times with a
|
||||
representative set of inputs. If your operator supports
|
||||
autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
|
||||
if your operator supports multiple devices (e.g. CPU and CUDA), please
|
||||
use ``opcheck`` with inputs on all supported devices.
|
||||
|
||||
Args:
|
||||
op: The operator. Must either be a function decorated with
|
||||
:func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
|
||||
found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
|
||||
args: The args to the operator
|
||||
kwargs: The kwargs to the operator
|
||||
test_utils: Tests that we should run. Default: all of them.
|
||||
Example: ("test_schema", "test_faketensor")
|
||||
raise_exception: If we should raise an exception on the first
|
||||
error. If False, we will return a dict with information
|
||||
on if each test passed or not.
|
||||
|
||||
.. warning::
|
||||
|
||||
opcheck and :func:`torch.autograd.gradcheck` test different things;
|
||||
opcheck tests if your usage of torch.library APIs is correct while
|
||||
:func:`torch.autograd.gradcheck` tests if your autograd formula is
|
||||
mathematically correct. Use both to test custom ops that support
|
||||
gradient computation.
|
||||
|
||||
Example:
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
|
||||
>>> def numpy_add(x: Tensor, y: float) -> Tensor:
|
||||
>>> x_np = x.numpy(force=True)
|
||||
>>> z_np = x_np + y
|
||||
>>> return torch.from_numpy(z_np).to(x.device)
|
||||
>>>
|
||||
>>> @numpy_sin.register_fake
|
||||
>>> def _(x, y):
|
||||
>>> return torch.empty_like(x)
|
||||
>>>
|
||||
>>> def setup_context(ctx, inputs, output)
|
||||
>>> y, = inputs
|
||||
>>> ctx.y = y
|
||||
>>>
|
||||
>>> def backward(ctx, grad)
|
||||
>>> return grad * ctx.y, None
|
||||
>>>
|
||||
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
|
||||
>>>
|
||||
>>> sample_inputs = [
|
||||
>>> (torch.randn(3), 3.14),
|
||||
>>> (torch.randn(2, 3, device='cuda'), 2.718),
|
||||
>>> (torch.randn(1, 10, requires_grad=True), 1.234),
|
||||
>>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
|
||||
>>> ]
|
||||
>>>
|
||||
>>> for args in sample_inputs:
|
||||
>>> torch.library.opcheck(foo, args)
|
||||
|
||||
"""
|
||||
import torch.testing._internal.optests as optests
|
||||
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
|
||||
|
@ -10,7 +10,7 @@ import re
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,6 +18,7 @@ import torch._dynamo
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import clone_input
|
||||
from torch._library.custom_ops import CustomOpDef
|
||||
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.overrides import TorchFunctionMode
|
||||
@ -620,48 +621,19 @@ def should_print_better_repro() -> None:
|
||||
|
||||
|
||||
def opcheck(
|
||||
op: torch._ops.OperatorBase,
|
||||
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, List[str]] = DEFAULT_TEST_UTILS,
|
||||
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""Given an operator and some sample arguments, tests if the operator is
|
||||
registered correctly.
|
||||
|
||||
We test the following (which are important for correctness in eager-mode
|
||||
PyTorch and with torch.compile):
|
||||
- test_schema: if the operator's schema is correct.
|
||||
- test_autograd_registration: if autograd was registered correctly,
|
||||
i.e. to the correct DispatchKey.
|
||||
- test_faketensor: If the operator has a FakeTensor implementation
|
||||
(and if it is correct).
|
||||
- test_aot_dispatch_static: If the operator works with
|
||||
AOTAutograd/AOTDispatch, which is one of the parts in the PT2 stack.
|
||||
Checks that the outputs (and gradients, if they are computable)
|
||||
of the operator are the same under eager-mode PyTorch and torch.compile.
|
||||
- test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
|
||||
tests dynamic shapes instead of static shapes.
|
||||
|
||||
For best results, please call ``opcheck`` multiple times with a
|
||||
representative set of inputs. For example, if your operator supports
|
||||
autograd, please use ``opcheck`` with inputs that require_grad.
|
||||
|
||||
Args:
|
||||
op: The operator. Should look like torch.ops.aten.foo
|
||||
args: The args to the operator
|
||||
kwargs: The kwargs to the operator
|
||||
test_utils: Tests that we should run. Default: all of them.
|
||||
Example: ["test_schema", "test_faketensor"]
|
||||
raise_exception: If we should raise an exception on the first
|
||||
error. If False, we will return a dict with information
|
||||
on if each test passed or not.
|
||||
|
||||
"""
|
||||
"""See torch.library.opcheck for docstring"""
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if isinstance(op, CustomOpDef):
|
||||
op = op._opoverload
|
||||
if isinstance(op, torch._ops.OpOverloadPacket):
|
||||
op = resolve_unique_overload_or_throw(op)
|
||||
if not isinstance(op, torch._ops.OpOverload):
|
||||
|
Reference in New Issue
Block a user