[ONNX] Refactor test_op_consistenct.py and test_fx_op_consistency.py (#100172)

## Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 9255aa3</samp>

This pull request refactors the ONNX operator testing code to use a common module `./test/onnx/onnx_test_common.py` that defines constants, types, classes, and functions for testing ONNX operators. This improves the code quality, readability, and maintainability.

## Walkthrough
<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 9255aa3</samp>

*  Refactor the common code for testing ONNX operators from different files into `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eL10-R24), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR33), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR367-R623))
* Remove the unused and duplicated imports, constants, types, and classes for testing ONNX operators from `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L43-R42), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L43-R44))
* Import the `unittest`, `opinfo_core`, and `onnx_test_common` modules and the `fixme`, `skip`, and `xfail` functions in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ( [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4R36), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L37-R37))
* Update the references to the constants, types, functions, and classes for testing ONNX operators in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` to use the definitions from `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L324-R80), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L389-R135), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L405-R151), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L455-R204), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L333-R107), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L434-R183), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L448-R197), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L494-R246))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100172
Approved by: https://github.com/justinchuby
This commit is contained in:
AllenTiTaiWang
2023-04-27 15:12:40 +00:00
committed by PyTorch MergeBot
parent 61917a006d
commit 1dba53cbab
3 changed files with 307 additions and 541 deletions

View File

@ -7,8 +7,21 @@ import copy
import dataclasses import dataclasses
import io import io
import os import os
import unittest
import warnings import warnings
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union from typing import (
Any,
Callable,
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np import numpy as np
@ -17,6 +30,7 @@ import pytorch_test_common
import torch import torch
from torch.onnx import _constants, verification from torch.onnx import _constants, verification
from torch.onnx._internal import _beartype from torch.onnx._internal import _beartype
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number from torch.types import Number
_NumericType = Union[Number, torch.Tensor, np.ndarray] _NumericType = Union[Number, torch.Tensor, np.ndarray]
@ -350,3 +364,260 @@ def _compare_pytorch_onnx_with_ort(
torch.testing.assert_close( torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
) )
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
# TODO(titaiwang): Change this when more versions are supported
# The min onnx opset version to test for
FX_MIN_ONNX_OPSET_VERSION = 18
# The max onnx opset version to test for
FX_MAX_ONNX_OPSET_VERSION = 18
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"

View File

@ -25,274 +25,20 @@ Note:
from __future__ import annotations from __future__ import annotations
import copy import copy
import dataclasses
import unittest
import warnings import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union from typing import Optional, Tuple
import onnx_test_common import onnx_test_common
import parameterized import parameterized
import torch import torch
from onnx_test_common import fixme, skip, xfail
from torch.testing._internal import ( from torch.testing._internal import (
common_device_type, common_device_type,
common_methods_invocations, common_methods_invocations,
common_utils, common_utils,
) )
from torch.testing._internal.opinfo import core as opinfo_core
# TODO(titaiwang): Change this when more versions are supported
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 18
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = 18
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ########################################################## # Modify this section ##########################################################
# NOTE: Modify this section as more ops are supported. The list should be sorted # NOTE: Modify this section as more ops are supported. The list should be sorted
@ -321,17 +67,17 @@ TESTED_OPS: frozenset[str] = frozenset(
# e.g. the test is flaky or some tests pass. Otherwise, use xfail. # e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage. # Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test. # Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = ( EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
skip( skip(
"ceil", dtypes=BOOL_TYPES + INT_TYPES, "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=reason_onnx_does_not_support("Ceil") reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
), ),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])), fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
xfail("unflatten", reason="AssertionError: Expected 1 inputs, got 3 (https://github.com/pytorch/pytorch/issues/99534)"), xfail("unflatten", reason="AssertionError: Expected 1 inputs, got 3 (https://github.com/pytorch/pytorch/issues/99534)"),
) )
# fmt: on # fmt: on
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ( SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
fixme( fixme(
"unflatten", "unflatten",
reason="Logic not implemented for size 0 inputs in op.Reshape", reason="Logic not implemented for size 0 inputs in op.Reshape",
@ -386,7 +132,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
"name": f"TestOnnxModelOutputConsistency_opset{opset}", "name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset, "opset_version": opset,
} }
for opset in TESTED_OPSETS for opset in onnx_test_common.FX_TESTED_OPSETS
], ],
class_name_func=_get_test_class_name, class_name_func=_get_test_class_name,
) )
@ -402,7 +148,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
@common_device_type.ops( @common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS], [op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES, allowed_dtypes=onnx_test_common.TESTED_DTYPES,
) )
def test_output_match(self, device: str, dtype: torch.dtype, op): def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter.""" """Test the ONNX exporter."""
@ -452,10 +198,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
) )
for opset in TESTED_OPSETS: for opset in onnx_test_common.FX_TESTED_OPSETS:
# The name needs to match the parameterized_class name. # The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}" test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info( onnx_test_common.add_decorate_info(
OPS_DB, OPS_DB,
test_class_name, test_class_name,
"test_output_match", "test_output_match",

View File

@ -25,272 +25,21 @@ Note:
from __future__ import annotations from __future__ import annotations
import copy import copy
import dataclasses
import unittest
import warnings import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union from typing import Optional, Tuple
import onnx_test_common import onnx_test_common
import parameterized import parameterized
import torch import torch
from torch.onnx import _constants
# For readability, these two are allowed to be imported as function
from onnx_test_common import fixme, skip
from torch.testing._internal import ( from torch.testing._internal import (
common_device_type, common_device_type,
common_methods_invocations, common_methods_invocations,
common_utils, common_utils,
) )
from torch.testing._internal.opinfo import core as opinfo_core
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
TESTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], Any]] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
skip is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
matcher: A function that matches the test sample input. It is used only when
fixme is in the SKIP_SUBTESTS list.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
matcher=matcher,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ########################################################## # Modify this section ##########################################################
@ -330,32 +79,32 @@ TESTED_OPS: frozenset[str] = frozenset(
# e.g. the test is flaky or some tests pass. Otherwise, use xfail. # e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage. # Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test. # Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = ( EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
skip( skip(
"atan", dtypes=BOOL_TYPES + INT_TYPES, "atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=reason_onnx_does_not_support("Atan") reason=onnx_test_common.reason_onnx_does_not_support("Atan")
), ),
fixme("atan", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])), fixme("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
skip( skip(
"atan2", dtypes=BOOL_TYPES + INT_TYPES, "atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=reason_onnx_does_not_support("Atan") reason=onnx_test_common.reason_onnx_does_not_support("Atan")
), ),
fixme("atan2", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])), fixme("atan2", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
skip( skip(
"ceil", dtypes=BOOL_TYPES + INT_TYPES, "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=reason_onnx_does_not_support("Ceil") reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
), ),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])), fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
skip("nn.functional.scaled_dot_product_attention", opsets=[opsets_before(14)], reason="Need Trilu."), skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"), fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
skip("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")), skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
skip("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")), skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
skip("tile", opsets=[opsets_before(13)], reason=reason_onnx_does_not_support("Tile")), skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
fixme("unflatten", opsets=[opsets_before(13)], reason="Helper function is needed to support legacy ops."), fixme("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
) )
# fmt: on # fmt: on
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ( SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
skip( skip(
"nn.functional.scaled_dot_product_attention", "nn.functional.scaled_dot_product_attention",
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
@ -431,7 +180,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
"name": f"TestOnnxModelOutputConsistency_opset{opset}", "name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset, "opset_version": opset,
} }
for opset in TESTED_OPSETS for opset in onnx_test_common.TESTED_OPSETS
], ],
class_name_func=_get_test_class_name, class_name_func=_get_test_class_name,
) )
@ -445,7 +194,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
@common_device_type.ops( @common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS], [op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=TESTED_DTYPES, allowed_dtypes=onnx_test_common.TESTED_DTYPES,
) )
def test_output_match(self, device: str, dtype: torch.dtype, op): def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter.""" """Test the ONNX exporter."""
@ -491,10 +240,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
self.run_test(model, inputs, rtol=rtol, atol=atol) self.run_test(model, inputs, rtol=rtol, atol=atol)
for opset in TESTED_OPSETS: for opset in onnx_test_common.TESTED_OPSETS:
# The name needs to match the parameterized_class name. # The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}" test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
add_decorate_info( onnx_test_common.add_decorate_info(
OPS_DB, OPS_DB,
test_class_name, test_class_name,
"test_output_match", "test_output_match",