mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Rename arguments, code clean up. * Refactor functions to smaller reusable functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77289 Approved by: https://github.com/justinchuby, https://github.com/garymm
47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import unittest
|
|
|
|
import onnxruntime # noqa: F401
|
|
from test_models import TestModels
|
|
from test_pytorch_onnx_onnxruntime import run_model_test
|
|
|
|
import torch
|
|
|
|
|
|
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
|
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
|
|
|
|
for opset_version in opset_versions:
|
|
self.opset_version = opset_version
|
|
self.onnx_shape_inference = True
|
|
run_model_test(self, model, input_args=inputs, rtol=rtol, atol=atol)
|
|
|
|
if self.is_script_test_enabled and opset_version > 11:
|
|
script_model = torch.jit.script(model)
|
|
run_model_test(self, script_model, input_args=inputs, rtol=rtol, atol=atol)
|
|
|
|
|
|
TestModels = type(
|
|
"TestModels",
|
|
(unittest.TestCase,),
|
|
dict(TestModels.__dict__, is_script_test_enabled=False, exportTest=exportTest),
|
|
)
|
|
|
|
|
|
# model tests for scripting with new JIT APIs and shape inference
|
|
TestModels_new_jit_API = type(
|
|
"TestModels_new_jit_API",
|
|
(unittest.TestCase,),
|
|
dict(
|
|
TestModels.__dict__,
|
|
exportTest=exportTest,
|
|
is_script_test_enabled=True,
|
|
onnx_shape_inference=True,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|