mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes #135419 PS: there are 104 empty output nodes, I suggest we add them one by one when we run into them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135591 Approved by: https://github.com/justinchuby
211 lines
6.6 KiB
Python
211 lines
6.6 KiB
Python
# Owner(s): ["module: onnx"]
|
|
"""Simple API tests for the ONNX exporter."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
|
|
import torch
|
|
from torch.onnx._internal.exporter import testing as onnx_testing
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class SampleModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = x + 1
|
|
z = y.relu()
|
|
return (y, z)
|
|
|
|
|
|
class SampleModelTwoInputs(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
y = x + b
|
|
z = y.relu()
|
|
return (y, z)
|
|
|
|
|
|
class SampleModelForDynamicShapes(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
return x.relu(), b.sigmoid()
|
|
|
|
|
|
class TestExportAPIDynamo(common_utils.TestCase):
|
|
"""Tests for the ONNX exporter API when dynamo=True."""
|
|
|
|
def assert_export(self, *args, **kwargs):
|
|
onnx_program = torch.onnx.export(*args, **kwargs, dynamo=True)
|
|
assert onnx_program is not None
|
|
onnx_testing.assert_onnx_program(onnx_program)
|
|
|
|
def test_args_normalization_with_no_kwargs(self):
|
|
self.assert_export(
|
|
SampleModelTwoInputs(),
|
|
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
|
)
|
|
|
|
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
|
|
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"x": [0, 1, 2],
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_supports_output_names(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
onnx_program = torch.onnx.export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
input_names=["x", "b"],
|
|
output_names=["x_out", "b_out"],
|
|
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
|
|
dynamo=True,
|
|
)
|
|
assert onnx_program is not None
|
|
onnx_testing.assert_onnx_program(onnx_program)
|
|
|
|
def test_saved_f_exists_after_export(self):
|
|
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
|
_ = torch.onnx.export(
|
|
SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
|
|
)
|
|
self.assertTrue(os.path.exists(path))
|
|
|
|
def test_export_supports_script_module(self):
|
|
class ScriptModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
self.assert_export(torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),))
|
|
|
|
def test_dynamic_shapes_with_fully_specified_axes(self):
|
|
exported_program = torch.export.export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
dynamic_shapes={
|
|
"x": {
|
|
0: torch.export.Dim("customx_dim_0"),
|
|
1: torch.export.Dim("customx_dim_1"),
|
|
2: torch.export.Dim("customx_dim_2"),
|
|
},
|
|
"b": {
|
|
0: torch.export.Dim("customb_dim_0"),
|
|
1: torch.export.Dim("customb_dim_1"),
|
|
2: torch.export.Dim("customb_dim_2"),
|
|
},
|
|
},
|
|
)
|
|
|
|
self.assert_export(exported_program)
|
|
|
|
def test_partial_dynamic_shapes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
dynamic_shapes={
|
|
"x": None,
|
|
"b": {
|
|
0: torch.export.Dim("customb_dim_0"),
|
|
1: torch.export.Dim("customb_dim_1"),
|
|
2: torch.export.Dim("customb_dim_2"),
|
|
},
|
|
},
|
|
)
|
|
|
|
def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self):
|
|
os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1"
|
|
assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1"
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, x):
|
|
(a0, a1), (b0, b1), (c0, c1, c2) = x
|
|
return a0 + a1 + b0 + b1 + c0 + c1 + c2
|
|
|
|
inputs = (
|
|
(1, 2),
|
|
(
|
|
torch.randn(4, 4),
|
|
torch.randn(4, 4),
|
|
),
|
|
(
|
|
torch.randn(4, 4),
|
|
torch.randn(4, 4),
|
|
torch.randn(4, 4),
|
|
),
|
|
)
|
|
|
|
onnx_program = torch.onnx.dynamo_export(
|
|
Nested(),
|
|
inputs,
|
|
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
|
|
)
|
|
assert onnx_program is not None
|
|
onnx_testing.assert_onnx_program(onnx_program)
|
|
|
|
def test_refine_dynamic_shapes_with_onnx_export(self):
|
|
# NOTE: From test/export/test_export.py
|
|
|
|
# refine lower, upper bound
|
|
class TestRefineDynamicShapeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if x.shape[0] >= 6 and y.shape[0] <= 16:
|
|
return x * 2.0, y + 1
|
|
|
|
inps = (torch.randn(16), torch.randn(12))
|
|
dynamic_shapes = {
|
|
"x": (torch.export.Dim("dx"),),
|
|
"y": (torch.export.Dim("dy"),),
|
|
}
|
|
self.assert_export(
|
|
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
|
|
)
|
|
|
|
def test_zero_output_aten_node(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed")
|
|
return x + x
|
|
|
|
input = torch.randn(2)
|
|
self.assert_export(Model(), (input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|