mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Use onnxruntime to run fx tests (#94638)
- Enable the mnist test - Removed `max_pool2d` in the test because we don't have the op yet. - Add aten::convolution - Bump onnxscript version Pull Request resolved: https://github.com/pytorch/pytorch/pull/94638 Approved by: https://github.com/BowenBao, https://github.com/wschin, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
9dd7e83676
commit
a27bd42bb9
@ -64,7 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
|
||||
# TODO: change this when onnx reference patch is released.
|
||||
pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@be441bf70f93369d30d1e12fd97e27d2beb75b12'
|
||||
# TODO: change this when onnx-script is on testPypi
|
||||
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@78ea55b888de88bfdadce7c3f6f3f83fa1404c7f'
|
||||
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@a71e35bcd72537bf7572536ee57250a0c0488bf6'
|
||||
# numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21.
|
||||
# We don't actually need it for our tests, but it's imported if it's present, so uninstall.
|
||||
pip uninstall -q --yes numba
|
||||
|
@ -17,7 +17,6 @@ import torch
|
||||
import transformers # type: ignore[import]
|
||||
from torch import nn
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.nn import functional as F
|
||||
from torch.onnx._internal import diagnostics, fx as fx_onnx
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -54,7 +53,7 @@ def _run_test_with_fx_to_onnx_exporter_reference_runtime(
|
||||
)
|
||||
|
||||
ref_outputs, _ = pytree.tree_flatten(model(*input_args))
|
||||
ort_outputs = _run_onnx_reference_runtime(onnx_model, input_args)
|
||||
ort_outputs = _run_ort(onnx_model, input_args)
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
torch.testing.assert_close(
|
||||
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
|
||||
@ -101,15 +100,12 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# Commenting this line and removing related files.
|
||||
# self.run_test_with_fx_to_onnx_exporter(func, (tensor_x,), {"b": 500.0})
|
||||
|
||||
@unittest.skip(
|
||||
"Conv Op is not supported at the time. https://github.com/microsoft/onnx-script/issues/397"
|
||||
)
|
||||
def test_mnist(self):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 2, bias=True)
|
||||
self.fc1 = nn.Linear(9216, 128, bias=True)
|
||||
self.fc2 = nn.Linear(128, 10, bias=True)
|
||||
|
||||
@ -118,7 +114,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
tensor_x = self.conv2(tensor_x)
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
tensor_x = F.max_pool2d(tensor_x, 2)
|
||||
tensor_x = torch.flatten(tensor_x, 1)
|
||||
tensor_x = self.fc1(tensor_x)
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
@ -175,9 +170,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
)
|
||||
|
||||
ref_outputs, _ = pytree.tree_flatten(model(**inputs, return_dict=False))
|
||||
ort_outputs = _run_onnx_reference_runtime(
|
||||
onnx_model, (input_ids, attention_mask)
|
||||
)
|
||||
ort_outputs = _run_ort(onnx_model, (input_ids, attention_mask))
|
||||
assert len(ref_outputs) == len(ort_outputs)
|
||||
assert len(ref_outputs) == 5
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
@ -244,6 +237,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
fake_model,
|
||||
*fake_args,
|
||||
use_binary_format=False,
|
||||
opset_version=self.opset_version,
|
||||
)
|
||||
|
||||
# Tasks done by the following block.
|
||||
@ -271,7 +265,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# Original outputs.
|
||||
ref_outputs, _ = pytree.tree_flatten(model(*args, **kwargs))
|
||||
# ORT outputs.
|
||||
ort_outputs = _run_onnx_reference_runtime(
|
||||
ort_outputs = _run_ort(
|
||||
os.path.join(tmp_folder, onnx_model_location),
|
||||
(arg for arg in args if arg is not None),
|
||||
)
|
||||
|
@ -67,9 +67,7 @@ _ATENLIB_FUNCTIONS = {
|
||||
"aten::addmm": ops.core.aten_addmm,
|
||||
"aten::amax": ops.core.aten_amax,
|
||||
"aten::amin": ops.core.aten_amin,
|
||||
# "aten::arange": ops.core.aten_arange_start_step,
|
||||
"aten::arange": ops.core.aten_arange_start,
|
||||
# "aten::arange": ops.core.aten_arange,
|
||||
"aten::asin": ops.core.aten_asin,
|
||||
"aten::asinh": ops.core.aten_asinh,
|
||||
"aten::atan": ops.core.aten_atan,
|
||||
@ -80,6 +78,7 @@ _ATENLIB_FUNCTIONS = {
|
||||
"aten::clamp_min": ops.core.aten_clamp_min,
|
||||
"aten::clamp": ops.core.aten_clamp,
|
||||
"aten::clone": ops.core.aten_clone,
|
||||
"aten::convolution": ops.core.aten_convolution,
|
||||
"aten::cos": ops.core.aten_cos,
|
||||
"aten::cosh": ops.core.aten_cosh,
|
||||
"aten::detach": ops.core.aten_detach,
|
||||
@ -519,7 +518,13 @@ def _validate_op_between_ort_torch(
|
||||
|
||||
for ort_output, expected_output in zip(ort_outputs, expected_outputs):
|
||||
try:
|
||||
torch.testing.assert_close(expected_output.numpy(), ort_output)
|
||||
torch.testing.assert_close(
|
||||
expected_output.numpy(),
|
||||
ort_output,
|
||||
check_device=False,
|
||||
atol=10e-4,
|
||||
rtol=10e-3,
|
||||
)
|
||||
except AssertionError as e:
|
||||
warnings.warn(
|
||||
f"Suppressed AssertionError:\n{e}.\n"
|
||||
|
Reference in New Issue
Block a user