mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes #160150 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160200 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
618 lines
22 KiB
Python
618 lines
22 KiB
Python
# Owner(s): ["module: onnx"]
|
|
"""Simple API tests for the ONNX exporter."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
from onnxscript import BOOL, FLOAT, ir, opset18 as op
|
|
|
|
import torch
|
|
import torch.onnx._flags
|
|
from torch.onnx._internal.exporter import _testing as onnx_testing
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class SampleModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = x + 1
|
|
z = y.relu()
|
|
return (y, z)
|
|
|
|
|
|
class SampleModelTwoInputs(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
y = x + b
|
|
z = y.relu()
|
|
return (y, z)
|
|
|
|
|
|
class SampleModelForDynamicShapes(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
return x.relu(), b.sigmoid()
|
|
|
|
|
|
class NestedModelForDynamicShapes(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
ys: list[torch.Tensor],
|
|
zs: dict[str, torch.Tensor],
|
|
c: torch.Tensor,
|
|
):
|
|
y = ys[0] + ys[1] + zs["a"] + zs["b"]
|
|
w = 5
|
|
if x.shape[0] < 3 and c.shape[0] != 4:
|
|
return x + w, x + y, c
|
|
else:
|
|
return x - w, x - y, c
|
|
|
|
|
|
class SampleModelForDimOne(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return torch.cat((x, y), axis=1) + z
|
|
|
|
|
|
class TestExportAPIDynamo(common_utils.TestCase):
|
|
"""Tests for the ONNX exporter API when dynamo=True."""
|
|
|
|
def assert_export(
|
|
self, *args, strategy: str | None = "TorchExportNonStrictStrategy", **kwargs
|
|
):
|
|
onnx_program = torch.onnx.export(
|
|
*args, **kwargs, dynamo=True, fallback=False, verbose=False
|
|
)
|
|
assert onnx_program is not None
|
|
onnx_testing.assert_onnx_program(onnx_program, strategy=strategy)
|
|
|
|
def test_args_normalization_with_no_kwargs(self):
|
|
self.assert_export(
|
|
SampleModelTwoInputs(),
|
|
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
|
)
|
|
|
|
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
|
|
class constant_plus_tensor_inputs(torch.nn.Module):
|
|
def forward(self, a, x):
|
|
return a + torch.tensor(1) + x
|
|
|
|
# Capture log output
|
|
log_capture = io.StringIO()
|
|
log_handler = logging.StreamHandler(log_capture)
|
|
log_handler.setLevel(logging.ERROR)
|
|
# Get the logger used in _core.py
|
|
logger = logging.getLogger("torch.onnx._internal.exporter._core")
|
|
original_level = logger.level
|
|
logger.addHandler(log_handler)
|
|
logger.setLevel(logging.ERROR)
|
|
|
|
try:
|
|
with common_utils.TemporaryDirectoryName() as temp_dir:
|
|
self.assert_export(
|
|
constant_plus_tensor_inputs(),
|
|
(
|
|
1,
|
|
torch.ones(2),
|
|
),
|
|
dynamic_shapes=(
|
|
torch.export.Dim.DYNAMIC,
|
|
{0: torch.export.Dim.DYNAMIC},
|
|
),
|
|
report=True,
|
|
artifacts_dir=temp_dir,
|
|
)
|
|
# Check if the expected error was logged
|
|
log_output = log_capture.getvalue()
|
|
self.assertNotIn("Failed to save report due to an error", log_output)
|
|
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
|
|
# Note: We don't call assert_onnx_program here because it will fail
|
|
# due to the input name mismatch issue mentioned in your error
|
|
|
|
finally:
|
|
# Clean up logging
|
|
logger.removeHandler(log_handler)
|
|
logger.setLevel(original_level)
|
|
|
|
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
|
|
class constant_plus_tensor_inputs(torch.nn.Module):
|
|
def forward(self, a, x):
|
|
return a + torch.tensor(1) + x
|
|
|
|
onnx_program = torch.onnx.export(
|
|
constant_plus_tensor_inputs(),
|
|
(
|
|
1,
|
|
torch.ones(2),
|
|
),
|
|
dynamic_shapes=(
|
|
None,
|
|
{0: torch.export.Dim.DYNAMIC},
|
|
),
|
|
dynamo=True,
|
|
)
|
|
|
|
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
|
|
|
|
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
|
|
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
dynamic_axes={
|
|
"x": [0, 1, 2],
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
input_names=["x", "b"],
|
|
dynamic_axes={
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
|
|
def test_dynamic_axes_supports_output_names(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
|
input_names=["x", "b"],
|
|
dynamic_axes={
|
|
"b": [0, 1, 2],
|
|
},
|
|
)
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
input_names=["x", "b"],
|
|
output_names=["x_out", "b_out"],
|
|
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
|
|
)
|
|
|
|
def test_saved_f_exists_after_export(self):
|
|
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
|
_ = torch.onnx.export(
|
|
SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
|
|
)
|
|
self.assertTrue(os.path.exists(path))
|
|
|
|
def test_dynamic_shapes_with_fully_specified_axes(self):
|
|
ep = torch.export.export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
dynamic_shapes={
|
|
"x": {
|
|
0: torch.export.Dim("customx_dim_0"),
|
|
1: torch.export.Dim("customx_dim_1"),
|
|
2: torch.export.Dim("customx_dim_2"),
|
|
},
|
|
"b": {
|
|
0: torch.export.Dim("customb_dim_0"),
|
|
1: torch.export.Dim("customb_dim_1"),
|
|
2: torch.export.Dim("customb_dim_2"),
|
|
},
|
|
},
|
|
strict=True,
|
|
)
|
|
|
|
self.assert_export(ep, strategy=None)
|
|
|
|
def test_partial_dynamic_shapes(self):
|
|
self.assert_export(
|
|
SampleModelForDynamicShapes(),
|
|
(
|
|
torch.randn(2, 2, 3),
|
|
torch.randn(2, 2, 3),
|
|
),
|
|
dynamic_shapes={
|
|
"x": None,
|
|
"b": {
|
|
0: torch.export.Dim("customb_dim_0"),
|
|
1: torch.export.Dim("customb_dim_1"),
|
|
2: torch.export.Dim("customb_dim_2"),
|
|
},
|
|
},
|
|
)
|
|
|
|
def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(self):
|
|
# kwargs can still be renamed as long as it's in order
|
|
input_names = ["input_x", "input_y", "input_z", "d", "e", "f"]
|
|
|
|
dynamic_axes = {
|
|
"input_x": {0: "dim"},
|
|
"input_y": {0: "dim"},
|
|
"input_z": {0: "dim"},
|
|
"d": {0: "dim"},
|
|
"e": {0: "dim"},
|
|
}
|
|
|
|
model = NestedModelForDynamicShapes()
|
|
input = (
|
|
torch.ones(5),
|
|
[torch.zeros(5), torch.ones(5)],
|
|
{"a": torch.zeros(5), "b": torch.ones(5)},
|
|
torch.ones(4),
|
|
)
|
|
|
|
self.assert_export(
|
|
model, input, dynamic_axes=dynamic_axes, input_names=input_names
|
|
)
|
|
|
|
# Check whether inputs are dynamically shaped
|
|
onnx_program = torch.onnx.export(
|
|
model,
|
|
input,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names,
|
|
dynamo=True,
|
|
)
|
|
self.assertTrue(
|
|
all(
|
|
[
|
|
input.type.tensor_type.shape.dim[0].dim_param
|
|
for input in onnx_program.model_proto.graph.input
|
|
][:-1]
|
|
)
|
|
)
|
|
|
|
def test_upgraded_torchlib_impl(self):
|
|
class GeluModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
# Use GELU activation function
|
|
return torch.nn.functional.gelu(input, approximate="tanh")
|
|
|
|
input = (torch.randn(1, 3, 4, 4),)
|
|
onnx_program_op18 = torch.onnx.export(
|
|
GeluModel(),
|
|
input,
|
|
opset_version=18,
|
|
dynamo=True,
|
|
)
|
|
all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
|
|
self.assertIn("Tanh", all_nodes_op18)
|
|
self.assertNotIn("Gelu", all_nodes_op18)
|
|
|
|
onnx_program_op20 = torch.onnx.export(
|
|
GeluModel(),
|
|
input,
|
|
opset_version=20,
|
|
dynamo=True,
|
|
)
|
|
all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
|
|
self.assertIn("Gelu", all_nodes_op20)
|
|
|
|
def test_refine_dynamic_shapes_with_onnx_export(self):
|
|
# NOTE: From test/export/test_export.py
|
|
|
|
# refine lower, upper bound
|
|
class TestRefineDynamicShapeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if x.shape[0] >= 6 and y.shape[0] <= 16:
|
|
return x * 2.0, y + 1
|
|
|
|
inps = (torch.randn(16), torch.randn(12))
|
|
dynamic_shapes = {
|
|
"x": (torch.export.Dim("dx"),),
|
|
"y": (torch.export.Dim("dy"),),
|
|
}
|
|
self.assert_export(
|
|
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
|
|
)
|
|
|
|
def test_zero_output_aten_node(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed")
|
|
return x + x
|
|
|
|
input = torch.randn(2)
|
|
self.assert_export(Model(), (input))
|
|
|
|
def test_export_successful_when_dynamic_dimension_is_one(self):
|
|
self.assert_export(
|
|
SampleModelForDimOne(),
|
|
(torch.randn(1, 3), torch.randn(1, 5), torch.randn(1, 8)),
|
|
dynamic_shapes=(
|
|
{0: "batch", 1: "sequence"},
|
|
{0: "batch", 1: "sequence"},
|
|
{0: "batch", 1: "sequence"},
|
|
),
|
|
)
|
|
|
|
|
|
class TestCustomTranslationTable(common_utils.TestCase):
|
|
def test_custom_translation_table_overrides_ops(self):
|
|
from onnxscript import opset18 as op
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
def custom_add(self, other):
|
|
# Replace add with sub
|
|
return op.Sub(self, other)
|
|
|
|
custom_translation_table = {torch.ops.aten.add.Tensor: custom_add}
|
|
|
|
onnx_program = torch.onnx.export(
|
|
Model(),
|
|
(torch.randn(2, 2), torch.randn(2, 2)),
|
|
custom_translation_table=custom_translation_table,
|
|
dynamo=True,
|
|
)
|
|
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
|
self.assertIn("Sub", all_nodes)
|
|
self.assertNotIn("Add", all_nodes)
|
|
|
|
def test_custom_translation_table_supports_overloading_ops(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.ops.aten.logical_and.default(x, y)
|
|
|
|
def custom_add_bool(self: BOOL, other: BOOL) -> BOOL:
|
|
# Replace add with sub
|
|
return op.Sub(self, other)
|
|
|
|
def custom_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
|
# Replace add with mul
|
|
return op.Mul(self, other)
|
|
|
|
custom_translation_table = {
|
|
torch.ops.aten.logical_and.default: [custom_add, custom_add_bool],
|
|
}
|
|
|
|
onnx_program = torch.onnx.export(
|
|
Model(),
|
|
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
|
|
custom_translation_table=custom_translation_table,
|
|
dynamo=True,
|
|
)
|
|
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
|
# The dispatcher should pick the correct overload based on the input types
|
|
self.assertIn("Sub", all_nodes)
|
|
self.assertNotIn("Add", all_nodes)
|
|
self.assertNotIn("Mul", all_nodes)
|
|
|
|
def test_custom_translation_table_supports_custom_op_as_target(self):
|
|
# Define the custom op and use it in the model
|
|
@torch.library.custom_op("custom::add", mutates_args=())
|
|
def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
return a + b
|
|
|
|
@custom_add.register_fake
|
|
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
return torch.empty_like(a) + torch.empty_like(b)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return custom_add(x, y)
|
|
|
|
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
|
# Replace add with Sub
|
|
return op.Sub(self, other)
|
|
|
|
custom_translation_table = {
|
|
torch.ops.custom.add.default: onnx_add,
|
|
}
|
|
|
|
onnx_program = torch.onnx.export(
|
|
Model(),
|
|
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
|
|
custom_translation_table=custom_translation_table,
|
|
dynamo=True,
|
|
)
|
|
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
|
self.assertIn("Sub", all_nodes)
|
|
self.assertNotIn("Add", all_nodes)
|
|
|
|
def test_custom_translation_table_supports_custom_op_with_its_decomp(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor a, Tensor b) -> Tensor",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
|
|
@torch.library.register_fake("mylib::foo")
|
|
def foo_impl(a, b):
|
|
return a + b
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.ops.mylib.foo(x, y)
|
|
|
|
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
|
# Replace add with Sub
|
|
return op.Sub(self, other)
|
|
|
|
# With the custom op defined, we can use it in the model
|
|
# and replace it with a custom translation table
|
|
custom_translation_table = {
|
|
torch.ops.mylib.foo.default: onnx_add,
|
|
}
|
|
onnx_program = torch.onnx.export(
|
|
M(),
|
|
(torch.ones(3, 3), torch.ones(3, 3)),
|
|
custom_translation_table=custom_translation_table,
|
|
dynamo=True,
|
|
)
|
|
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
|
self.assertIn("Sub", all_nodes)
|
|
self.assertNotIn("Add", all_nodes)
|
|
|
|
# Without the custom op defined, it's going to be decomposed
|
|
onnx_program_decomp = torch.onnx.export(
|
|
M(), (torch.ones(3, 3), torch.ones(3, 3)), dynamo=True
|
|
)
|
|
all_nodes_decomp = [n.op_type for n in onnx_program_decomp.model.graph]
|
|
self.assertIn("Add", all_nodes_decomp)
|
|
self.assertNotIn("Sub", all_nodes_decomp)
|
|
|
|
|
|
class TestFakeTensorExport(common_utils.TestCase):
|
|
"""Test exporting in fake mode."""
|
|
|
|
def test_onnx_program_raises_when_model_defined_in_fake_mode(self):
|
|
with torch.onnx.enable_fake_mode():
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
|
|
|
def forward(self, x):
|
|
return self.weight + x
|
|
|
|
onnx_program = torch.onnx.export(
|
|
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
|
|
)
|
|
assert onnx_program is not None
|
|
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
|
with self.assertRaises(Exception):
|
|
# The tensors need to be replaced with real tensors
|
|
_ = onnx_program.model_proto
|
|
|
|
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
|
with self.assertRaises(Exception):
|
|
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
|
|
_ = onnx_program.model_proto
|
|
|
|
# If we replace with concrete tensors, the serialization will succeed.
|
|
# This needs to happen outside of the fake context
|
|
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
|
|
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
|
np.testing.assert_allclose(
|
|
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
|
)
|
|
|
|
def test_onnx_program_save_raises_when_model_initialized_in_fake_mode(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
|
|
|
def forward(self, x):
|
|
return self.weight + x
|
|
|
|
with torch.onnx.enable_fake_mode():
|
|
onnx_program = torch.onnx.export(
|
|
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
|
|
)
|
|
assert onnx_program is not None
|
|
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
|
with self.assertRaises(Exception):
|
|
# The tensors need to be replaced with real tensors
|
|
_ = onnx_program.model_proto
|
|
|
|
with self.assertRaises(Exception):
|
|
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
|
|
_ = onnx_program.model_proto
|
|
|
|
# If we replace with concrete tensors, the serialization will succeed
|
|
# This needs to happen outside of the fake context
|
|
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
|
|
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
|
np.testing.assert_allclose(
|
|
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
|
)
|
|
|
|
def test_onnx_program_save_succeeds_when_export_and_save_in_fake_mode(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
|
|
|
def forward(self, x):
|
|
return self.weight + x
|
|
|
|
real_model = Model()
|
|
|
|
with torch.onnx.enable_fake_mode():
|
|
onnx_program = torch.onnx.export(
|
|
real_model, (torch.tensor(1.0),), dynamo=True, optimize=False
|
|
)
|
|
|
|
assert onnx_program is not None
|
|
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
|
# Note that even though we are calling .model_proto (equivalently .save()) in fake mode,
|
|
# the concrete tensors are maintained.
|
|
# This is due to the usage of torch._subclasses.fake_tensor.unset_fake_temporarily() in
|
|
# TorchTensor.tobytes()
|
|
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
|
np.testing.assert_allclose(
|
|
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
|
)
|
|
|
|
# This works inside or outside the fake mode
|
|
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
|
np.testing.assert_allclose(
|
|
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
|
)
|
|
|
|
def test_is_in_onnx_export(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, x):
|
|
def f(x):
|
|
return x.sin() if torch.onnx.is_in_onnx_export() else x.cos()
|
|
|
|
return f(x)
|
|
|
|
self.assertFalse(torch.onnx.is_in_onnx_export())
|
|
onnx_program = torch.onnx.export(
|
|
Mod(),
|
|
(torch.randn(3, 4),),
|
|
dynamo=True,
|
|
fallback=False,
|
|
)
|
|
self.assertFalse(torch.onnx.is_in_onnx_export())
|
|
|
|
node_names = [n.op_type for n in onnx_program.model.graph]
|
|
self.assertIn("Sin", node_names)
|
|
|
|
def test_torchscript_exporter_raises_deprecation_warning(self):
|
|
# Test that the deprecation warning is raised when using torchscript exporter
|
|
with self.assertWarnsRegex(
|
|
DeprecationWarning, "You are using the legacy TorchScript-based ONNX export"
|
|
):
|
|
torch.onnx.export(
|
|
SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False
|
|
)
|
|
|
|
def test_model_output_can_be_none(self):
|
|
class ModelWithNoneOutput(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1, None
|
|
|
|
onnx_program = torch.onnx.export(
|
|
ModelWithNoneOutput(),
|
|
(torch.randn(1, 1, 2),),
|
|
dynamo=True,
|
|
)
|
|
onnx_testing.assert_onnx_program(onnx_program)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|