[ONNX] Default to dynamo export (#159646)

Set dynamo=True and enable fallback.

1. Implemented the compatible behavior where BytesIO objects as `f` is accepted
2. Update tests to explicitly set dynamo=False

#151693

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159646
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-09-02 22:45:55 +00:00
committed by PyTorch MergeBot
parent e4bd0ff4f8
commit bd39e47fee
9 changed files with 126 additions and 1257 deletions

View File

@ -199,6 +199,7 @@ class TestDynamicShapes(common_utils.TestCase):
filename,
dynamic_axes=dynamic_axes,
input_names=input_names,
dynamo=False,
)
onnx_model = onnx.load(filename)

View File

@ -67,6 +67,7 @@ def check_onnx_opsets_operator(
training=training,
input_names=input_names,
dynamic_axes=dynamic_axes,
dynamo=False,
)
model = onnx.load(io.BytesIO(f.getvalue()))
check_onnx_opset_operator(model, ops[opset_version], opset_version)

View File

@ -86,14 +86,20 @@ class TestONNXScriptExport(common_utils.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model_selu = torch.nn.SELU()
selu_onnx = io.BytesIO()
torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version)
torch.onnx.export(
model_selu, x, selu_onnx, opset_version=self.opset_version, dynamo=False
)
N, C = 3, 4
y = torch.randn(N, C)
model_layer_norm = torch.nn.LayerNorm(C)
layer_norm_onnx = io.BytesIO()
torch.onnx.export(
model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version
model_layer_norm,
y,
layer_norm_onnx,
opset_version=self.opset_version,
dynamo=False,
)
# 4. test on models
@ -156,7 +162,11 @@ class TestONNXScriptExport(common_utils.TestCase):
saved_model = io.BytesIO()
torch.onnx.export(
torch.jit.script(model), inputs, f=saved_model, opset_version=15
torch.jit.script(model),
inputs,
f=saved_model,
opset_version=15,
dynamo=False,
)
loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
self.assertEqual(len(loop_selu_proto.functions), 1)

File diff suppressed because it is too large Load Diff

View File

@ -897,7 +897,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
# export succeeds, but running ORT through run_test would fail because the exported model
# has the inputs flattened into 3 inputs.
torch.onnx.export(
model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version
model,
(x, {"y": (y0, y1)}),
io.BytesIO(),
opset_version=self.opset_version,
dynamo=False,
)
def test_primitive_input_integer(self):
@ -10791,6 +10795,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
opset_version=self.opset_version,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
ort_outs = verification._run_onnx(ort_sess, (x,))
@ -10806,6 +10811,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
opset_version=self.opset_version,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamo=False,
)
ort_outs = verification._run_onnx(ort_sess, (x,))
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
@ -10839,6 +10845,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
opset_version=self.opset_version,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
ort_outs = verification._run_onnx(ort_sess, (x,))
@ -10864,6 +10871,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
opset_version=self.opset_version,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
ort_outs = verification._run_onnx(ort_sess, (x,))
@ -12624,7 +12632,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
model_onnx = io.BytesIO()
torch.onnx.export(
model_export, dummy_input, model_onnx, opset_version=self.opset_version
model_export,
dummy_input,
model_onnx,
opset_version=self.opset_version,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
@ -12655,7 +12667,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
model_onnx = io.BytesIO()
test_inputs = ()
torch.onnx.export(
model_export, test_inputs, model_onnx, opset_version=self.opset_version
model_export,
test_inputs,
model_onnx,
opset_version=self.opset_version,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
ort_out = verification._run_onnx(ort_sess, inputs=test_inputs)
@ -12698,7 +12714,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
model_onnx = io.BytesIO()
torch.onnx.export(
model_export, dummy_input, model_onnx, opset_version=self.opset_version
model_export,
dummy_input,
model_onnx,
opset_version=self.opset_version,
dynamo=False,
)
ort_sess = verification._ort_session(model_onnx)
@ -13705,6 +13725,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
# Ensure condition is not constant
dynamic_axes={"x": {0: dynamic_axis_name}},
input_names=["x"],
dynamo=False,
)
exported = onnx.load_from_string(f.getvalue())
expected_elem_type = JitScalarType.from_value(x).onnx_type()

View File

@ -396,6 +396,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
dynamo=False,
)
model_proto = onnx.load(io.BytesIO(f.getvalue()))
@ -430,6 +431,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
dynamo=False,
)
model_proto = onnx.load(io.BytesIO(f.getvalue()))
@ -468,6 +470,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
custom_opsets={"com.microsoft": 1},
input_names=["x"],
dynamic_axes={"x": {0: "batch"}},
dynamo=False,
)
model_proto = onnx.load(io.BytesIO(f.getvalue()))
@ -508,6 +511,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
dynamo=False,
)
model_proto = onnx.load(io.BytesIO(f.getvalue()))

View File

@ -111,7 +111,9 @@ class TestUtilityFuns(_BaseTestCase):
x = torch.randn(3, 4)
f = io.BytesIO()
try:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
torch.onnx.export(
MyModule(), x, f, opset_version=self.opset_version, dynamo=False
)
except ValueError:
self.assertFalse(torch.onnx.is_in_onnx_export())
@ -638,7 +640,7 @@ class TestUtilityFuns(_BaseTestCase):
model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
x = torch.randn(1, 32, 224, 224)
f = io.BytesIO()
torch.onnx.export(model, x, f)
torch.onnx.export(model, x, f, dynamo=False)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(len(onnx_model.graph.initializer), 0)
@ -651,10 +653,17 @@ class TestUtilityFuns(_BaseTestCase):
def is_model_stripped(f, verbose=None):
if verbose is None:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
torch.onnx.export(
MyModule(), x, f, opset_version=self.opset_version, dynamo=False
)
else:
torch.onnx.export(
MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
MyModule(),
x,
f,
verbose=verbose,
opset_version=self.opset_version,
dynamo=False,
)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
@ -677,7 +686,9 @@ class TestUtilityFuns(_BaseTestCase):
"exporter, please use 'attribute' module to "
"unwrap model from torch.nn.DataParallel. Try ",
):
torch.onnx.export(model, x, f, opset_version=self.opset_version)
torch.onnx.export(
model, x, f, opset_version=self.opset_version, dynamo=False
)
@skipIfUnsupportedMinOpsetVersion(11)
def test_sequence_dim(self):
@ -701,6 +712,7 @@ class TestUtilityFuns(_BaseTestCase):
opset_version=self.opset_version,
input_names=["x", "y"],
dynamic_axes={"y": [1]},
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
@ -712,7 +724,9 @@ class TestUtilityFuns(_BaseTestCase):
# Case 2: no dynamic axes.
f = io.BytesIO()
y = torch.randn(2, 3)
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
torch.onnx.export(
script_model, (x, y), f, opset_version=self.opset_version, dynamo=False
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
@ -739,6 +753,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING,
dynamo=False,
)
# verify that the model state is preserved
self.assertEqual(model.training, old_state)
@ -752,6 +767,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
training=torch.onnx.TrainingMode.EVAL,
dynamo=False,
)
# verify that the model state is preserved
self.assertEqual(model.training, old_state)
@ -779,7 +795,9 @@ class TestUtilityFuns(_BaseTestCase):
# jit.freeze removes the training attribute in the module
module = torch.jit.freeze(module)
torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
torch.onnx.export(
module, (x,), io.BytesIO(), opset_version=self.opset_version, dynamo=False
)
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function(self):
@ -828,6 +846,7 @@ class TestUtilityFuns(_BaseTestCase):
torch.nn.Dropout,
torch.nn.LayerNorm,
},
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -862,6 +881,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
export_modules_as_functions={torch.nn.CELU},
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -877,6 +897,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
export_modules_as_functions=set(),
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -891,6 +912,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
export_modules_as_functions=True,
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -927,6 +949,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
export_modules_as_functions={NWithOverloads},
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -956,6 +979,7 @@ class TestUtilityFuns(_BaseTestCase):
export_modules_as_functions=True,
opset_version=self.opset_version,
do_constant_folding=False,
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -988,6 +1012,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
export_modules_as_functions=True,
opset_version=self.opset_version,
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
@ -1053,6 +1078,7 @@ class TestUtilityFuns(_BaseTestCase):
export_modules_as_functions=True,
opset_version=self.opset_version,
verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'`
dynamo=False,
)
def test_node_scope(self):
@ -1297,6 +1323,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
@ -1317,7 +1344,9 @@ class TestUtilityFuns(_BaseTestCase):
model = torch.nn.GELU(approximate="none")
x = torch.randn(3, 3)
f = io.BytesIO()
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
torch.onnx.export(
model, (x,), f, opset_version=self.opset_version, dynamo=False
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
@ -1344,6 +1373,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
@ -1647,6 +1677,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
opset_version=self.opset_version,
keep_initializers_as_inputs=True,
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.input[1].name, "in_weight")
@ -1679,13 +1710,19 @@ class TestUtilityFuns(_BaseTestCase):
]
f = io.BytesIO()
torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
torch.onnx.export(
module, torch.ones(1, 10), f, output_names=["y"], dynamo=False
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
for n in onnx_model.graph.node:
self.assertIn(n.name, ref_node_names)
torch.onnx.export(
torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
torch.jit.script(module),
torch.ones(1, 10),
f,
output_names=["y"],
dynamo=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
for n in onnx_model.graph.node:
@ -1728,6 +1765,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
training=TrainingMode.TRAINING,
opset_version=self.opset_version,
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
@ -1740,6 +1778,7 @@ class TestUtilityFuns(_BaseTestCase):
f,
training=TrainingMode.PRESERVE,
opset_version=self.opset_version,
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
@ -1747,7 +1786,9 @@ class TestUtilityFuns(_BaseTestCase):
# Test eval mode.
model.eval()
f = io.BytesIO()
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
torch.onnx.export(
model, (x,), f, opset_version=self.opset_version, dynamo=False
)
graph = onnx.load(io.BytesIO(f.getvalue()))
param_name_set.remove("param2")
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
@ -1776,7 +1817,9 @@ class TestUtilityFuns(_BaseTestCase):
x = torch.randn(3, 3, device=torch.device("cpu"))
y = torch.randn(3, 3, device=torch.device("cuda"))
f = io.BytesIO()
torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
torch.onnx.export(
Model(), (x, y), f, opset_version=self.opset_version, dynamo=False
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
@ -1817,6 +1860,7 @@ class TestUtilityFuns(_BaseTestCase):
dynamic_axes=dynamic_axes,
verbose=True,
keep_initializers_as_inputs=True,
dynamo=False,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
@ -1844,7 +1888,7 @@ class TestUtilityFuns(_BaseTestCase):
f = io.BytesIO()
x = torch.randn(1, 32, 224, 224)
torch.onnx.export(Model(), x, f)
torch.onnx.export(Model(), x, f, dynamo=False)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
# aten::upsample converts to onnx::resize
resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
@ -1876,7 +1920,7 @@ class TestUtilityFuns(_BaseTestCase):
self.assertExpectedRaisesInline(
AssertionError,
lambda: torch.onnx.export(
model, (x,), f, opset_version=_onnx_opset_version
model, (x,), f, opset_version=_onnx_opset_version, dynamo=False
),
(
"A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "

View File

@ -74,7 +74,7 @@ def export(
| Mapping[str, Sequence[int]]
| None = None,
keep_initializers_as_inputs: bool = False,
dynamo: bool = False,
dynamo: bool = True,
# Dynamo only options
external_data: bool = True,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
@ -86,7 +86,7 @@ def export(
profile: bool = False,
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
fallback: bool = True,
# Deprecated options
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,

View File

@ -4,6 +4,7 @@
# mypy: disable-error-code=attr-defined
from __future__ import annotations
import io
import logging
import warnings
from collections.abc import Mapping, Sequence
@ -11,7 +12,7 @@ from typing import Any, Callable, TYPE_CHECKING
import torch
from torch.onnx import _constants as onnx_constants
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.exporter import (
_constants,
_core,
@ -61,12 +62,12 @@ def export_compat(
keep_initializers_as_inputs: bool = False,
external_data: bool = True,
report: bool = False,
optimize: bool = False,
optimize: bool = True,
verify: bool = False,
profile: bool = False,
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
fallback: bool = True,
# Legacy export parameters for fallback
legacy_export_kwargs: dict[str, Any] | None = None,
) -> _onnx_program.ONNXProgram:
@ -211,11 +212,23 @@ def export_compat(
onnx_program.optimize()
if f is not None:
onnx_program.save(
f,
include_initializers=export_params,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,
)
if isinstance(f, io.BytesIO):
# For legacy export compatibility, we allow f to be a BytesIO object.
# This is not explicitly supported but we may need to maintain the
# behavior indefinitely.
warnings.warn(
"Saving ONNX model to a BytesIO object is deprecated. "
"Please use a file path instead.",
DeprecationWarning,
stacklevel=2,
)
onnx.save(onnx_program.model_proto, f)
else:
onnx_program.save(
f,
include_initializers=export_params,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,
)
return onnx_program