mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
343 lines
13 KiB
Python
343 lines
13 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
"""Test consistency between the output values of torch.onnx exported operators
|
|
and torch operators given the same inputs.
|
|
|
|
Usage:
|
|
|
|
pytest test/onnx/test_op_consistency.py
|
|
|
|
To run tests on a specific operator (e.g. torch.ceil):
|
|
|
|
pytest test/onnx/test_op_consistency.py -k ceil
|
|
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
|
|
|
|
Read more on Running and writing tests:
|
|
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
|
|
|
Note:
|
|
|
|
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
|
|
TESTED_OPS lists. See "Modify this section"
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
from typing import Optional
|
|
|
|
import onnx_test_common
|
|
import parameterized
|
|
|
|
# For readability, these two are allowed to be imported as function
|
|
from onnx_test_common import skip, xfail
|
|
|
|
import torch
|
|
from torch.testing._internal import (
|
|
common_device_type,
|
|
common_methods_invocations,
|
|
common_utils,
|
|
)
|
|
|
|
|
|
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
|
|
|
# Modify this section ##########################################################
|
|
# NOTE: Modify this section as more ops are supported. The list should be sorted
|
|
# alphabetically.
|
|
#
|
|
# For example, to add a test for torch.ceil:
|
|
# 1. Add "ceil" to TESTED_OPS then run pytest.
|
|
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
|
|
|
|
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
|
|
# Ops to be tested for numerical consistency between onnx and pytorch
|
|
# TODO: https://github.com/pytorch/pytorch/issues/102211
|
|
TESTED_OPS: frozenset[str] = frozenset(
|
|
[
|
|
"atan",
|
|
"atan2",
|
|
# "atleast_1d", # How to support list input?
|
|
# "atleast_2d",
|
|
# "atleast_3d",
|
|
"broadcast_to",
|
|
"ceil",
|
|
"expand",
|
|
"flatten",
|
|
"hstack",
|
|
"logical_not",
|
|
# "logit",
|
|
"nn.functional.scaled_dot_product_attention",
|
|
"repeat",
|
|
"round",
|
|
# "scatter_add",
|
|
# "scatter_reduce",
|
|
"sqrt",
|
|
"stft",
|
|
"t",
|
|
"tile",
|
|
"unflatten",
|
|
"vstack",
|
|
]
|
|
)
|
|
|
|
# fmt: off
|
|
# Turn off black formatting to keep the list compact
|
|
|
|
# Expected failures for onnx export.
|
|
# The list should be sorted alphabetically by op name.
|
|
# Q: When should I use fixme vs vs skip vs xfail?
|
|
# A: Prefer xfail over skip when possible.
|
|
# 2a. If a test is now failing because of xpass, because some previous errors
|
|
# are now fixed, removed the corresponding xfail.
|
|
# 2b. If a test is not failing consistently, use skip.
|
|
EXPECTED_SKIPS_OR_FAILS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|
skip(
|
|
"atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
|
|
),
|
|
xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
|
|
skip(
|
|
"atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
|
|
),
|
|
xfail(
|
|
"atan2", dtypes=[torch.float64],
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])
|
|
),
|
|
xfail(
|
|
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
|
|
),
|
|
skip("hstack", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
|
|
xfail(
|
|
"logit",
|
|
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"),
|
|
),
|
|
skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
|
|
skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
|
|
xfail("round", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
|
|
xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
|
|
xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
|
|
xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
|
|
skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
|
|
skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
|
|
skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
|
|
xfail("scatter_reduce", variant_name="mean",
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")),
|
|
skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="sum",
|
|
dtypes=(torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="prod",
|
|
dtypes=(torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="amin",
|
|
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="amax",
|
|
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="mean",
|
|
reason="ONNX doesn't support reduce='mean' option",
|
|
),
|
|
skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
|
|
skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
|
|
xfail("stft",
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")),
|
|
skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
|
|
xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
|
|
skip("vstack", opsets=[onnx_test_common.opsets_before(11)],
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
|
|
)
|
|
# fmt: on
|
|
|
|
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|
skip(
|
|
"nn.functional.scaled_dot_product_attention",
|
|
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
|
|
reason="dropout is random so the results do not match",
|
|
),
|
|
skip(
|
|
"repeat",
|
|
reason="Empty repeats value leads to an invalid graph",
|
|
matcher=lambda sample: not sample.args[0],
|
|
),
|
|
skip(
|
|
"scatter_reduce",
|
|
# ONNX has not include_self parameter and default is include_self=True mode
|
|
matcher=lambda sample: sample.kwargs.get("include_self") is False,
|
|
reason="ONNX does't support include_self=False option",
|
|
),
|
|
skip(
|
|
"stft",
|
|
reason="ONNX STFT does not support complex results",
|
|
matcher=lambda sample: sample.kwargs.get("return_complex") is True,
|
|
),
|
|
skip(
|
|
"tile",
|
|
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
|
|
or not sample.input.shape,
|
|
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
|
),
|
|
skip(
|
|
"unflatten",
|
|
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
|
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
|
|
),
|
|
)
|
|
|
|
|
|
# END OF SECTION TO MODIFY #####################################################
|
|
|
|
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
|
|
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
|
|
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
|
|
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
|
|
|
|
|
|
class SingleOpModel(torch.nn.Module):
|
|
"""Test model to wrap around a single op for export."""
|
|
|
|
def __init__(self, op, kwargs):
|
|
super().__init__()
|
|
self.operator = op
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, *args):
|
|
return self.operator(*args, **self.kwargs)
|
|
|
|
|
|
def _should_skip_xfail_test_sample(
|
|
op_name: str, sample
|
|
) -> tuple[Optional[str], Optional[str]]:
|
|
"""Returns a reason if a test sample should be skipped."""
|
|
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
|
|
return None, None
|
|
for decorator_meta in SKIP_XFAIL_SUBTESTS:
|
|
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
|
|
if decorator_meta.op_name == op_name:
|
|
assert decorator_meta.matcher is not None, "Matcher must be defined"
|
|
if decorator_meta.matcher(sample):
|
|
return decorator_meta.test_behavior, decorator_meta.reason
|
|
return None, None
|
|
|
|
|
|
def _get_test_class_name(cls, num, params_dict) -> str:
|
|
del cls # unused
|
|
del num # unused
|
|
return params_dict["name"]
|
|
|
|
|
|
@parameterized.parameterized_class(
|
|
[
|
|
{
|
|
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
|
"opset_version": opset,
|
|
}
|
|
for opset in onnx_test_common.TESTED_OPSETS
|
|
],
|
|
class_name_func=_get_test_class_name,
|
|
)
|
|
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|
"""Test output consistency between exported ONNX models and PyTorch eager mode.
|
|
|
|
This is a parameterized test suite.
|
|
"""
|
|
|
|
opset_version = -1
|
|
|
|
@common_device_type.ops(
|
|
[op for op in OPS_DB if op.name in TESTED_OPS],
|
|
allowed_dtypes=onnx_test_common.INT_TYPES
|
|
+ onnx_test_common.FLOAT_TYPES
|
|
+ onnx_test_common.BOOL_TYPES,
|
|
)
|
|
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
|
"""Test the ONNX exporter."""
|
|
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
|
|
assert device == "cpu"
|
|
|
|
samples = op.sample_inputs(
|
|
device,
|
|
dtype,
|
|
requires_grad=False,
|
|
)
|
|
|
|
for i, cpu_sample in enumerate(samples):
|
|
inputs = (cpu_sample.input, *cpu_sample.args)
|
|
# Provide the repr to subtest because tensors are not serializable in parallel test runs
|
|
with self.subTest(
|
|
opset=self.opset_version,
|
|
sample_num=i,
|
|
inputs=repr(inputs),
|
|
kwargs=repr(cpu_sample.kwargs),
|
|
):
|
|
test_behavior, reason = _should_skip_xfail_test_sample(
|
|
op.name, cpu_sample
|
|
)
|
|
with onnx_test_common.normal_xfail_skip_test_behaviors(
|
|
test_behavior, reason
|
|
):
|
|
model = SingleOpModel(op, cpu_sample.kwargs)
|
|
model.eval()
|
|
|
|
if dtype == torch.float32:
|
|
# Relax atol and rtol for float32 based on empirical results
|
|
# The current most relaxed values are for aten::stft
|
|
rtol = 1e-5
|
|
atol = 2e-5
|
|
elif dtype == torch.float64:
|
|
# The current most relaxed values are for aten::stft
|
|
rtol = 1e-5
|
|
atol = 2e-5
|
|
else:
|
|
rtol = None
|
|
atol = None
|
|
# Run the test
|
|
self.run_test(model, inputs, rtol=rtol, atol=atol)
|
|
|
|
|
|
for opset in onnx_test_common.TESTED_OPSETS:
|
|
# The name needs to match the parameterized_class name.
|
|
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
|
onnx_test_common.add_decorate_info(
|
|
OPS_DB,
|
|
test_class_name,
|
|
"test_output_match",
|
|
opset=opset,
|
|
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
|
|
)
|
|
common_device_type.instantiate_device_type_tests(
|
|
globals()[test_class_name], globals(), only_for="cpu"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|