Files
pytorch/test/onnx/test_pytorch_jit_onnx.py
Justin Chu 524b78d4f6 [ONNX] Refactor torchscript based exporter (#161323)
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved.

- Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly
- Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused

@albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and
`torch/autograd/_functions/utils.py`? Thanks!

## BC Breaking
- **Deprecated members in `torch.onnx.verification` are removed**

Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323
Approved by: https://github.com/titaiwangms, https://github.com/angelayi
2025-09-02 16:10:30 +00:00

204 lines
6.5 KiB
Python

# Owner(s): ["module: onnx"]
import onnxruntime
import pytorch_test_common
from pytorch_test_common import skipIfNoCuda
import torch
from torch.onnx._internal.torchscript_exporter import verification
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter.utils import (
_trigger_symbolic_function_registration,
)
from torch.testing._internal import common_utils
def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
r"""
This function exports torch::jit::Graph object
to serialized ONNX ModelProto.
This function is for testing purpose.
It only keeps the essential parts for IR graph conversions.
It also does not interact with actual PyTorch modules nor
PyTorch tensor inputs.
"""
GLOBALS.export_onnx_opset_version = opset_version
_trigger_symbolic_function_registration()
graph = torch.onnx.utils._optimize_graph(
graph, operator_export_type, params_dict={}
)
proto, _, _, _ = graph._export_onnx(
{},
opset_version,
{},
False,
operator_export_type,
False,
False,
{},
True,
"",
{},
)
return proto
class _TestJITIRToONNX:
"""Abstract base class for test cases.
Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
don't run it directly. unitest.TestCase is mixed in as another base class when
creating concrete sub-types. See MakeTestCase().
"""
opset_version = -1 # Sub-classes must override
ort_providers = ["CPUExecutionProvider"]
check_shape = True
check_dtype = True
ignore_none = True # True for tracing, and Flase for scripting
def run_test(self, graph_ir, example_inputs, parse_tensor_constants=False):
graph = torch._C.parse_ir(graph_ir, parse_tensor_constants)
jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)
onnx_proto = _jit_graph_to_onnx_model(
graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
)
ort_sess = onnxruntime.InferenceSession(
onnx_proto, providers=self.ort_providers
)
ort_outs = verification._run_onnx(ort_sess, example_inputs)
options = verification.VerificationOptions(
rtol=1e-3,
atol=1e-7,
check_shape=self.check_shape,
check_dtype=self.check_dtype,
ignore_none=self.ignore_none,
acceptable_error_percentage=None,
)
verification._compare_onnx_pytorch_outputs(
ort_outs,
jit_outs,
options,
)
def test_example_ir(self):
graph_ir = """
graph(%1 : Float(2, 3),
%2 : Float(2, 3)):
%3 : int = prim::Constant[value=1]()
%4 : Float(2, 3) = aten::add(%1, %2, %3)
return (%4)
"""
a = torch.randn(2, 3)
b = torch.randn(2, 3)
self.run_test(graph_ir, (a, b))
def test_where_constants(self):
graph_ir = """
graph(%0 : Bool(8, device=cpu),
%1 : Float(8, device=cpu)):
%3 : Double(device=cpu) = prim::Constant[value={0.}]()
%4 : Float(8) = aten::where(%0, %1, %3)
return (%4)
"""
a = torch.zeros(8, dtype=bool)
b = torch.zeros(8)
self.run_test(graph_ir, (a, b), parse_tensor_constants=True)
def test_add_sub_with_graph_inputs(self):
for op in ["add", "sub", "rsub"]:
graph_ir = f"""
graph(%1 : Float(2, 3),
%2 : Float(2, 3),
%3 : int):
%4 : Float(2, 3) = aten::{op}(%1, %2, %3)
return (%4)
"""
a = torch.randn(2, 3)
b = torch.randn(2, 3)
self.run_test(graph_ir, (a, b, 2))
def test_native_layer_norm(self):
graph_ir = """
graph(%x : Float(2, 3, 2),
%w : Float(3, 2),
%b : Float(3, 2)):
%5 : int = prim::Constant[value=3]()
%6 : int = prim::Constant[value=2]()
%7 : int[] = prim::ListConstruct(%5, %6)
%10 : float = prim::Constant[value=1.0000000000000001e-05]()
%11 : Float(2, 3, 2), %12 : Float(2, 1, 1), %13 : Float(2, 1, 1) = aten::native_layer_norm(%x, %7, %w, %b, %10)
return (%11, %12, %13)
"""
x = torch.randn(2, 3, 2)
w = torch.randn(3, 2)
b = torch.randn(3, 2)
self.run_test(graph_ir, (x, w, b))
def test_convolution(self):
graph_ir = """
graph(%1 : Tensor,
%2 : Tensor):
%3 : NoneType = prim::Constant()
%4 : int[] = prim::Constant[value=[1, 1]]()
%5 : int[] = prim::Constant[value=[0, 0]]()
%6 : bool = prim::Constant[value=0]()
%7 : int = prim::Constant[value=1]()
%8 : Tensor = aten::convolution(%1, %2, %3, %4, %5, %4, %6, %5, %7)
return (%8)
"""
x = torch.randn(8, 1, 5, 5)
w = torch.randn(4, 1, 3, 3)
self.run_test(graph_ir, (x, w))
def test_log_softmax(self):
graph_ir = """
graph(%x: Tensor):
%half_to_float: bool = prim::Constant[value=0]()
%dim: int = prim::Constant[value=1]()
%y = aten::_log_softmax(%x, %dim, %half_to_float)
return (%y)
"""
x = torch.randn(5, 2)
self.run_test(graph_ir, (x,))
@skipIfNoCuda
def test_log_softmax_half_to_float(self):
graph_ir = """
graph(%x: Tensor):
%half_to_float: bool = prim::Constant[value=1]()
%dim: int = prim::Constant[value=1]()
%y = aten::_log_softmax(%x, %dim, %half_to_float)
return (%y)
"""
x = torch.randn(5, 2).half().to("cuda")
self.run_test(graph_ir, (x,))
def test_native_dropout(self):
graph_ir = """
graph(%1 : Float(2, 3)):
%2 : float = prim::Constant[value=0.0]()
%training : bool = prim::Constant[value=1]()
%3 : Tensor, %4 : Tensor = aten::native_dropout(%1, %2, %training)
return (%3, %4)
"""
a = torch.randn(2, 3)
self.run_test(graph_ir, (a,))
def MakeTestCase(opset_version: int) -> type:
name = f"TestJITIRToONNX_opset{opset_version}"
return type(
str(name),
(pytorch_test_common.ExportTestCase,),
dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
)
TestJITIRToONNX_opset14 = MakeTestCase(14)
if __name__ == "__main__":
common_utils.run_tests()