mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Support renaming in dynamic axes to shapes conversion (#165769)
Discovered in ##165748 This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9f4999985
commit
543ddbf44c
@ -202,6 +202,51 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
|
||||
)
|
||||
|
||||
def test_from_dynamic_axes_to_dynamic_shapes_deprecation_warning(self):
|
||||
with self.assertWarnsRegex(
|
||||
DeprecationWarning,
|
||||
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
|
||||
"This function converts 'dynamic_axes' format \\(including custom axis names\\) to 'dynamic_shapes' format. "
|
||||
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
|
||||
):
|
||||
self.assert_export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"x": [0, 1, 2],
|
||||
"b": [0, 1, 2],
|
||||
},
|
||||
)
|
||||
|
||||
def test_from_dynamic_axes_to_dynamic_shapes_keeps_custom_axis_names(self):
|
||||
model = SampleModelForDynamicShapes()
|
||||
input = (
|
||||
torch.randn(2, 2, 3),
|
||||
{"b": torch.randn(2, 2, 3)},
|
||||
)
|
||||
dynamic_axes = {
|
||||
"x": {0: "customx_x_0", 1: "customx_x_1", 2: "customx_x_2"},
|
||||
"b": {0: "customb_b_0", 1: "customb_b_1", 2: "customb_b_2"},
|
||||
"x_out": {0: "customx_out_x_0", 1: "customx_out_x_1", 2: "customx_out_x_2"},
|
||||
"b_out": {0: "customb_out_b_0", 1: "customb_out_b_1", 2: "customb_out_b_2"},
|
||||
}
|
||||
onnx_program = torch.onnx.export(
|
||||
model,
|
||||
input,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=["x", "b"],
|
||||
output_names=["x_out", "b_out"],
|
||||
dynamo=True,
|
||||
)
|
||||
|
||||
# Check whether the dynamic dimension names are preserved
|
||||
self.assertIs(onnx_program.model.graph.inputs[0].shape[0].value, "customx_x_0")
|
||||
self.assertIs(onnx_program.model.graph.inputs[0].shape[1].value, "customx_x_1")
|
||||
self.assertIs(onnx_program.model.graph.inputs[0].shape[2].value, "customx_x_2")
|
||||
self.assertIs(onnx_program.model.graph.inputs[1].shape[0].value, "customb_b_0")
|
||||
self.assertIs(onnx_program.model.graph.inputs[1].shape[1].value, "customb_b_1")
|
||||
self.assertIs(onnx_program.model.graph.inputs[1].shape[2].value, "customb_b_2")
|
||||
|
||||
def test_saved_f_exists_after_export(self):
|
||||
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
||||
_ = torch.onnx.export(
|
||||
|
@ -39,6 +39,15 @@ def from_dynamic_axes_to_dynamic_shapes(
|
||||
|
||||
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
|
||||
"This function converts 'dynamic_axes' format (including custom axis names) to 'dynamic_shapes' format. "
|
||||
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/pull/128371
|
||||
# 1. The function does not need to provide dynamic_shapes to torch.export.export
|
||||
if dynamic_axes is None:
|
||||
@ -62,9 +71,8 @@ def from_dynamic_axes_to_dynamic_shapes(
|
||||
raise ValueError(
|
||||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||
)
|
||||
dynamic_shapes[input_name] = {
|
||||
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
|
||||
}
|
||||
# str will be converted to Dim.DYNAMIC in convert_str_to_export_dim
|
||||
dynamic_shapes[input_name] = axes
|
||||
elif isinstance(axes, list):
|
||||
if any(not isinstance(k, int) for k in axes):
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user