[ONNX] Use onnxruntime to run fx tests (#94638)

- Enable the mnist test
- Removed `max_pool2d` in the test because we don't have the op yet.
- Add aten::convolution
- Bump onnxscript version
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94638
Approved by: https://github.com/BowenBao, https://github.com/wschin, https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2023-02-11 15:32:03 +00:00
committed by PyTorch MergeBot
parent 9dd7e83676
commit a27bd42bb9
3 changed files with 14 additions and 15 deletions

View File

@ -64,7 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: change this when onnx reference patch is released.
pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@be441bf70f93369d30d1e12fd97e27d2beb75b12'
# TODO: change this when onnx-script is on testPypi
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@78ea55b888de88bfdadce7c3f6f3f83fa1404c7f'
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@a71e35bcd72537bf7572536ee57250a0c0488bf6'
# numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21.
# We don't actually need it for our tests, but it's imported if it's present, so uninstall.
pip uninstall -q --yes numba

View File

@ -17,7 +17,6 @@ import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.nn import functional as F
from torch.onnx._internal import diagnostics, fx as fx_onnx
from torch.testing._internal import common_utils
from torch.utils import _pytree as pytree
@ -54,7 +53,7 @@ def _run_test_with_fx_to_onnx_exporter_reference_runtime(
)
ref_outputs, _ = pytree.tree_flatten(model(*input_args))
ort_outputs = _run_onnx_reference_runtime(onnx_model, input_args)
ort_outputs = _run_ort(onnx_model, input_args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
@ -101,15 +100,12 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
# Commenting this line and removing related files.
# self.run_test_with_fx_to_onnx_exporter(func, (tensor_x,), {"b": 500.0})
@unittest.skip(
"Conv Op is not supported at the time. https://github.com/microsoft/onnx-script/issues/397"
)
def test_mnist(self):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True)
self.conv2 = nn.Conv2d(32, 64, 3, 2, bias=True)
self.fc1 = nn.Linear(9216, 128, bias=True)
self.fc2 = nn.Linear(128, 10, bias=True)
@ -118,7 +114,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.conv2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = F.max_pool2d(tensor_x, 2)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
@ -175,9 +170,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
)
ref_outputs, _ = pytree.tree_flatten(model(**inputs, return_dict=False))
ort_outputs = _run_onnx_reference_runtime(
onnx_model, (input_ids, attention_mask)
)
ort_outputs = _run_ort(onnx_model, (input_ids, attention_mask))
assert len(ref_outputs) == len(ort_outputs)
assert len(ref_outputs) == 5
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
@ -244,6 +237,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
fake_model,
*fake_args,
use_binary_format=False,
opset_version=self.opset_version,
)
# Tasks done by the following block.
@ -271,7 +265,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
# Original outputs.
ref_outputs, _ = pytree.tree_flatten(model(*args, **kwargs))
# ORT outputs.
ort_outputs = _run_onnx_reference_runtime(
ort_outputs = _run_ort(
os.path.join(tmp_folder, onnx_model_location),
(arg for arg in args if arg is not None),
)

View File

@ -67,9 +67,7 @@ _ATENLIB_FUNCTIONS = {
"aten::addmm": ops.core.aten_addmm,
"aten::amax": ops.core.aten_amax,
"aten::amin": ops.core.aten_amin,
# "aten::arange": ops.core.aten_arange_start_step,
"aten::arange": ops.core.aten_arange_start,
# "aten::arange": ops.core.aten_arange,
"aten::asin": ops.core.aten_asin,
"aten::asinh": ops.core.aten_asinh,
"aten::atan": ops.core.aten_atan,
@ -80,6 +78,7 @@ _ATENLIB_FUNCTIONS = {
"aten::clamp_min": ops.core.aten_clamp_min,
"aten::clamp": ops.core.aten_clamp,
"aten::clone": ops.core.aten_clone,
"aten::convolution": ops.core.aten_convolution,
"aten::cos": ops.core.aten_cos,
"aten::cosh": ops.core.aten_cosh,
"aten::detach": ops.core.aten_detach,
@ -519,7 +518,13 @@ def _validate_op_between_ort_torch(
for ort_output, expected_output in zip(ort_outputs, expected_outputs):
try:
torch.testing.assert_close(expected_output.numpy(), ort_output)
torch.testing.assert_close(
expected_output.numpy(),
ort_output,
check_device=False,
atol=10e-4,
rtol=10e-3,
)
except AssertionError as e:
warnings.warn(
f"Suppressed AssertionError:\n{e}.\n"