[ONNX] Support renaming in dynamic axes to shapes conversion (#165769)

Discovered in ##165748

This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang
2025-10-18 01:11:16 +00:00
committed by PyTorch MergeBot
parent e9f4999985
commit 543ddbf44c
2 changed files with 56 additions and 3 deletions

View File

@ -202,6 +202,51 @@ class TestExportAPIDynamo(common_utils.TestCase):
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
)
def test_from_dynamic_axes_to_dynamic_shapes_deprecation_warning(self):
with self.assertWarnsRegex(
DeprecationWarning,
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
"This function converts 'dynamic_axes' format \\(including custom axis names\\) to 'dynamic_shapes' format. "
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
):
self.assert_export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"x": [0, 1, 2],
"b": [0, 1, 2],
},
)
def test_from_dynamic_axes_to_dynamic_shapes_keeps_custom_axis_names(self):
model = SampleModelForDynamicShapes()
input = (
torch.randn(2, 2, 3),
{"b": torch.randn(2, 2, 3)},
)
dynamic_axes = {
"x": {0: "customx_x_0", 1: "customx_x_1", 2: "customx_x_2"},
"b": {0: "customb_b_0", 1: "customb_b_1", 2: "customb_b_2"},
"x_out": {0: "customx_out_x_0", 1: "customx_out_x_1", 2: "customx_out_x_2"},
"b_out": {0: "customb_out_b_0", 1: "customb_out_b_1", 2: "customb_out_b_2"},
}
onnx_program = torch.onnx.export(
model,
input,
dynamic_axes=dynamic_axes,
input_names=["x", "b"],
output_names=["x_out", "b_out"],
dynamo=True,
)
# Check whether the dynamic dimension names are preserved
self.assertIs(onnx_program.model.graph.inputs[0].shape[0].value, "customx_x_0")
self.assertIs(onnx_program.model.graph.inputs[0].shape[1].value, "customx_x_1")
self.assertIs(onnx_program.model.graph.inputs[0].shape[2].value, "customx_x_2")
self.assertIs(onnx_program.model.graph.inputs[1].shape[0].value, "customb_b_0")
self.assertIs(onnx_program.model.graph.inputs[1].shape[1].value, "customb_b_1")
self.assertIs(onnx_program.model.graph.inputs[1].shape[2].value, "customb_b_2")
def test_saved_f_exists_after_export(self):
with common_utils.TemporaryFileName(suffix=".onnx") as path:
_ = torch.onnx.export(

View File

@ -39,6 +39,15 @@ def from_dynamic_axes_to_dynamic_shapes(
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
"""
warnings.warn(
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
"This function converts 'dynamic_axes' format (including custom axis names) to 'dynamic_shapes' format. "
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
DeprecationWarning,
stacklevel=2,
)
# https://github.com/pytorch/pytorch/pull/128371
# 1. The function does not need to provide dynamic_shapes to torch.export.export
if dynamic_axes is None:
@ -62,9 +71,8 @@ def from_dynamic_axes_to_dynamic_shapes(
raise ValueError(
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
)
dynamic_shapes[input_name] = {
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
}
# str will be converted to Dim.DYNAMIC in convert_str_to_export_dim
dynamic_shapes[input_name] = axes
elif isinstance(axes, list):
if any(not isinstance(k, int) for k in axes):
raise ValueError(