mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Rename arguments, code clean up. * Refactor functions to smaller reusable functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77289 Approved by: https://github.com/justinchuby, https://github.com/garymm
97 lines
2.8 KiB
Python
97 lines
2.8 KiB
Python
# Owner(s): ["module: onnx"]
|
|
import unittest
|
|
|
|
import onnxruntime
|
|
|
|
import torch
|
|
from torch._C import parse_ir
|
|
from torch.onnx import verification
|
|
|
|
|
|
def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
|
|
r"""
|
|
This function exports torch::jit::Graph object
|
|
to serialized ONNX ModelProto.
|
|
This function is for testing purpose.
|
|
It only keeps the essential parts for IR graph conversions.
|
|
It also does not interact with actual PyTorch modules nor
|
|
PyTorch tensor inputs.
|
|
"""
|
|
from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
|
|
from torch.onnx.utils import _optimize_graph
|
|
|
|
# Shape inference is required because some ops' symbolic functions
|
|
# generate sub-graphs based on inputs' types.
|
|
_set_onnx_shape_inference(True)
|
|
_set_opset_version(opset_version)
|
|
graph = _optimize_graph(graph, operator_export_type, params_dict={})
|
|
proto, _, _, _ = graph._export_onnx(
|
|
{},
|
|
opset_version,
|
|
{},
|
|
False,
|
|
operator_export_type,
|
|
False,
|
|
False,
|
|
{},
|
|
True,
|
|
"",
|
|
{},
|
|
)
|
|
return proto
|
|
|
|
|
|
class _TestJITIRToONNX:
|
|
"""Abstract base class for test cases.
|
|
|
|
Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
|
|
don't run it directly. unitest.TestCase is mixed in as another base class when
|
|
creating concrete sub-types. See MakeTestCase().
|
|
"""
|
|
|
|
opset_version = -1 # Sub-classes must override
|
|
ort_providers = ["CPUExecutionProvider"]
|
|
|
|
def run_test(self, graph_ir, example_inputs):
|
|
graph = parse_ir(graph_ir)
|
|
jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)
|
|
|
|
onnx_proto = _jit_graph_to_onnx_model(
|
|
graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
|
|
)
|
|
ort_sess = onnxruntime.InferenceSession(
|
|
onnx_proto, providers=self.ort_providers
|
|
)
|
|
ort_outs = verification._run_ort(ort_sess, example_inputs)
|
|
|
|
verification._compare_ort_pytorch_outputs(
|
|
ort_outs, jit_outs, rtol=1e-3, atol=1e-7
|
|
)
|
|
|
|
def test_example_ir(self):
|
|
graph_ir = """
|
|
graph(%1 : Float(2, 3),
|
|
%2 : Float(2, 3)):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%4 : Float(2, 3) = aten::add(%1, %2, %3)
|
|
return (%4)
|
|
"""
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
self.run_test(graph_ir, (a, b))
|
|
|
|
|
|
def MakeTestCase(opset_version: int) -> type:
|
|
name = f"TestJITIRToONNX_opset{opset_version}"
|
|
return type(
|
|
str(name),
|
|
(unittest.TestCase,),
|
|
dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
|
|
)
|
|
|
|
|
|
TestJITIRToONNX_opset14 = MakeTestCase(14)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|