mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Default to dynamo export (#159646)
Set dynamo=True and enable fallback. 1. Implemented the compatible behavior where BytesIO objects as `f` is accepted 2. Update tests to explicitly set dynamo=False #151693 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159646 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4bd0ff4f8
commit
bd39e47fee
@ -199,6 +199,7 @@ class TestDynamicShapes(common_utils.TestCase):
|
||||
filename,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=input_names,
|
||||
dynamo=False,
|
||||
)
|
||||
onnx_model = onnx.load(filename)
|
||||
|
||||
|
@ -67,6 +67,7 @@ def check_onnx_opsets_operator(
|
||||
training=training,
|
||||
input_names=input_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
dynamo=False,
|
||||
)
|
||||
model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
check_onnx_opset_operator(model, ops[opset_version], opset_version)
|
||||
|
@ -86,14 +86,20 @@ class TestONNXScriptExport(common_utils.TestCase):
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
model_selu = torch.nn.SELU()
|
||||
selu_onnx = io.BytesIO()
|
||||
torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
model_selu, x, selu_onnx, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
|
||||
N, C = 3, 4
|
||||
y = torch.randn(N, C)
|
||||
model_layer_norm = torch.nn.LayerNorm(C)
|
||||
layer_norm_onnx = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version
|
||||
model_layer_norm,
|
||||
y,
|
||||
layer_norm_onnx,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
# 4. test on models
|
||||
@ -156,7 +162,11 @@ class TestONNXScriptExport(common_utils.TestCase):
|
||||
|
||||
saved_model = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
torch.jit.script(model), inputs, f=saved_model, opset_version=15
|
||||
torch.jit.script(model),
|
||||
inputs,
|
||||
f=saved_model,
|
||||
opset_version=15,
|
||||
dynamo=False,
|
||||
)
|
||||
loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
|
||||
self.assertEqual(len(loop_selu_proto.functions), 1)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -897,7 +897,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# export succeeds, but running ORT through run_test would fail because the exported model
|
||||
# has the inputs flattened into 3 inputs.
|
||||
torch.onnx.export(
|
||||
model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version
|
||||
model,
|
||||
(x, {"y": (y0, y1)}),
|
||||
io.BytesIO(),
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
def test_primitive_input_integer(self):
|
||||
@ -10791,6 +10795,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
opset_version=self.opset_version,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
ort_outs = verification._run_onnx(ort_sess, (x,))
|
||||
@ -10806,6 +10811,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
opset_version=self.opset_version,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_outs = verification._run_onnx(ort_sess, (x,))
|
||||
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
|
||||
@ -10839,6 +10845,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
opset_version=self.opset_version,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
ort_outs = verification._run_onnx(ort_sess, (x,))
|
||||
@ -10864,6 +10871,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
opset_version=self.opset_version,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
ort_outs = verification._run_onnx(ort_sess, (x,))
|
||||
@ -12624,7 +12632,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
|
||||
model_onnx = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
model_export, dummy_input, model_onnx, opset_version=self.opset_version
|
||||
model_export,
|
||||
dummy_input,
|
||||
model_onnx,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
|
||||
@ -12655,7 +12667,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_onnx = io.BytesIO()
|
||||
test_inputs = ()
|
||||
torch.onnx.export(
|
||||
model_export, test_inputs, model_onnx, opset_version=self.opset_version
|
||||
model_export,
|
||||
test_inputs,
|
||||
model_onnx,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
ort_out = verification._run_onnx(ort_sess, inputs=test_inputs)
|
||||
@ -12698,7 +12714,11 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
|
||||
model_onnx = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
model_export, dummy_input, model_onnx, opset_version=self.opset_version
|
||||
model_export,
|
||||
dummy_input,
|
||||
model_onnx,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
ort_sess = verification._ort_session(model_onnx)
|
||||
|
||||
@ -13705,6 +13725,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# Ensure condition is not constant
|
||||
dynamic_axes={"x": {0: dynamic_axis_name}},
|
||||
input_names=["x"],
|
||||
dynamo=False,
|
||||
)
|
||||
exported = onnx.load_from_string(f.getvalue())
|
||||
expected_elem_type = JitScalarType.from_value(x).onnx_type()
|
||||
|
@ -396,6 +396,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
model_proto = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -430,6 +431,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
model_proto = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -468,6 +470,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
input_names=["x"],
|
||||
dynamic_axes={"x": {0: "batch"}},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
model_proto = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -508,6 +511,7 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
model_proto = onnx.load(io.BytesIO(f.getvalue()))
|
||||
|
@ -111,7 +111,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
x = torch.randn(3, 4)
|
||||
f = io.BytesIO()
|
||||
try:
|
||||
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
MyModule(), x, f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
except ValueError:
|
||||
self.assertFalse(torch.onnx.is_in_onnx_export())
|
||||
|
||||
@ -638,7 +640,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||
x = torch.randn(1, 32, 224, 224)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, x, f)
|
||||
torch.onnx.export(model, x, f, dynamo=False)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertEqual(len(onnx_model.graph.initializer), 0)
|
||||
|
||||
@ -651,10 +653,17 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
|
||||
def is_model_stripped(f, verbose=None):
|
||||
if verbose is None:
|
||||
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
MyModule(), x, f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
else:
|
||||
torch.onnx.export(
|
||||
MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
|
||||
MyModule(),
|
||||
x,
|
||||
f,
|
||||
verbose=verbose,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
model_strip = copy.copy(model)
|
||||
@ -677,7 +686,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
"exporter, please use 'attribute' module to "
|
||||
"unwrap model from torch.nn.DataParallel. Try ",
|
||||
):
|
||||
torch.onnx.export(model, x, f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
model, x, f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_sequence_dim(self):
|
||||
@ -701,6 +712,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
opset_version=self.opset_version,
|
||||
input_names=["x", "y"],
|
||||
dynamic_axes={"y": [1]},
|
||||
dynamo=False,
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
loop_output_value_info_proto = onnx_model.graph.output[0]
|
||||
@ -712,7 +724,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
# Case 2: no dynamic axes.
|
||||
f = io.BytesIO()
|
||||
y = torch.randn(2, 3)
|
||||
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
script_model, (x, y), f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
loop_output_value_info_proto = onnx_model.graph.output[0]
|
||||
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
|
||||
@ -739,6 +753,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamo=False,
|
||||
)
|
||||
# verify that the model state is preserved
|
||||
self.assertEqual(model.training, old_state)
|
||||
@ -752,6 +767,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
dynamo=False,
|
||||
)
|
||||
# verify that the model state is preserved
|
||||
self.assertEqual(model.training, old_state)
|
||||
@ -779,7 +795,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
# jit.freeze removes the training attribute in the module
|
||||
module = torch.jit.freeze(module)
|
||||
|
||||
torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
module, (x,), io.BytesIO(), opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(15)
|
||||
def test_local_function(self):
|
||||
@ -828,6 +846,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
torch.nn.Dropout,
|
||||
torch.nn.LayerNorm,
|
||||
},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -862,6 +881,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
export_modules_as_functions={torch.nn.CELU},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -877,6 +897,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
export_modules_as_functions=set(),
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -891,6 +912,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
export_modules_as_functions=True,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -927,6 +949,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
export_modules_as_functions={NWithOverloads},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -956,6 +979,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
export_modules_as_functions=True,
|
||||
opset_version=self.opset_version,
|
||||
do_constant_folding=False,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -988,6 +1012,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
export_modules_as_functions=True,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -1053,6 +1078,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
export_modules_as_functions=True,
|
||||
opset_version=self.opset_version,
|
||||
verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'`
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
def test_node_scope(self):
|
||||
@ -1297,6 +1323,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -1317,7 +1344,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
model = torch.nn.GELU(approximate="none")
|
||||
x = torch.randn(3, 3)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
model, (x,), f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
|
||||
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
|
||||
@ -1344,6 +1373,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -1647,6 +1677,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
opset_version=self.opset_version,
|
||||
keep_initializers_as_inputs=True,
|
||||
dynamo=False,
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertEqual(graph.graph.input[1].name, "in_weight")
|
||||
@ -1679,13 +1710,19 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
]
|
||||
f = io.BytesIO()
|
||||
|
||||
torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
|
||||
torch.onnx.export(
|
||||
module, torch.ones(1, 10), f, output_names=["y"], dynamo=False
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
for n in onnx_model.graph.node:
|
||||
self.assertIn(n.name, ref_node_names)
|
||||
|
||||
torch.onnx.export(
|
||||
torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
|
||||
torch.jit.script(module),
|
||||
torch.ones(1, 10),
|
||||
f,
|
||||
output_names=["y"],
|
||||
dynamo=False,
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
for n in onnx_model.graph.node:
|
||||
@ -1728,6 +1765,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
training=TrainingMode.TRAINING,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
|
||||
@ -1740,6 +1778,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
f,
|
||||
training=TrainingMode.PRESERVE,
|
||||
opset_version=self.opset_version,
|
||||
dynamo=False,
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
|
||||
@ -1747,7 +1786,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
# Test eval mode.
|
||||
model.eval()
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
model, (x,), f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
param_name_set.remove("param2")
|
||||
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
|
||||
@ -1776,7 +1817,9 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
x = torch.randn(3, 3, device=torch.device("cpu"))
|
||||
y = torch.randn(3, 3, device=torch.device("cuda"))
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
|
||||
torch.onnx.export(
|
||||
Model(), (x, y), f, opset_version=self.opset_version, dynamo=False
|
||||
)
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
|
||||
|
||||
@ -1817,6 +1860,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
dynamic_axes=dynamic_axes,
|
||||
verbose=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
graph = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -1844,7 +1888,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
|
||||
f = io.BytesIO()
|
||||
x = torch.randn(1, 32, 224, 224)
|
||||
torch.onnx.export(Model(), x, f)
|
||||
torch.onnx.export(Model(), x, f, dynamo=False)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
# aten::upsample converts to onnx::resize
|
||||
resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
|
||||
@ -1876,7 +1920,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
self.assertExpectedRaisesInline(
|
||||
AssertionError,
|
||||
lambda: torch.onnx.export(
|
||||
model, (x,), f, opset_version=_onnx_opset_version
|
||||
model, (x,), f, opset_version=_onnx_opset_version, dynamo=False
|
||||
),
|
||||
(
|
||||
"A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
|
||||
|
@ -74,7 +74,7 @@ def export(
|
||||
| Mapping[str, Sequence[int]]
|
||||
| None = None,
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
dynamo: bool = False,
|
||||
dynamo: bool = True,
|
||||
# Dynamo only options
|
||||
external_data: bool = True,
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
||||
@ -86,7 +86,7 @@ def export(
|
||||
profile: bool = False,
|
||||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
fallback: bool = True,
|
||||
# Deprecated options
|
||||
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
|
||||
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
|
||||
|
@ -4,6 +4,7 @@
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
@ -11,7 +12,7 @@ from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants as onnx_constants
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import (
|
||||
_constants,
|
||||
_core,
|
||||
@ -61,12 +62,12 @@ def export_compat(
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
external_data: bool = True,
|
||||
report: bool = False,
|
||||
optimize: bool = False,
|
||||
optimize: bool = True,
|
||||
verify: bool = False,
|
||||
profile: bool = False,
|
||||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
fallback: bool = True,
|
||||
# Legacy export parameters for fallback
|
||||
legacy_export_kwargs: dict[str, Any] | None = None,
|
||||
) -> _onnx_program.ONNXProgram:
|
||||
@ -211,11 +212,23 @@ def export_compat(
|
||||
onnx_program.optimize()
|
||||
|
||||
if f is not None:
|
||||
onnx_program.save(
|
||||
f,
|
||||
include_initializers=export_params,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
external_data=external_data,
|
||||
)
|
||||
if isinstance(f, io.BytesIO):
|
||||
# For legacy export compatibility, we allow f to be a BytesIO object.
|
||||
# This is not explicitly supported but we may need to maintain the
|
||||
# behavior indefinitely.
|
||||
warnings.warn(
|
||||
"Saving ONNX model to a BytesIO object is deprecated. "
|
||||
"Please use a file path instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
onnx.save(onnx_program.model_proto, f)
|
||||
else:
|
||||
onnx_program.save(
|
||||
f,
|
||||
include_initializers=export_params,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
external_data=external_data,
|
||||
)
|
||||
|
||||
return onnx_program
|
||||
|
Reference in New Issue
Block a user