[ONNX] Refactor test_op_consistenct.py and test_fx_op_consistency.py (#100172)

## Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 9255aa3</samp>

This pull request refactors the ONNX operator testing code to use a common module `./test/onnx/onnx_test_common.py` that defines constants, types, classes, and functions for testing ONNX operators. This improves the code quality, readability, and maintainability.

## Walkthrough
<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 9255aa3</samp>

*  Refactor the common code for testing ONNX operators from different files into `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eL10-R24), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR33), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR367-R623))
* Remove the unused and duplicated imports, constants, types, and classes for testing ONNX operators from `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L43-R42), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L43-R44))
* Import the `unittest`, `opinfo_core`, and `onnx_test_common` modules and the `fixme`, `skip`, and `xfail` functions in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ( [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4R36), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L37-R37))
* Update the references to the constants, types, functions, and classes for testing ONNX operators in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` to use the definitions from `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L324-R80), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L389-R135), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L405-R151), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L455-R204), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L333-R107), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L434-R183), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L448-R197), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L494-R246))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100172
Approved by: https://github.com/justinchuby
This commit is contained in:
AllenTiTaiWang
2023-04-27 15:12:40 +00:00
committed by PyTorch MergeBot
parent 61917a006d
commit 1dba53cbab
3 changed files with 307 additions and 541 deletions

View File

@ -7,8 +7,21 @@ import copy
import dataclasses
import io
import os
import unittest
import warnings
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np
@ -17,6 +30,7 @@ import pytorch_test_common
import torch
from torch.onnx import _constants, verification
from torch.onnx._internal import _beartype
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number
_NumericType = Union[Number, torch.Tensor, np.ndarray]
@ -350,3 +364,260 @@ def _compare_pytorch_onnx_with_ort(
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
)
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
# TODO(titaiwang): Change this when more versions are supported
# The min onnx opset version to test for
FX_MIN_ONNX_OPSET_VERSION = 18
# The max onnx opset version to test for
FX_MAX_ONNX_OPSET_VERSION = 18
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"

View File

@ -25,274 +25,20 @@ Note:
from __future__ import annotations
import copy
import dataclasses
import unittest
import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union
from typing import Optional, Tuple
import onnx_test_common
import parameterized
import torch
from onnx_test_common import fixme, skip, xfail
from torch.testing._internal import (
common_device_type,
common_methods_invocations,
common_utils,
)
from torch.testing._internal.opinfo import core as opinfo_core
# TODO(titaiwang): Change this when more versions are supported
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 18
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = 18
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ##########################################################
# NOTE: Modify this section as more ops are supported. The list should be sorted
@ -321,17 +67,17 @@ TESTED_OPS: frozenset[str] = frozenset(
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
skip(
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Ceil")
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
xfail("unflatten", reason="AssertionError: Expected 1 inputs, got 3 (https://github.com/pytorch/pytorch/issues/99534)"),
)
# fmt: on
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
fixme(
"unflatten",
reason="Logic not implemented for size 0 inputs in op.Reshape",
@ -386,7 +132,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset,
}
for opset in TESTED_OPSETS
for opset in onnx_test_common.FX_TESTED_OPSETS
],
class_name_func=_get_test_class_name,
)
@ -402,7 +148,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter."""
@ -452,10 +198,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
)
for opset in TESTED_OPSETS:
for opset in onnx_test_common.FX_TESTED_OPSETS:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info(
onnx_test_common.add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match",

View File

@ -25,272 +25,21 @@ Note:
from __future__ import annotations
import copy
import dataclasses
import unittest
import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union
from typing import Optional, Tuple
import onnx_test_common
import parameterized
import torch
from torch.onnx import _constants
# For readability, these two are allowed to be imported as function
from onnx_test_common import fixme, skip
from torch.testing._internal import (
common_device_type,
common_methods_invocations,
common_utils,
)
from torch.testing._internal.opinfo import core as opinfo_core
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ##########################################################
@ -330,32 +79,32 @@ TESTED_OPS: frozenset[str] = frozenset(
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
skip(
"atan", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Atan")
"atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
),
fixme("atan", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])),
fixme("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
skip(
"atan2", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Atan")
"atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
),
fixme("atan2", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])),
fixme("atan2", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
skip(
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Ceil")
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
skip("nn.functional.scaled_dot_product_attention", opsets=[opsets_before(14)], reason="Need Trilu."),
fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
skip("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
skip("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
skip("tile", opsets=[opsets_before(13)], reason=reason_onnx_does_not_support("Tile")),
fixme("unflatten", opsets=[opsets_before(13)], reason="Helper function is needed to support legacy ops."),
skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
fixme("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
)
# fmt: on
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
skip(
"nn.functional.scaled_dot_product_attention",
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
@ -431,7 +180,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset,
}
for opset in TESTED_OPSETS
for opset in onnx_test_common.TESTED_OPSETS
],
class_name_func=_get_test_class_name,
)
@ -445,7 +194,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter."""
@ -491,10 +240,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
self.run_test(model, inputs, rtol=rtol, atol=atol)
for opset in TESTED_OPSETS:
for opset in onnx_test_common.TESTED_OPSETS:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info(
onnx_test_common.add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match",