mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ONNX] Refactor test_op_consistenct.py and test_fx_op_consistency.py (#100172)
## Summary <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at 9255aa3</samp> This pull request refactors the ONNX operator testing code to use a common module `./test/onnx/onnx_test_common.py` that defines constants, types, classes, and functions for testing ONNX operators. This improves the code quality, readability, and maintainability. ## Walkthrough <!-- copilot:walkthrough --> ### <samp>🤖 Generated by Copilot at 9255aa3</samp> * Refactor the common code for testing ONNX operators from different files into `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eL10-R24), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR33), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR367-R623)) * Remove the unused and duplicated imports, constants, types, and classes for testing ONNX operators from `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L43-R42), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L28-R29), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L43-R44)) * Import the `unittest`, `opinfo_core`, and `onnx_test_common` modules and the `fixme`, `skip`, and `xfail` functions in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` ( [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4R36), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L37-R37)) * Update the references to the constants, types, functions, and classes for testing ONNX operators in `./test/onnx/test_fx_op_consistency.py` and `./test/onnx/test_op_consistency.py` to use the definitions from `./test/onnx/onnx_test_common.py` ([link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L324-R80), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L389-R135), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L405-R151), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-db2f78a51511bb172cbfde1b2f68272b8b33049abe2571cded27bcd0f3ae5fa4L455-R204), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L333-R107), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L434-R183), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L448-R197), [link](https://github.com/pytorch/pytorch/pull/100172/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L494-R246)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/100172 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
61917a006d
commit
1dba53cbab
@ -7,8 +7,21 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -17,6 +30,7 @@ import pytorch_test_common
|
|||||||
import torch
|
import torch
|
||||||
from torch.onnx import _constants, verification
|
from torch.onnx import _constants, verification
|
||||||
from torch.onnx._internal import _beartype
|
from torch.onnx._internal import _beartype
|
||||||
|
from torch.testing._internal.opinfo import core as opinfo_core
|
||||||
from torch.types import Number
|
from torch.types import Number
|
||||||
|
|
||||||
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
||||||
@ -350,3 +364,260 @@ def _compare_pytorch_onnx_with_ort(
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
|
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# The min onnx opset version to test for
|
||||||
|
MIN_ONNX_OPSET_VERSION = 9
|
||||||
|
# The max onnx opset version to test for
|
||||||
|
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
|
||||||
|
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
|
||||||
|
|
||||||
|
# TODO(titaiwang): Change this when more versions are supported
|
||||||
|
# The min onnx opset version to test for
|
||||||
|
FX_MIN_ONNX_OPSET_VERSION = 18
|
||||||
|
# The max onnx opset version to test for
|
||||||
|
FX_MAX_ONNX_OPSET_VERSION = 18
|
||||||
|
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
|
||||||
|
|
||||||
|
BOOL_TYPES = (torch.bool,)
|
||||||
|
|
||||||
|
INT_TYPES = (
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
QINT_TYPES = (
|
||||||
|
torch.qint8,
|
||||||
|
torch.quint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
FLOAT_TYPES = (
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
torch.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
COMPLEX_TYPES = (
|
||||||
|
torch.complex32,
|
||||||
|
torch.complex64,
|
||||||
|
torch.complex128,
|
||||||
|
)
|
||||||
|
|
||||||
|
TESTED_DTYPES = (
|
||||||
|
# Boolean
|
||||||
|
torch.bool,
|
||||||
|
# Integers
|
||||||
|
*INT_TYPES,
|
||||||
|
# Floating types
|
||||||
|
*FLOAT_TYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DecorateMeta:
|
||||||
|
"""Information about a test case to skip or xfail.
|
||||||
|
|
||||||
|
Adapted from functorch: functorch/test/common_utils.py
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
op_name: The name of the operator.
|
||||||
|
variant_name: The name of the OpInfo variant.
|
||||||
|
decorator: The decorator to apply to the test case.
|
||||||
|
opsets: The opsets to apply the decorator to.
|
||||||
|
dtypes: The dtypes to apply the decorator to.
|
||||||
|
reason: The reason for skipping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
op_name: str
|
||||||
|
variant_name: str
|
||||||
|
decorator: Callable
|
||||||
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
|
||||||
|
dtypes: Optional[Collection[torch.dtype]]
|
||||||
|
reason: str
|
||||||
|
matcher: Optional[Callable[[Any], Any]] = None
|
||||||
|
|
||||||
|
def contains_opset(self, opset: int) -> bool:
|
||||||
|
if self.opsets is None:
|
||||||
|
return True
|
||||||
|
return any(
|
||||||
|
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
|
||||||
|
for opset_spec in self.opsets
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def xfail(
|
||||||
|
op_name: str,
|
||||||
|
variant_name: str = "",
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
||||||
|
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||||
|
):
|
||||||
|
"""Expects a OpInfo test to fail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_name: The name of the operator.
|
||||||
|
variant_name: The name of the variant.
|
||||||
|
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
||||||
|
dtypes: The dtypes to expect the failure.
|
||||||
|
reason: The reason for the failure.
|
||||||
|
"""
|
||||||
|
return DecorateMeta(
|
||||||
|
op_name=op_name,
|
||||||
|
variant_name=variant_name,
|
||||||
|
decorator=unittest.expectedFailure,
|
||||||
|
opsets=opsets,
|
||||||
|
dtypes=dtypes,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def skip(
|
||||||
|
op_name: str,
|
||||||
|
variant_name: str = "",
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
||||||
|
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||||
|
matcher: Optional[Callable[[Any], Any]] = None,
|
||||||
|
):
|
||||||
|
"""Skips a test case in OpInfo that we don't care about.
|
||||||
|
|
||||||
|
Likely because ONNX does not support the use case or it is by design.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_name: The name of the operator.
|
||||||
|
variant_name: The name of the variant.
|
||||||
|
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
||||||
|
dtypes: The dtypes to expect the failure.
|
||||||
|
reason: The reason for the failure.
|
||||||
|
matcher: A function that matches the test sample input. It is used only when
|
||||||
|
skip is in the SKIP_SUBTESTS list.
|
||||||
|
"""
|
||||||
|
return DecorateMeta(
|
||||||
|
op_name=op_name,
|
||||||
|
variant_name=variant_name,
|
||||||
|
decorator=unittest.skip(f"Skip: {reason}"),
|
||||||
|
opsets=opsets,
|
||||||
|
dtypes=dtypes,
|
||||||
|
reason=reason,
|
||||||
|
matcher=matcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fixme(
|
||||||
|
op_name: str,
|
||||||
|
variant_name: str = "",
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
||||||
|
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||||
|
matcher: Optional[Callable[[Any], Any]] = None,
|
||||||
|
):
|
||||||
|
"""Skips a test case in OpInfo. It should be eventually fixed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_name: The name of the operator.
|
||||||
|
variant_name: The name of the variant.
|
||||||
|
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
||||||
|
dtypes: The dtypes to expect the failure.
|
||||||
|
reason: The reason for the failure.
|
||||||
|
matcher: A function that matches the test sample input. It is used only when
|
||||||
|
fixme is in the SKIP_SUBTESTS list.
|
||||||
|
"""
|
||||||
|
return DecorateMeta(
|
||||||
|
op_name=op_name,
|
||||||
|
variant_name=variant_name,
|
||||||
|
decorator=unittest.skip(f"To fix: {reason}"),
|
||||||
|
opsets=opsets,
|
||||||
|
dtypes=dtypes,
|
||||||
|
reason=reason,
|
||||||
|
matcher=matcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_decorate_info(
|
||||||
|
all_opinfos: Sequence[opinfo_core.OpInfo],
|
||||||
|
test_class_name: str,
|
||||||
|
base_test_name: str,
|
||||||
|
opset: int,
|
||||||
|
skip_or_xfails: Iterable[DecorateMeta],
|
||||||
|
):
|
||||||
|
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_opinfos: All OpInfos.
|
||||||
|
test_class_name: The name of the test class.
|
||||||
|
base_test_name: The name of the test method.
|
||||||
|
opset: The opset to decorate for.
|
||||||
|
skip_or_xfails: DecorateMeta's.
|
||||||
|
"""
|
||||||
|
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
|
||||||
|
for decorate_meta in skip_or_xfails:
|
||||||
|
if not decorate_meta.contains_opset(opset):
|
||||||
|
# Skip does not apply to this opset
|
||||||
|
continue
|
||||||
|
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
|
||||||
|
assert (
|
||||||
|
opinfo is not None
|
||||||
|
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
||||||
|
decorators = list(opinfo.decorators)
|
||||||
|
new_decorator = opinfo_core.DecorateInfo(
|
||||||
|
decorate_meta.decorator,
|
||||||
|
test_class_name,
|
||||||
|
base_test_name,
|
||||||
|
dtypes=decorate_meta.dtypes,
|
||||||
|
)
|
||||||
|
decorators.append(new_decorator)
|
||||||
|
opinfo.decorators = tuple(decorators)
|
||||||
|
|
||||||
|
# This decorator doesn't modify fn in any way
|
||||||
|
def wrapped(fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def opsets_before(opset: int) -> Callable[[int], bool]:
|
||||||
|
"""Returns a comparison function that decides if the given opset is before the specified."""
|
||||||
|
|
||||||
|
def compare(other_opset: int):
|
||||||
|
return other_opset < opset
|
||||||
|
|
||||||
|
return compare
|
||||||
|
|
||||||
|
|
||||||
|
def opsets_after(opset: int) -> Callable[[int], bool]:
|
||||||
|
"""Returns a comparison function that decides if the given opset is after the specified."""
|
||||||
|
|
||||||
|
def compare(other_opset: int):
|
||||||
|
return other_opset > opset
|
||||||
|
|
||||||
|
return compare
|
||||||
|
|
||||||
|
|
||||||
|
def reason_onnx_runtime_does_not_support(
|
||||||
|
operator: str, dtypes: Optional[Sequence[str]] = None
|
||||||
|
) -> str:
|
||||||
|
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
|
||||||
|
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
|
||||||
|
|
||||||
|
|
||||||
|
def reason_onnx_does_not_support(
|
||||||
|
operator: str, dtypes: Optional[Sequence[str]] = None
|
||||||
|
) -> str:
|
||||||
|
"""Formats the reason: ONNX doesn't support the given dtypes."""
|
||||||
|
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
|
||||||
|
|
||||||
|
|
||||||
|
def reason_jit_tracer_error(info: str) -> str:
|
||||||
|
"""Formats the reason: JIT tracer errors."""
|
||||||
|
return f"JIT tracer error on {info}"
|
||||||
|
|
||||||
|
|
||||||
|
def reason_flaky() -> str:
|
||||||
|
"""Formats the reason: test is flaky."""
|
||||||
|
return "flaky test"
|
||||||
|
@ -25,274 +25,20 @@ Note:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
|
||||||
import unittest
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import onnx_test_common
|
import onnx_test_common
|
||||||
|
|
||||||
import parameterized
|
import parameterized
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from onnx_test_common import fixme, skip, xfail
|
||||||
from torch.testing._internal import (
|
from torch.testing._internal import (
|
||||||
common_device_type,
|
common_device_type,
|
||||||
common_methods_invocations,
|
common_methods_invocations,
|
||||||
common_utils,
|
common_utils,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.opinfo import core as opinfo_core
|
|
||||||
|
|
||||||
# TODO(titaiwang): Change this when more versions are supported
|
|
||||||
# The min onnx opset version to test for
|
|
||||||
MIN_ONNX_OPSET_VERSION = 18
|
|
||||||
# The max onnx opset version to test for
|
|
||||||
MAX_ONNX_OPSET_VERSION = 18
|
|
||||||
|
|
||||||
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
|
|
||||||
|
|
||||||
BOOL_TYPES = (torch.bool,)
|
|
||||||
|
|
||||||
INT_TYPES = (
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.uint8,
|
|
||||||
)
|
|
||||||
|
|
||||||
QINT_TYPES = (
|
|
||||||
torch.qint8,
|
|
||||||
torch.quint8,
|
|
||||||
)
|
|
||||||
|
|
||||||
FLOAT_TYPES = (
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
torch.float64,
|
|
||||||
)
|
|
||||||
|
|
||||||
COMPLEX_TYPES = (
|
|
||||||
torch.complex32,
|
|
||||||
torch.complex64,
|
|
||||||
torch.complex128,
|
|
||||||
)
|
|
||||||
|
|
||||||
TESTED_DTYPES = (
|
|
||||||
# Boolean
|
|
||||||
torch.bool,
|
|
||||||
# Integers
|
|
||||||
*INT_TYPES,
|
|
||||||
# Floating types
|
|
||||||
*FLOAT_TYPES,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class DecorateMeta:
|
|
||||||
"""Information about a test case to skip or xfail.
|
|
||||||
|
|
||||||
Adapted from functorch: functorch/test/common_utils.py
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the OpInfo variant.
|
|
||||||
decorator: The decorator to apply to the test case.
|
|
||||||
opsets: The opsets to apply the decorator to.
|
|
||||||
dtypes: The dtypes to apply the decorator to.
|
|
||||||
reason: The reason for skipping.
|
|
||||||
"""
|
|
||||||
|
|
||||||
op_name: str
|
|
||||||
variant_name: str
|
|
||||||
decorator: Callable
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
|
|
||||||
dtypes: Optional[Collection[torch.dtype]]
|
|
||||||
reason: str
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None
|
|
||||||
|
|
||||||
def contains_opset(self, opset: int) -> bool:
|
|
||||||
if self.opsets is None:
|
|
||||||
return True
|
|
||||||
return any(
|
|
||||||
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
|
|
||||||
for opset_spec in self.opsets
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def xfail(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
):
|
|
||||||
"""Expects a OpInfo test to fail.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.expectedFailure,
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def skip(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None,
|
|
||||||
):
|
|
||||||
"""Skips a test case in OpInfo that we don't care about.
|
|
||||||
|
|
||||||
Likely because ONNX does not support the use case or it is by design.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
matcher: A function that matches the test sample input. It is used only when
|
|
||||||
skip is in the SKIP_SUBTESTS list.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.skip(f"Skip: {reason}"),
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
matcher=matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fixme(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None,
|
|
||||||
):
|
|
||||||
"""Skips a test case in OpInfo. It should be eventually fixed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
matcher: A function that matches the test sample input. It is used only when
|
|
||||||
fixme is in the SKIP_SUBTESTS list.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.skip(f"To fix: {reason}"),
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
matcher=matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_decorate_info(
|
|
||||||
all_opinfos: Sequence[opinfo_core.OpInfo],
|
|
||||||
test_class_name: str,
|
|
||||||
base_test_name: str,
|
|
||||||
opset: int,
|
|
||||||
skip_or_xfails: Iterable[DecorateMeta],
|
|
||||||
):
|
|
||||||
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
all_opinfos: All OpInfos.
|
|
||||||
test_class_name: The name of the test class.
|
|
||||||
base_test_name: The name of the test method.
|
|
||||||
opset: The opset to decorate for.
|
|
||||||
skip_or_xfails: DecorateMeta's.
|
|
||||||
"""
|
|
||||||
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
|
|
||||||
for decorate_meta in skip_or_xfails:
|
|
||||||
if not decorate_meta.contains_opset(opset):
|
|
||||||
# Skip does not apply to this opset
|
|
||||||
continue
|
|
||||||
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
|
|
||||||
assert (
|
|
||||||
opinfo is not None
|
|
||||||
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
|
||||||
decorators = list(opinfo.decorators)
|
|
||||||
new_decorator = opinfo_core.DecorateInfo(
|
|
||||||
decorate_meta.decorator,
|
|
||||||
test_class_name,
|
|
||||||
base_test_name,
|
|
||||||
dtypes=decorate_meta.dtypes,
|
|
||||||
)
|
|
||||||
decorators.append(new_decorator)
|
|
||||||
opinfo.decorators = tuple(decorators)
|
|
||||||
|
|
||||||
# This decorator doesn't modify fn in any way
|
|
||||||
def wrapped(fn):
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def opsets_before(opset: int) -> Callable[[int], bool]:
|
|
||||||
"""Returns a comparison function that decides if the given opset is before the specified."""
|
|
||||||
|
|
||||||
def compare(other_opset: int):
|
|
||||||
return other_opset < opset
|
|
||||||
|
|
||||||
return compare
|
|
||||||
|
|
||||||
|
|
||||||
def opsets_after(opset: int) -> Callable[[int], bool]:
|
|
||||||
"""Returns a comparison function that decides if the given opset is after the specified."""
|
|
||||||
|
|
||||||
def compare(other_opset: int):
|
|
||||||
return other_opset > opset
|
|
||||||
|
|
||||||
return compare
|
|
||||||
|
|
||||||
|
|
||||||
def reason_onnx_runtime_does_not_support(
|
|
||||||
operator: str, dtypes: Optional[Sequence[str]] = None
|
|
||||||
) -> str:
|
|
||||||
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
|
|
||||||
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_onnx_does_not_support(
|
|
||||||
operator: str, dtypes: Optional[Sequence[str]] = None
|
|
||||||
) -> str:
|
|
||||||
"""Formats the reason: ONNX doesn't support the given dtypes."""
|
|
||||||
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_jit_tracer_error(info: str) -> str:
|
|
||||||
"""Formats the reason: JIT tracer errors."""
|
|
||||||
return f"JIT tracer error on {info}"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_flaky() -> str:
|
|
||||||
"""Formats the reason: test is flaky."""
|
|
||||||
return "flaky test"
|
|
||||||
|
|
||||||
|
|
||||||
# Modify this section ##########################################################
|
# Modify this section ##########################################################
|
||||||
# NOTE: Modify this section as more ops are supported. The list should be sorted
|
# NOTE: Modify this section as more ops are supported. The list should be sorted
|
||||||
@ -321,17 +67,17 @@ TESTED_OPS: frozenset[str] = frozenset(
|
|||||||
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
|
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
|
||||||
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
|
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
|
||||||
# Use xfail if a test fails now and we want to eventually fix the test.
|
# Use xfail if a test fails now and we want to eventually fix the test.
|
||||||
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
|
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||||
skip(
|
skip(
|
||||||
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
|
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
||||||
reason=reason_onnx_does_not_support("Ceil")
|
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
|
||||||
),
|
),
|
||||||
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
|
fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
|
||||||
xfail("unflatten", reason="AssertionError: Expected 1 inputs, got 3 (https://github.com/pytorch/pytorch/issues/99534)"),
|
xfail("unflatten", reason="AssertionError: Expected 1 inputs, got 3 (https://github.com/pytorch/pytorch/issues/99534)"),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
|
SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||||
fixme(
|
fixme(
|
||||||
"unflatten",
|
"unflatten",
|
||||||
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
||||||
@ -386,7 +132,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
|
|||||||
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
||||||
"opset_version": opset,
|
"opset_version": opset,
|
||||||
}
|
}
|
||||||
for opset in TESTED_OPSETS
|
for opset in onnx_test_common.FX_TESTED_OPSETS
|
||||||
],
|
],
|
||||||
class_name_func=_get_test_class_name,
|
class_name_func=_get_test_class_name,
|
||||||
)
|
)
|
||||||
@ -402,7 +148,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||||||
|
|
||||||
@common_device_type.ops(
|
@common_device_type.ops(
|
||||||
[op for op in OPS_DB if op.name in TESTED_OPS],
|
[op for op in OPS_DB if op.name in TESTED_OPS],
|
||||||
allowed_dtypes=TESTED_DTYPES,
|
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
|
||||||
)
|
)
|
||||||
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
||||||
"""Test the ONNX exporter."""
|
"""Test the ONNX exporter."""
|
||||||
@ -452,10 +198,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for opset in TESTED_OPSETS:
|
for opset in onnx_test_common.FX_TESTED_OPSETS:
|
||||||
# The name needs to match the parameterized_class name.
|
# The name needs to match the parameterized_class name.
|
||||||
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
||||||
add_decorate_info(
|
onnx_test_common.add_decorate_info(
|
||||||
OPS_DB,
|
OPS_DB,
|
||||||
test_class_name,
|
test_class_name,
|
||||||
"test_output_match",
|
"test_output_match",
|
||||||
|
@ -25,272 +25,21 @@ Note:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
|
||||||
import unittest
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import onnx_test_common
|
import onnx_test_common
|
||||||
import parameterized
|
import parameterized
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import _constants
|
|
||||||
|
# For readability, these two are allowed to be imported as function
|
||||||
|
from onnx_test_common import fixme, skip
|
||||||
from torch.testing._internal import (
|
from torch.testing._internal import (
|
||||||
common_device_type,
|
common_device_type,
|
||||||
common_methods_invocations,
|
common_methods_invocations,
|
||||||
common_utils,
|
common_utils,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.opinfo import core as opinfo_core
|
|
||||||
|
|
||||||
# The min onnx opset version to test for
|
|
||||||
MIN_ONNX_OPSET_VERSION = 9
|
|
||||||
# The max onnx opset version to test for
|
|
||||||
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
|
|
||||||
|
|
||||||
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
|
|
||||||
|
|
||||||
BOOL_TYPES = (torch.bool,)
|
|
||||||
|
|
||||||
INT_TYPES = (
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.uint8,
|
|
||||||
)
|
|
||||||
|
|
||||||
QINT_TYPES = (
|
|
||||||
torch.qint8,
|
|
||||||
torch.quint8,
|
|
||||||
)
|
|
||||||
|
|
||||||
FLOAT_TYPES = (
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
torch.float64,
|
|
||||||
)
|
|
||||||
|
|
||||||
COMPLEX_TYPES = (
|
|
||||||
torch.complex32,
|
|
||||||
torch.complex64,
|
|
||||||
torch.complex128,
|
|
||||||
)
|
|
||||||
|
|
||||||
TESTED_DTYPES = (
|
|
||||||
# Boolean
|
|
||||||
torch.bool,
|
|
||||||
# Integers
|
|
||||||
*INT_TYPES,
|
|
||||||
# Floating types
|
|
||||||
*FLOAT_TYPES,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class DecorateMeta:
|
|
||||||
"""Information about a test case to skip or xfail.
|
|
||||||
|
|
||||||
Adapted from functorch: functorch/test/common_utils.py
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the OpInfo variant.
|
|
||||||
decorator: The decorator to apply to the test case.
|
|
||||||
opsets: The opsets to apply the decorator to.
|
|
||||||
dtypes: The dtypes to apply the decorator to.
|
|
||||||
reason: The reason for skipping.
|
|
||||||
"""
|
|
||||||
|
|
||||||
op_name: str
|
|
||||||
variant_name: str
|
|
||||||
decorator: Callable
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
|
|
||||||
dtypes: Optional[Collection[torch.dtype]]
|
|
||||||
reason: str
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None
|
|
||||||
|
|
||||||
def contains_opset(self, opset: int) -> bool:
|
|
||||||
if self.opsets is None:
|
|
||||||
return True
|
|
||||||
return any(
|
|
||||||
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
|
|
||||||
for opset_spec in self.opsets
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def xfail(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
):
|
|
||||||
"""Expects a OpInfo test to fail.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.expectedFailure,
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def skip(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None,
|
|
||||||
):
|
|
||||||
"""Skips a test case in OpInfo that we don't care about.
|
|
||||||
|
|
||||||
Likely because ONNX does not support the use case or it is by design.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
matcher: A function that matches the test sample input. It is used only when
|
|
||||||
skip is in the SKIP_SUBTESTS list.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.skip(f"Skip: {reason}"),
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
matcher=matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fixme(
|
|
||||||
op_name: str,
|
|
||||||
variant_name: str = "",
|
|
||||||
*,
|
|
||||||
reason: str,
|
|
||||||
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
||||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
||||||
matcher: Optional[Callable[[Any], Any]] = None,
|
|
||||||
):
|
|
||||||
"""Skips a test case in OpInfo. It should be eventually fixed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_name: The name of the operator.
|
|
||||||
variant_name: The name of the variant.
|
|
||||||
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
|
|
||||||
dtypes: The dtypes to expect the failure.
|
|
||||||
reason: The reason for the failure.
|
|
||||||
matcher: A function that matches the test sample input. It is used only when
|
|
||||||
fixme is in the SKIP_SUBTESTS list.
|
|
||||||
"""
|
|
||||||
return DecorateMeta(
|
|
||||||
op_name=op_name,
|
|
||||||
variant_name=variant_name,
|
|
||||||
decorator=unittest.skip(f"To fix: {reason}"),
|
|
||||||
opsets=opsets,
|
|
||||||
dtypes=dtypes,
|
|
||||||
reason=reason,
|
|
||||||
matcher=matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_decorate_info(
|
|
||||||
all_opinfos: Sequence[opinfo_core.OpInfo],
|
|
||||||
test_class_name: str,
|
|
||||||
base_test_name: str,
|
|
||||||
opset: int,
|
|
||||||
skip_or_xfails: Iterable[DecorateMeta],
|
|
||||||
):
|
|
||||||
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
all_opinfos: All OpInfos.
|
|
||||||
test_class_name: The name of the test class.
|
|
||||||
base_test_name: The name of the test method.
|
|
||||||
opset: The opset to decorate for.
|
|
||||||
skip_or_xfails: DecorateMeta's.
|
|
||||||
"""
|
|
||||||
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
|
|
||||||
for decorate_meta in skip_or_xfails:
|
|
||||||
if not decorate_meta.contains_opset(opset):
|
|
||||||
# Skip does not apply to this opset
|
|
||||||
continue
|
|
||||||
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
|
|
||||||
assert (
|
|
||||||
opinfo is not None
|
|
||||||
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
|
||||||
decorators = list(opinfo.decorators)
|
|
||||||
new_decorator = opinfo_core.DecorateInfo(
|
|
||||||
decorate_meta.decorator,
|
|
||||||
test_class_name,
|
|
||||||
base_test_name,
|
|
||||||
dtypes=decorate_meta.dtypes,
|
|
||||||
)
|
|
||||||
decorators.append(new_decorator)
|
|
||||||
opinfo.decorators = tuple(decorators)
|
|
||||||
|
|
||||||
# This decorator doesn't modify fn in any way
|
|
||||||
def wrapped(fn):
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def opsets_before(opset: int) -> Callable[[int], bool]:
|
|
||||||
"""Returns a comparison function that decides if the given opset is before the specified."""
|
|
||||||
|
|
||||||
def compare(other_opset: int):
|
|
||||||
return other_opset < opset
|
|
||||||
|
|
||||||
return compare
|
|
||||||
|
|
||||||
|
|
||||||
def opsets_after(opset: int) -> Callable[[int], bool]:
|
|
||||||
"""Returns a comparison function that decides if the given opset is after the specified."""
|
|
||||||
|
|
||||||
def compare(other_opset: int):
|
|
||||||
return other_opset > opset
|
|
||||||
|
|
||||||
return compare
|
|
||||||
|
|
||||||
|
|
||||||
def reason_onnx_runtime_does_not_support(
|
|
||||||
operator: str, dtypes: Optional[Sequence[str]] = None
|
|
||||||
) -> str:
|
|
||||||
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
|
|
||||||
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_onnx_does_not_support(
|
|
||||||
operator: str, dtypes: Optional[Sequence[str]] = None
|
|
||||||
) -> str:
|
|
||||||
"""Formats the reason: ONNX doesn't support the given dtypes."""
|
|
||||||
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_jit_tracer_error(info: str) -> str:
|
|
||||||
"""Formats the reason: JIT tracer errors."""
|
|
||||||
return f"JIT tracer error on {info}"
|
|
||||||
|
|
||||||
|
|
||||||
def reason_flaky() -> str:
|
|
||||||
"""Formats the reason: test is flaky."""
|
|
||||||
return "flaky test"
|
|
||||||
|
|
||||||
|
|
||||||
# Modify this section ##########################################################
|
# Modify this section ##########################################################
|
||||||
@ -330,32 +79,32 @@ TESTED_OPS: frozenset[str] = frozenset(
|
|||||||
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
|
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
|
||||||
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
|
# Use skip if we don't care about the test passing, e.g. ONNX doesn't support the usage.
|
||||||
# Use xfail if a test fails now and we want to eventually fix the test.
|
# Use xfail if a test fails now and we want to eventually fix the test.
|
||||||
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
|
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||||
skip(
|
skip(
|
||||||
"atan", dtypes=BOOL_TYPES + INT_TYPES,
|
"atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
||||||
reason=reason_onnx_does_not_support("Atan")
|
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
|
||||||
),
|
),
|
||||||
fixme("atan", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])),
|
fixme("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
|
||||||
skip(
|
skip(
|
||||||
"atan2", dtypes=BOOL_TYPES + INT_TYPES,
|
"atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
||||||
reason=reason_onnx_does_not_support("Atan")
|
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
|
||||||
),
|
),
|
||||||
fixme("atan2", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Atan", ["f64"])),
|
fixme("atan2", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
|
||||||
skip(
|
skip(
|
||||||
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
|
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
||||||
reason=reason_onnx_does_not_support("Ceil")
|
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
|
||||||
),
|
),
|
||||||
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
|
fixme("ceil", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
|
||||||
skip("nn.functional.scaled_dot_product_attention", opsets=[opsets_before(14)], reason="Need Trilu."),
|
skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
|
||||||
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
|
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
|
||||||
skip("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
|
skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
|
||||||
skip("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
|
skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
|
||||||
skip("tile", opsets=[opsets_before(13)], reason=reason_onnx_does_not_support("Tile")),
|
skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
|
||||||
fixme("unflatten", opsets=[opsets_before(13)], reason="Helper function is needed to support legacy ops."),
|
fixme("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
|
SKIP_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||||
skip(
|
skip(
|
||||||
"nn.functional.scaled_dot_product_attention",
|
"nn.functional.scaled_dot_product_attention",
|
||||||
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
|
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
|
||||||
@ -431,7 +180,7 @@ def _get_test_class_name(cls, num, params_dict) -> str:
|
|||||||
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
||||||
"opset_version": opset,
|
"opset_version": opset,
|
||||||
}
|
}
|
||||||
for opset in TESTED_OPSETS
|
for opset in onnx_test_common.TESTED_OPSETS
|
||||||
],
|
],
|
||||||
class_name_func=_get_test_class_name,
|
class_name_func=_get_test_class_name,
|
||||||
)
|
)
|
||||||
@ -445,7 +194,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||||||
|
|
||||||
@common_device_type.ops(
|
@common_device_type.ops(
|
||||||
[op for op in OPS_DB if op.name in TESTED_OPS],
|
[op for op in OPS_DB if op.name in TESTED_OPS],
|
||||||
allowed_dtypes=TESTED_DTYPES,
|
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
|
||||||
)
|
)
|
||||||
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
||||||
"""Test the ONNX exporter."""
|
"""Test the ONNX exporter."""
|
||||||
@ -491,10 +240,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||||||
self.run_test(model, inputs, rtol=rtol, atol=atol)
|
self.run_test(model, inputs, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
for opset in TESTED_OPSETS:
|
for opset in onnx_test_common.TESTED_OPSETS:
|
||||||
# The name needs to match the parameterized_class name.
|
# The name needs to match the parameterized_class name.
|
||||||
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
||||||
add_decorate_info(
|
onnx_test_common.add_decorate_info(
|
||||||
OPS_DB,
|
OPS_DB,
|
||||||
test_class_name,
|
test_class_name,
|
||||||
"test_output_match",
|
"test_output_match",
|
||||||
|
Reference in New Issue
Block a user