Files
pytorch/test/onnx/test_models_onnxruntime.py
BowenBao 6883b0ce9f [ONNX][WIP] Refactor verification.py
* Rename arguments, code clean up.
* Refactor functions to smaller reusable functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77289

Approved by: https://github.com/justinchuby, https://github.com/garymm
2022-05-31 18:49:39 +00:00

47 lines
1.2 KiB
Python

# Owner(s): ["module: onnx"]
import unittest
import onnxruntime # noqa: F401
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
for opset_version in opset_versions:
self.opset_version = opset_version
self.onnx_shape_inference = True
run_model_test(self, model, input_args=inputs, rtol=rtol, atol=atol)
if self.is_script_test_enabled and opset_version > 11:
script_model = torch.jit.script(model)
run_model_test(self, script_model, input_args=inputs, rtol=rtol, atol=atol)
TestModels = type(
"TestModels",
(unittest.TestCase,),
dict(TestModels.__dict__, is_script_test_enabled=False, exportTest=exportTest),
)
# model tests for scripting with new JIT APIs and shape inference
TestModels_new_jit_API = type(
"TestModels_new_jit_API",
(unittest.TestCase,),
dict(
TestModels.__dict__,
exportTest=exportTest,
is_script_test_enabled=True,
onnx_shape_inference=True,
),
)
if __name__ == "__main__":
unittest.main()