mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Set USE_EXPERIMENTAL_LOGIC to True (#137296)
This sets dynamo_export to use the new export logic. The legacy dynamo export logic will be removed as a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137296 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
5aa5a5763e
commit
af43b445a5
@ -1,89 +0,0 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
import io
|
||||
|
||||
import onnx
|
||||
|
||||
import torch
|
||||
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
|
||||
from torch.onnx._internal._exporter_legacy import ResolvedExportOptions
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class SampleModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = x + 1
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
|
||||
class SampleModelTwoInputs(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
y = x + b
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
|
||||
class SampleModelForDynamicShapes(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
return x.relu(), b.sigmoid()
|
||||
|
||||
|
||||
class TestExportOptionsAPI(common_utils.TestCase):
|
||||
def test_dynamic_shapes_default(self):
|
||||
options = ResolvedExportOptions(ExportOptions())
|
||||
self.assertFalse(options.dynamic_shapes)
|
||||
|
||||
def test_dynamic_shapes_explicit(self):
|
||||
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
|
||||
self.assertFalse(options.dynamic_shapes)
|
||||
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
|
||||
self.assertTrue(options.dynamic_shapes)
|
||||
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
|
||||
self.assertFalse(options.dynamic_shapes)
|
||||
|
||||
|
||||
class TestDynamoExportAPI(common_utils.TestCase):
|
||||
def test_default_export(self):
|
||||
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
|
||||
self.assertIsInstance(output, ONNXProgram)
|
||||
self.assertIsInstance(output.model_proto, onnx.ModelProto)
|
||||
|
||||
def test_export_with_options(self):
|
||||
self.assertIsInstance(
|
||||
dynamo_export(
|
||||
SampleModel(),
|
||||
torch.randn(1, 1, 2),
|
||||
export_options=ExportOptions(
|
||||
dynamic_shapes=True,
|
||||
),
|
||||
),
|
||||
ONNXProgram,
|
||||
)
|
||||
|
||||
def test_save_to_file_default_serializer(self):
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
|
||||
onnx.load(path)
|
||||
|
||||
def test_save_to_existing_buffer_default_serializer(self):
|
||||
buffer = io.BytesIO()
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
|
||||
onnx.load(buffer)
|
||||
|
||||
def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
|
||||
self,
|
||||
):
|
||||
with self.assertRaises(torch.onnx.OnnxExporterError):
|
||||
dynamo_export(
|
||||
SampleModel(),
|
||||
torch.randn(1, 1, 2),
|
||||
export_options=ExportOptions(
|
||||
diagnostic_options=torch.onnx.DiagnosticOptions(
|
||||
warnings_as_errors=True
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
@ -1,431 +0,0 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the internal registration wrapper module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TypeVar, Union
|
||||
|
||||
import onnxscript # type: ignore[import]
|
||||
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
|
||||
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
# TODO: this can only be global. https://github.com/microsoft/onnxscript/issues/805
|
||||
TCustomFloat = TypeVar("TCustomFloat", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
|
||||
|
||||
|
||||
class TestRegistration(common_utils.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = torch.onnx.OnnxRegistry()
|
||||
self.custom_domain = onnxscript.values.Opset(domain="custom", version=1)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
internal_name_instance = registration.OpName.from_name_parts(
|
||||
namespace="test", op_name="test_op"
|
||||
)
|
||||
self.registry._registry.pop(internal_name_instance, None)
|
||||
|
||||
def test_register_custom_op_registers_custom_function(self):
|
||||
self.assertFalse(self.registry.is_registered_op("test", "test_op", "default"))
|
||||
|
||||
@onnxscript.script(self.custom_domain)
|
||||
def custom_add(x, y):
|
||||
return op.Add(x, y)
|
||||
|
||||
self.registry.register_op(custom_add, "test", "test_op", "default")
|
||||
self.assertTrue(self.registry.is_registered_op("test", "test_op", "default"))
|
||||
|
||||
# Test on get_ops
|
||||
function_group = self.registry.get_op_functions("test", "test_op", "default")
|
||||
self.assertIsNotNone(function_group)
|
||||
self.assertEqual({func.onnx_function for func in function_group}, {custom_add}) # type: ignore[arg-type]
|
||||
|
||||
def test_custom_onnx_symbolic_joins_existing_function(self):
|
||||
self.assertFalse(self.registry.is_registered_op("test", "test_op"))
|
||||
|
||||
@onnxscript.script(self.custom_domain)
|
||||
def test_original(x, y):
|
||||
return op.Add(x, y)
|
||||
|
||||
# default has to be specified, as we are not using the registration.OpName
|
||||
internal_name_instance = registration.OpName.from_name_parts(
|
||||
namespace="test", op_name="test_op", overload="default"
|
||||
)
|
||||
symbolic_fn = registration.ONNXFunction(
|
||||
test_original, op_full_name=internal_name_instance.qualified_name()
|
||||
)
|
||||
self.registry._register(internal_name_instance, symbolic_fn)
|
||||
self.assertTrue(self.registry.is_registered_op("test", "test_op"))
|
||||
|
||||
@onnxscript.script(self.custom_domain)
|
||||
def test_custom(x, y):
|
||||
return op.Add(x, y)
|
||||
|
||||
self.registry.register_op(test_custom, "test", "test_op")
|
||||
|
||||
function_group = self.registry.get_op_functions("test", "test_op")
|
||||
assert function_group is not None
|
||||
# The order does matter (list)
|
||||
self.assertEqual(
|
||||
[func.onnx_function for func in function_group],
|
||||
[test_original, test_custom],
|
||||
)
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestDispatcher(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
self.registry = torch.onnx.OnnxRegistry()
|
||||
self.diagnostic_context = diagnostics.DiagnosticContext(
|
||||
"torch.onnx.dynamo_export", torch.__version__
|
||||
)
|
||||
self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
|
||||
self.registry, self.diagnostic_context
|
||||
)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"node, expected_name",
|
||||
[
|
||||
common_utils.subtest(
|
||||
(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3), torch.tensor(4)),
|
||||
kwargs={},
|
||||
),
|
||||
("aten", "add", "Tensor"),
|
||||
),
|
||||
name="get_Opoverload_name",
|
||||
),
|
||||
common_utils.subtest(
|
||||
(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::sym_size",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.sym_size,
|
||||
args=(),
|
||||
kwargs={},
|
||||
),
|
||||
("aten", "sym_size", None),
|
||||
),
|
||||
name="get_Opoverloadpacket_name",
|
||||
),
|
||||
common_utils.subtest(
|
||||
(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="builtin_add",
|
||||
op="call_function",
|
||||
target=operator.add,
|
||||
args=(1, 2),
|
||||
kwargs={},
|
||||
),
|
||||
("_operator", "add", None),
|
||||
),
|
||||
name="get_builtin_op_name",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_aten_name_on_supported_fx_node(
|
||||
self, node: torch.fx.Node, expected_name: str
|
||||
):
|
||||
expected_name_class = registration.OpName.from_name_parts(*expected_name)
|
||||
self.assertEqual(
|
||||
self.dispatcher._get_aten_name(node, self.diagnostic_context),
|
||||
expected_name_class,
|
||||
)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"node",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add,
|
||||
args=(),
|
||||
kwargs={},
|
||||
),
|
||||
name="unsupported_Opoverloadpacket_name",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="builtin_add",
|
||||
op="call_function",
|
||||
target=operator.add,
|
||||
args=("A", "B"),
|
||||
kwargs={},
|
||||
),
|
||||
name="unsupported_input_dtypes_for_builtin_op",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::made_up_node",
|
||||
op="call_function",
|
||||
target=lambda: None,
|
||||
args=(),
|
||||
kwargs={},
|
||||
),
|
||||
name="unsupported_target_function",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_aten_name_on_unsupported_fx_node(self, node: torch.fx.Node):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.dispatcher._get_aten_name(node, self.diagnostic_context)
|
||||
|
||||
def test_get_function_overloads_gives_overload_fall_back_default(self):
|
||||
# Test fall back to default op name
|
||||
node_overload = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3), torch.tensor(4)),
|
||||
kwargs={},
|
||||
)
|
||||
node_overloadpacket = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.dispatcher.get_function_overloads(
|
||||
node_overload, self.diagnostic_context
|
||||
),
|
||||
self.dispatcher.get_function_overloads(
|
||||
node_overloadpacket,
|
||||
self.diagnostic_context,
|
||||
),
|
||||
)
|
||||
|
||||
# Non-registered op
|
||||
unsupported_op_node = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::made_up_node",
|
||||
op="call_function",
|
||||
target=lambda: None,
|
||||
args=(),
|
||||
kwargs={},
|
||||
)
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.dispatcher.get_function_overloads(
|
||||
unsupported_op_node,
|
||||
self.diagnostic_context,
|
||||
)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"node",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={},
|
||||
),
|
||||
name="nearest_match",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={"alpha": 1},
|
||||
),
|
||||
name="perfect_match_with_kwargs",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_the_perfect_or_nearest_match_onnxfunction_gives_custom_ops_precedence(
|
||||
self, node
|
||||
):
|
||||
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_custom_op(
|
||||
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_default_op(
|
||||
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
op_full_name = "test::test_op"
|
||||
|
||||
custom_overloads = [
|
||||
registration.ONNXFunction(
|
||||
test_custom_op, op_full_name=op_full_name, is_custom=True
|
||||
)
|
||||
]
|
||||
function_overloads = [
|
||||
registration.ONNXFunction(test_default_op, op_full_name=op_full_name)
|
||||
] + custom_overloads
|
||||
|
||||
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
|
||||
node,
|
||||
function_overloads,
|
||||
node.args,
|
||||
node.kwargs,
|
||||
self.diagnostic_context,
|
||||
)
|
||||
self.assertEqual(symbolic_fn, test_custom_op)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"node",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={"attr": None},
|
||||
),
|
||||
name="perfect_match_with_ignoring_none_attribute",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={"unrelated": None},
|
||||
),
|
||||
name="perfect_match_with_ignoring_unrelated_none_attribute",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none(
|
||||
self, node
|
||||
):
|
||||
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_op_attribute(
|
||||
x: TCustomFloat, y: TCustomFloat, attr: int
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
op_full_name = "test::test_op"
|
||||
|
||||
function_overloads = [
|
||||
registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name),
|
||||
registration.ONNXFunction(test_op, op_full_name=op_full_name),
|
||||
]
|
||||
|
||||
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
|
||||
node,
|
||||
function_overloads,
|
||||
node.args,
|
||||
node.kwargs,
|
||||
self.diagnostic_context,
|
||||
)
|
||||
self.assertEqual(symbolic_fn, test_op)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"node",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={},
|
||||
),
|
||||
name="nearest_match",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="aten::add.Tensor",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
||||
args=(torch.tensor(3.0), torch.tensor(4.0)),
|
||||
kwargs={"alpha": 1},
|
||||
),
|
||||
name="perfect_match_with_kwargs",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_the_perfect_or_nearest_match_onnxfunction_gives_tie_breaks_to_registered_order(
|
||||
self, node
|
||||
):
|
||||
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_second_custom_op(
|
||||
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_third_custom_op(
|
||||
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
@onnxscript.script(custom_domain)
|
||||
def test_first_custom_op(
|
||||
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
|
||||
) -> TCustomFloat:
|
||||
return op.Add(x, y)
|
||||
|
||||
op_full_name = "aten::add"
|
||||
|
||||
function_overloads = [
|
||||
registration.ONNXFunction(
|
||||
test_first_custom_op, op_full_name=op_full_name, is_custom=True
|
||||
),
|
||||
registration.ONNXFunction(
|
||||
test_second_custom_op, op_full_name=op_full_name, is_custom=True
|
||||
),
|
||||
registration.ONNXFunction(
|
||||
test_third_custom_op, op_full_name=op_full_name, is_custom=True
|
||||
),
|
||||
]
|
||||
|
||||
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
|
||||
node,
|
||||
function_overloads,
|
||||
node.args,
|
||||
node.kwargs,
|
||||
self.diagnostic_context,
|
||||
)
|
||||
self.assertEqual(symbolic_fn, test_third_custom_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
@ -3,7 +3,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torchvision
|
||||
import transformers
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
@ -12,6 +15,13 @@ from torch.testing._internal import common_utils
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class DynamoExporterTest(common_utils.TestCase):
|
||||
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
||||
onnx_program = torch.onnx.export(
|
||||
model, args, kwargs=kwargs, dynamo=True, fallback=False, **options
|
||||
)
|
||||
assert onnx_program is not None
|
||||
return onnx_program
|
||||
|
||||
def test_insert_contiguous_between_transpose_and_view(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
@ -30,9 +40,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
ep = torch.export.export(model, (query, key, value), strict=False)
|
||||
self.assertNotIn("call_method", str(ep.graph))
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
model, (query, key, value), dynamo=True, fallback=False
|
||||
)
|
||||
onnx_program = self.export(model, (query, key, value))
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
|
||||
def test_constant_complex(self):
|
||||
@ -46,7 +54,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
[[1.0 + 2.0j, 3.0 + 4.0j], [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
|
||||
onnx_program = self.export(MulModule(), (x,))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_pow_does_not_trigger_type_promotion(self):
|
||||
@ -56,7 +64,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
|
||||
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
|
||||
|
||||
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True)
|
||||
onnx_program = self.export(Model(), (x,))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertNotIn("Cast", [node.op_type for node in onnx_program.model.graph])
|
||||
|
||||
@ -72,12 +80,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
||||
return y
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
CondModel(),
|
||||
(torch.tensor([1, 2]),),
|
||||
dynamo=True,
|
||||
fallback=False,
|
||||
)
|
||||
onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
|
||||
onnx_model = onnx_program.model
|
||||
self.assertIn("If", [node.op_type for node in onnx_model.graph])
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
@ -117,12 +120,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
||||
return y
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
CondModel(),
|
||||
(torch.tensor([1, 2]),),
|
||||
dynamo=True,
|
||||
fallback=False,
|
||||
)
|
||||
onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
|
||||
@ -141,11 +139,105 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
torch.tensor([0.1, 0.2]),
|
||||
0,
|
||||
)
|
||||
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
|
||||
onnx_program = self.export(VisionModel(), args)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
# TODO(justinchuby): Test multi-output HOPs
|
||||
|
||||
def test_empty(self):
|
||||
def func(x):
|
||||
return torch.empty(x.size(), dtype=torch.int64)
|
||||
|
||||
# Since `torch.empty` returns tensor with uninitialized data, we cannot
|
||||
# test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison.
|
||||
_ = self.export(func, (torch.randn(1, 2),))
|
||||
|
||||
def test_multiple_outputs_op_with_evaluator(self):
|
||||
class TopKModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
values, _ = torch.topk(x, 3)
|
||||
return torch.sum(values)
|
||||
|
||||
onnx_program = self.export(
|
||||
TopKModel(), (torch.arange(1.0, 6.0, requires_grad=True),)
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_exported_program_torch_distributions_normal_Normal(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
self.normal = torch.distributions.normal.Normal(0, 1)
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return self.normal.sample(x.shape)
|
||||
|
||||
with torch.no_grad():
|
||||
exported_program = torch.export.export(
|
||||
Model(), args=(torch.randn(2),), strict=False
|
||||
)
|
||||
_ = self.export(exported_program)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"float8_type",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2,
|
||||
name="torch_float8_e5m2",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2fnuz,
|
||||
name="torch_float8_e5m2fnuz",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fn,
|
||||
name="torch_float8_e4m3fn",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fnuz,
|
||||
name="torch_float8_e4m3fnuz",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_float8_support(self, float8_type):
|
||||
class Float8Module(torch.nn.Module):
|
||||
def forward(self, input: torch.Tensor):
|
||||
input = input.to(float8_type)
|
||||
return input
|
||||
|
||||
_ = self.export(Float8Module(), (torch.randn(1, 2),))
|
||||
|
||||
def test_export_with_logging_logger(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoggingLoggerModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.log("abc")
|
||||
return x + 1
|
||||
|
||||
onnx_program = self.export(LoggingLoggerModule(), (torch.tensor(1),))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_export_with_hf_logging_logger(self):
|
||||
logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
class HFLoggingLoggerModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.warning_once("abc")
|
||||
return x + 1
|
||||
|
||||
onnx_program = self.export(HFLoggingLoggerModule(), (torch.tensor(1),))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_export_with_print(self):
|
||||
class PrintModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
print("abc")
|
||||
return x + 1
|
||||
|
||||
onnx_program = self.export(PrintModule(), (torch.tensor(1),))
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
import pytorch_test_common
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
@ -57,208 +55,6 @@ class TestFxPasses(common_utils.TestCase):
|
||||
nodes
|
||||
), f"Expected all names to be unique, got {nodes}"
|
||||
|
||||
def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
|
||||
@torch.library.custom_op(
|
||||
"mylibrary::foo_op", device_types="cpu", mutates_args=()
|
||||
)
|
||||
def foo_op(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + 1
|
||||
|
||||
@torch.library.custom_op(
|
||||
"mylibrary::bar_op", device_types="cpu", mutates_args=()
|
||||
)
|
||||
def bar_op(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + 2
|
||||
|
||||
@foo_op.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
@bar_op.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
def func(x, y, z):
|
||||
return foo_op(x) + bar_op(y) + z
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3)
|
||||
with self.assertRaises(torch.onnx.OnnxExporterError) as ctx:
|
||||
torch.onnx.dynamo_export(func, x, y, z)
|
||||
inner_exception = ctx.exception.__cause__
|
||||
self.assertRegex(
|
||||
str(inner_exception),
|
||||
r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op",
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestModularizePass(common_utils.TestCase):
|
||||
@pytorch_test_common.xfail(
|
||||
error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found",
|
||||
reason="optimizer",
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"is_exported_program",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="exported_program",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="nn_module",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_modularize_pass_succeeds_when_submodule_output_is_unused(
|
||||
self, is_exported_program
|
||||
):
|
||||
# This is an ill-formed model, but exporter must not crash.
|
||||
# It is illegal for submodule to have zero output. For modularization pass it can happen
|
||||
# when the submodule output is unused, so no inner node is connected to any outer
|
||||
# nodes.
|
||||
# However, this also means the entire submodule should be erased by DCE. Hence
|
||||
# it should never occur.
|
||||
#
|
||||
# Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.unused_relu = torch.nn.ReLU()
|
||||
self.used_gelu = torch.nn.GELU()
|
||||
|
||||
def forward(self, x, y):
|
||||
result = self.used_gelu(x + y)
|
||||
unused_relu_result = self.unused_relu(x) # noqa: F841
|
||||
return result
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
|
||||
model_proto = onnx_program.model_proto
|
||||
function_proto_names = [function.name for function in model_proto.functions]
|
||||
self.assertIn(
|
||||
"torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names
|
||||
)
|
||||
self.assertFalse(any("ReLU" in name for name in function_proto_names))
|
||||
|
||||
@pytorch_test_common.xfail(
|
||||
error_message="'torch_nn_modules_activation_ReLU_relu_1' not found",
|
||||
reason="optimizer",
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"is_exported_program",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="exported_program",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="nn_module",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times(
|
||||
self, is_exported_program
|
||||
):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x, y):
|
||||
out = x + y
|
||||
out = self.relu(out)
|
||||
out = out + x
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
|
||||
model_proto = onnx_program.model_proto
|
||||
function_proto_names = [function.name for function in model_proto.functions]
|
||||
self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names)
|
||||
self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names)
|
||||
|
||||
@pytorch_test_common.xfail(
|
||||
error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found",
|
||||
reason="optimizer",
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"is_exported_program",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="exported_program",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="nn_module",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers(
|
||||
self, is_exported_program
|
||||
):
|
||||
# Minified repro from basic_gnn_edgecnn.
|
||||
class InnerModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.inner_module = InnerModule()
|
||||
|
||||
def forward(self, x, y):
|
||||
out = x + y
|
||||
out = self.inner_module(out)
|
||||
out = out + x
|
||||
out = self.inner_module.relu(out)
|
||||
return out
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
|
||||
model_proto = onnx_program.model_proto
|
||||
function_proto_names = [function.name for function in model_proto.functions]
|
||||
self.assertIn(
|
||||
"torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names
|
||||
)
|
||||
self.assertIn(
|
||||
"torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names
|
||||
)
|
||||
# local module qualified name is unstable in test environment depending on different test
|
||||
# invocation methods.
|
||||
self.assertTrue(
|
||||
any("InnerModule_inner_module_1" in name for name in function_proto_names)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
||||
@ -1,322 +0,0 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
import onnx
|
||||
import onnx.inliner
|
||||
|
||||
import pytorch_test_common
|
||||
import transformers # type: ignore[import]
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.nn import functional as F
|
||||
from torch.onnx import dynamo_export, ExportOptions
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.export_options = ExportOptions()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def test_simple_function(self):
|
||||
def func(x):
|
||||
y = x + 1
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
_ = dynamo_export(
|
||||
func, torch.randn(1, 1, 2), export_options=self.export_options
|
||||
)
|
||||
|
||||
def test_empty(self):
|
||||
# Since `torch.empty` returns tensor with uninitialized data, we cannot
|
||||
# test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison.
|
||||
def func(x):
|
||||
return torch.empty(x.size(), dtype=torch.int64)
|
||||
|
||||
tensor_x = torch.randn(1, 1, 2)
|
||||
_ = dynamo_export(func, tensor_x, export_options=self.export_options)
|
||||
|
||||
def test_args_used_for_export_is_not_converted_to_fake_tensors(self):
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
tensor_x = torch.randn(1, 1, 2)
|
||||
tensor_y = torch.randn(1, 1, 2)
|
||||
_ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options)
|
||||
self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
|
||||
self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)
|
||||
|
||||
def test_mnist_exported_with_no_warnings(self):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
self.fc1 = nn.Linear(9216, 128, bias=False)
|
||||
self.fc2 = nn.Linear(128, 10, bias=False)
|
||||
|
||||
def forward(self, tensor_x: torch.Tensor):
|
||||
tensor_x = self.conv1(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = self.conv2(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = F.max_pool2d(tensor_x, 2)
|
||||
tensor_x = torch.flatten(tensor_x, 1)
|
||||
tensor_x = self.fc1(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = self.fc2(tensor_x)
|
||||
output = F.log_softmax(tensor_x, dim=1)
|
||||
return output
|
||||
|
||||
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
|
||||
onnx_program = dynamo_export(MNISTModel(), tensor_x)
|
||||
assert onnx_program is not None
|
||||
|
||||
def test_trace_only_op_with_evaluator(self):
|
||||
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
|
||||
|
||||
class ArgminArgmaxModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return (
|
||||
torch.argmin(input),
|
||||
torch.argmax(input),
|
||||
torch.argmin(input, keepdim=True),
|
||||
torch.argmax(input, keepdim=True),
|
||||
torch.argmin(input, dim=0, keepdim=True),
|
||||
torch.argmax(input, dim=1, keepdim=True),
|
||||
)
|
||||
|
||||
_ = dynamo_export(
|
||||
ArgminArgmaxModel(), model_input, export_options=self.export_options
|
||||
)
|
||||
|
||||
def test_multiple_outputs_op_with_evaluator(self):
|
||||
class TopKModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
values, _ = torch.topk(x, 3)
|
||||
return torch.sum(values)
|
||||
|
||||
x = torch.arange(1.0, 6.0, requires_grad=True)
|
||||
|
||||
_ = dynamo_export(TopKModel(), x, export_options=self.export_options)
|
||||
|
||||
def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
self.fc1 = nn.Linear(9216, 128, bias=False)
|
||||
self.buffer = torch.nn.Buffer(torch.randn(1, 128))
|
||||
|
||||
def forward(self, tensor_x: torch.Tensor):
|
||||
tensor_x = self.conv2(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = F.max_pool2d(tensor_x, 2)
|
||||
tensor_x = torch.flatten(tensor_x, 1)
|
||||
tensor_x = self.fc1(tensor_x)
|
||||
tensor_x = tensor_x + self.buffer
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
return tensor_x
|
||||
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.submodule = SubModule()
|
||||
self.fc2 = nn.Linear(128, 10, bias=False)
|
||||
|
||||
def forward(self, tensor_x: torch.Tensor):
|
||||
tensor_x = self.conv1(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = self.submodule(tensor_x)
|
||||
tensor_x = self.fc2(tensor_x)
|
||||
output = F.log_softmax(tensor_x, dim=1)
|
||||
return output
|
||||
|
||||
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
|
||||
|
||||
model = MNISTModel()
|
||||
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
|
||||
model_proto = onnx_program.model_proto
|
||||
|
||||
# NOTE: initializers could be optimized away by onnx optimizer
|
||||
onnx_initilizers = {init.name for init in model_proto.graph.initializer}
|
||||
torch_weights = {*model.state_dict().keys()}
|
||||
self.assertTrue(onnx_initilizers.issubset(torch_weights))
|
||||
|
||||
def test_fake_tensor_mode_simple(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear(x)
|
||||
return out
|
||||
|
||||
with torch.onnx.enable_fake_mode() as fake_context:
|
||||
x = torch.rand(5, 2, 2)
|
||||
model = Model()
|
||||
export_options = ExportOptions(fake_context=fake_context)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model, x, export_options=export_options
|
||||
)
|
||||
|
||||
assert (
|
||||
onnx_program is not None
|
||||
), "ONNXProgram must be created on successful export"
|
||||
|
||||
onnx_program.apply_weights(Model().state_dict())
|
||||
|
||||
assert (
|
||||
onnx_program.model_proto is not None
|
||||
), "A model protobuf must be created on a successful export"
|
||||
onnx.checker.check_model(onnx_program.model_proto, full_check=True)
|
||||
|
||||
def test_exported_program_torch_distributions_normal_Normal(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
self.normal = torch.distributions.normal.Normal(0, 1)
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return self.normal.sample(x.shape)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with torch.no_grad():
|
||||
exported_program = torch.export.export(Model(), args=(x,), strict=True)
|
||||
_ = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
x,
|
||||
)
|
||||
|
||||
def test_aten_div_no_opmath_type_promotion(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input / 2
|
||||
|
||||
model = Model()
|
||||
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16)
|
||||
|
||||
model_proto = torch.onnx.dynamo_export(model, input).model_proto
|
||||
model_proto = onnx.inliner.inline_local_functions(model_proto)
|
||||
div_node = next(
|
||||
node for node in model_proto.graph.node if node.op_type == "Div"
|
||||
)
|
||||
# The input of Div node should be the input of the model,
|
||||
# with no Cast node in between.
|
||||
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"float8_type",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2,
|
||||
name="torch_float8_e5m2",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2fnuz,
|
||||
name="torch_float8_e5m2fnuz",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fn,
|
||||
name="torch_float8_e4m3fn",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fnuz,
|
||||
name="torch_float8_e4m3fnuz",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_float8_support(self, float8_type):
|
||||
class Float8Module(torch.nn.Module):
|
||||
def forward(self, input: torch.Tensor):
|
||||
input = input.to(float8_type)
|
||||
return input + torch.tensor(1.0, dtype=float8_type)
|
||||
|
||||
# NOTE: shape inference error raised in optimizer due to unsupported dtype
|
||||
with self.assertWarnsOnceRegex(
|
||||
UserWarning, "ONNXScript optimizer failed. Skipping optimization."
|
||||
):
|
||||
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
|
||||
|
||||
def test_export_with_logging_logger(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoggingLoggerModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.log("abc")
|
||||
return x + 1
|
||||
|
||||
input = torch.randn(2, 3)
|
||||
model = LoggingLoggerModule()
|
||||
_ = torch.onnx.dynamo_export(model, input)
|
||||
|
||||
def test_export_with_hf_logging_logger(self):
|
||||
logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
class HFLoggingLoggerModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.warning_once("abc")
|
||||
return x + 1
|
||||
|
||||
input = torch.randn(2, 3)
|
||||
model = HFLoggingLoggerModule()
|
||||
_ = torch.onnx.dynamo_export(model, input)
|
||||
|
||||
def test_checkpoint_cast(self):
|
||||
model_id = "openai/whisper-large-v3"
|
||||
feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128)
|
||||
batch = 4
|
||||
|
||||
with torch.onnx.enable_fake_mode() as ctx:
|
||||
model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=False, use_safetensors=False
|
||||
)
|
||||
input = {
|
||||
"input_features": torch.randn(
|
||||
(
|
||||
batch,
|
||||
feature_extractor.feature_size,
|
||||
feature_extractor.nb_max_frames,
|
||||
)
|
||||
),
|
||||
"decoder_input_ids": torch.tensor([[1, 1]]) * 8001,
|
||||
"return_dict": False,
|
||||
}
|
||||
|
||||
export_options = torch.onnx.ExportOptions(fake_context=ctx)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model, **input, export_options=export_options
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
|
||||
onnx_program.save(
|
||||
tmp_onnx_file.name,
|
||||
keep_initializers_as_inputs=True,
|
||||
include_initializers=False,
|
||||
)
|
||||
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
|
||||
|
||||
def test_export_with_print(self):
|
||||
class PrintModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
print("abc")
|
||||
return x + 1
|
||||
|
||||
input = torch.randn(2, 3)
|
||||
model = PrintModule()
|
||||
_ = torch.onnx.dynamo_export(model, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
@ -1,76 +0,0 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
import onnx.inliner
|
||||
|
||||
import pytorch_test_common
|
||||
|
||||
import torch
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
|
||||
inlined = onnx.inliner.inline_local_functions(model)
|
||||
for node in inlined.graph.node:
|
||||
if node.op_type == op_type:
|
||||
return
|
||||
raise AssertionError(f"Op {op_type} not found in model")
|
||||
|
||||
|
||||
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
def test_upsample_bilinear2d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
|
||||
|
||||
def forward(self, x):
|
||||
return self.upsample(x)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_upsample_bilinear2d_output_size(self):
|
||||
def func(x: torch.Tensor):
|
||||
return torch.nn.functional.interpolate(x, size=(4, 4), mode="bilinear")
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_upsample_trilinear3d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(scale_factor=2, mode="trilinear")
|
||||
|
||||
def forward(self, x):
|
||||
return self.upsample(x)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_upsample_trilinear3d_output_size(self):
|
||||
def func(x: torch.Tensor):
|
||||
return torch.nn.functional.interpolate(x, size=(4, 4, 4), mode="trilinear")
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2, 3))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_instance_norm(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.instance_norm(x)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain an InstanceNormalization op
|
||||
# instead of BatchNormalization op w/ training=True.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
@ -46,4 +46,5 @@ def _load_boolean_flag(
|
||||
USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag(
|
||||
"TORCH_ONNX_USE_EXPERIMENTAL_LOGIC",
|
||||
this_will="use ExportedProgram and the new torch.onnx export logic",
|
||||
default=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user