mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Support renaming in dynamic axes to shapes conversion (#165769)
Discovered in ##165748 This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9f4999985
commit
543ddbf44c
@ -202,6 +202,51 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
|||||||
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
|
dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_from_dynamic_axes_to_dynamic_shapes_deprecation_warning(self):
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
DeprecationWarning,
|
||||||
|
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
|
||||||
|
"This function converts 'dynamic_axes' format \\(including custom axis names\\) to 'dynamic_shapes' format. "
|
||||||
|
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
|
||||||
|
):
|
||||||
|
self.assert_export(
|
||||||
|
SampleModelForDynamicShapes(),
|
||||||
|
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||||
|
dynamic_axes={
|
||||||
|
"x": [0, 1, 2],
|
||||||
|
"b": [0, 1, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_dynamic_axes_to_dynamic_shapes_keeps_custom_axis_names(self):
|
||||||
|
model = SampleModelForDynamicShapes()
|
||||||
|
input = (
|
||||||
|
torch.randn(2, 2, 3),
|
||||||
|
{"b": torch.randn(2, 2, 3)},
|
||||||
|
)
|
||||||
|
dynamic_axes = {
|
||||||
|
"x": {0: "customx_x_0", 1: "customx_x_1", 2: "customx_x_2"},
|
||||||
|
"b": {0: "customb_b_0", 1: "customb_b_1", 2: "customb_b_2"},
|
||||||
|
"x_out": {0: "customx_out_x_0", 1: "customx_out_x_1", 2: "customx_out_x_2"},
|
||||||
|
"b_out": {0: "customb_out_b_0", 1: "customb_out_b_1", 2: "customb_out_b_2"},
|
||||||
|
}
|
||||||
|
onnx_program = torch.onnx.export(
|
||||||
|
model,
|
||||||
|
input,
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
input_names=["x", "b"],
|
||||||
|
output_names=["x_out", "b_out"],
|
||||||
|
dynamo=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check whether the dynamic dimension names are preserved
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[0].shape[0].value, "customx_x_0")
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[0].shape[1].value, "customx_x_1")
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[0].shape[2].value, "customx_x_2")
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[1].shape[0].value, "customb_b_0")
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[1].shape[1].value, "customb_b_1")
|
||||||
|
self.assertIs(onnx_program.model.graph.inputs[1].shape[2].value, "customb_b_2")
|
||||||
|
|
||||||
def test_saved_f_exists_after_export(self):
|
def test_saved_f_exists_after_export(self):
|
||||||
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
||||||
_ = torch.onnx.export(
|
_ = torch.onnx.export(
|
||||||
|
@ -39,6 +39,15 @@ def from_dynamic_axes_to_dynamic_shapes(
|
|||||||
|
|
||||||
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
|
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. "
|
||||||
|
"This function converts 'dynamic_axes' format (including custom axis names) to 'dynamic_shapes' format. "
|
||||||
|
"Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/pull/128371
|
# https://github.com/pytorch/pytorch/pull/128371
|
||||||
# 1. The function does not need to provide dynamic_shapes to torch.export.export
|
# 1. The function does not need to provide dynamic_shapes to torch.export.export
|
||||||
if dynamic_axes is None:
|
if dynamic_axes is None:
|
||||||
@ -62,9 +71,8 @@ def from_dynamic_axes_to_dynamic_shapes(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||||
)
|
)
|
||||||
dynamic_shapes[input_name] = {
|
# str will be converted to Dim.DYNAMIC in convert_str_to_export_dim
|
||||||
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
|
dynamic_shapes[input_name] = axes
|
||||||
}
|
|
||||||
elif isinstance(axes, list):
|
elif isinstance(axes, list):
|
||||||
if any(not isinstance(k, int) for k in axes):
|
if any(not isinstance(k, int) for k in axes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Reference in New Issue
Block a user