[ONNX] Support symbolic arguments in onnx exporter (#157734)

Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:

```python
class M(torch.nn.Module):
    def forward(self, a, x):
        return a + torch.tensor(1) + x

op = torch.onnx.export(M(), (1, torch.ones(2)),
                       dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}),
                       dynamo=True, report=True)
```

symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing.

NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.

The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang
2025-07-09 21:15:40 +00:00
committed by PyTorch MergeBot
parent 163f0d8f2a
commit 08e9dd280f
4 changed files with 120 additions and 64 deletions

View File

@ -4,6 +4,7 @@
from __future__ import annotations
import io
import logging
import os
import numpy as np
@ -73,6 +74,68 @@ class TestExportAPIDynamo(common_utils.TestCase):
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
)
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
class constant_plus_tensor_inputs(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
# Capture log output
log_capture = io.StringIO()
log_handler = logging.StreamHandler(log_capture)
log_handler.setLevel(logging.ERROR)
# Get the logger used in _core.py
logger = logging.getLogger("torch.onnx._internal.exporter._core")
original_level = logger.level
logger.addHandler(log_handler)
logger.setLevel(logging.ERROR)
try:
with common_utils.TemporaryDirectoryName() as temp_dir:
self.assert_export(
constant_plus_tensor_inputs(),
(
1,
torch.ones(2),
),
dynamic_shapes=(
torch.export.Dim.DYNAMIC,
{0: torch.export.Dim.DYNAMIC},
),
report=True,
artifacts_dir=temp_dir,
)
# Check if the expected error was logged
log_output = log_capture.getvalue()
self.assertNotIn("Failed to save report due to an error", log_output)
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
# Note: We don't call assert_onnx_program here because it will fail
# due to the input name mismatch issue mentioned in your error
finally:
# Clean up logging
logger.removeHandler(log_handler)
logger.setLevel(original_level)
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
class constant_plus_tensor_inputs(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
onnx_program = torch.onnx.export(
constant_plus_tensor_inputs(),
(
1,
torch.ones(2),
),
dynamic_shapes=(
None,
{0: torch.export.Dim.DYNAMIC},
),
dynamo=True,
)
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
self.assert_export(
SampleModelForDynamicShapes(),