From 543ddbf44c06640b424abf72a6469dddc829809f Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Sat, 18 Oct 2025 01:11:16 +0000 Subject: [PATCH] [ONNX] Support renaming in dynamic axes to shapes conversion (#165769) Discovered in ##165748 This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 45 +++++++++++++++++++ .../_internal/exporter/_dynamic_shapes.py | 14 ++++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 24a9176bbe5b..7e6a487e18f5 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -202,6 +202,51 @@ class TestExportAPIDynamo(common_utils.TestCase): dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]}, ) + def test_from_dynamic_axes_to_dynamic_shapes_deprecation_warning(self): + with self.assertWarnsRegex( + DeprecationWarning, + "from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. " + "This function converts 'dynamic_axes' format \\(including custom axis names\\) to 'dynamic_shapes' format. " + "Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.", + ): + self.assert_export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "x": [0, 1, 2], + "b": [0, 1, 2], + }, + ) + + def test_from_dynamic_axes_to_dynamic_shapes_keeps_custom_axis_names(self): + model = SampleModelForDynamicShapes() + input = ( + torch.randn(2, 2, 3), + {"b": torch.randn(2, 2, 3)}, + ) + dynamic_axes = { + "x": {0: "customx_x_0", 1: "customx_x_1", 2: "customx_x_2"}, + "b": {0: "customb_b_0", 1: "customb_b_1", 2: "customb_b_2"}, + "x_out": {0: "customx_out_x_0", 1: "customx_out_x_1", 2: "customx_out_x_2"}, + "b_out": {0: "customb_out_b_0", 1: "customb_out_b_1", 2: "customb_out_b_2"}, + } + onnx_program = torch.onnx.export( + model, + input, + dynamic_axes=dynamic_axes, + input_names=["x", "b"], + output_names=["x_out", "b_out"], + dynamo=True, + ) + + # Check whether the dynamic dimension names are preserved + self.assertIs(onnx_program.model.graph.inputs[0].shape[0].value, "customx_x_0") + self.assertIs(onnx_program.model.graph.inputs[0].shape[1].value, "customx_x_1") + self.assertIs(onnx_program.model.graph.inputs[0].shape[2].value, "customx_x_2") + self.assertIs(onnx_program.model.graph.inputs[1].shape[0].value, "customb_b_0") + self.assertIs(onnx_program.model.graph.inputs[1].shape[1].value, "customb_b_1") + self.assertIs(onnx_program.model.graph.inputs[1].shape[2].value, "customb_b_2") + def test_saved_f_exists_after_export(self): with common_utils.TemporaryFileName(suffix=".onnx") as path: _ = torch.onnx.export( diff --git a/torch/onnx/_internal/exporter/_dynamic_shapes.py b/torch/onnx/_internal/exporter/_dynamic_shapes.py index 3b04ab85a886..20651017f3ea 100644 --- a/torch/onnx/_internal/exporter/_dynamic_shapes.py +++ b/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -39,6 +39,15 @@ def from_dynamic_axes_to_dynamic_shapes( Detail on Dim.DYNAMIC: `#133620 `_ """ + + warnings.warn( + "from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. " + "This function converts 'dynamic_axes' format (including custom axis names) to 'dynamic_shapes' format. " + "Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.", + DeprecationWarning, + stacklevel=2, + ) + # https://github.com/pytorch/pytorch/pull/128371 # 1. The function does not need to provide dynamic_shapes to torch.export.export if dynamic_axes is None: @@ -62,9 +71,8 @@ def from_dynamic_axes_to_dynamic_shapes( raise ValueError( "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." ) - dynamic_shapes[input_name] = { - k: torch.export.Dim.DYNAMIC for k, _ in axes.items() - } + # str will be converted to Dim.DYNAMIC in convert_str_to_export_dim + dynamic_shapes[input_name] = axes elif isinstance(axes, list): if any(not isinstance(k, int) for k in axes): raise ValueError(