[ONNX] Support symbolic arguments in onnx exporter (#157734)

Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:

```python
class M(torch.nn.Module):
    def forward(self, a, x):
        return a + torch.tensor(1) + x

op = torch.onnx.export(M(), (1, torch.ones(2)),
                       dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}),
                       dynamo=True, report=True)
```

symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing.

NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.

The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang
2025-07-09 21:15:40 +00:00
committed by PyTorch MergeBot
parent 163f0d8f2a
commit 08e9dd280f
4 changed files with 120 additions and 64 deletions

View File

@ -4,6 +4,7 @@
from __future__ import annotations
import io
import logging
import os
import numpy as np
@ -73,6 +74,68 @@ class TestExportAPIDynamo(common_utils.TestCase):
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
)
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
class constant_plus_tensor_inputs(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
# Capture log output
log_capture = io.StringIO()
log_handler = logging.StreamHandler(log_capture)
log_handler.setLevel(logging.ERROR)
# Get the logger used in _core.py
logger = logging.getLogger("torch.onnx._internal.exporter._core")
original_level = logger.level
logger.addHandler(log_handler)
logger.setLevel(logging.ERROR)
try:
with common_utils.TemporaryDirectoryName() as temp_dir:
self.assert_export(
constant_plus_tensor_inputs(),
(
1,
torch.ones(2),
),
dynamic_shapes=(
torch.export.Dim.DYNAMIC,
{0: torch.export.Dim.DYNAMIC},
),
report=True,
artifacts_dir=temp_dir,
)
# Check if the expected error was logged
log_output = log_capture.getvalue()
self.assertNotIn("Failed to save report due to an error", log_output)
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
# Note: We don't call assert_onnx_program here because it will fail
# due to the input name mismatch issue mentioned in your error
finally:
# Clean up logging
logger.removeHandler(log_handler)
logger.setLevel(original_level)
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
class constant_plus_tensor_inputs(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
onnx_program = torch.onnx.export(
constant_plus_tensor_inputs(),
(
1,
torch.ones(2),
),
dynamic_shapes=(
None,
{0: torch.export.Dim.DYNAMIC},
),
dynamo=True,
)
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
self.assert_export(
SampleModelForDynamicShapes(),

View File

@ -520,7 +520,7 @@ def dynamo_export(
dynamic_shape[i] = torch.export.Dim.AUTO
return dynamic_shape
else:
return None
return torch.export.Dim.AUTO
# model_args could be nested
dynamic_shapes = _pytree.tree_map(
@ -529,7 +529,6 @@ def dynamo_export(
)
else:
dynamic_shapes = None
return _compat.export_compat(
model, # type: ignore[arg-type]
model_args,

View File

@ -159,22 +159,48 @@ def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict,
for spec in exported_program.graph_signature.output_specs
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
]
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
inputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
outputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
for spec in user_inputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
inputs[name] = nodes[name].meta["tensor_meta"]
inputs = _log_spec_into_io_specs(spec, nodes, inputs)
for spec in user_outputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
outputs[name] = nodes[name].meta["tensor_meta"]
outputs = _log_spec_into_io_specs(spec, nodes, outputs)
return inputs, outputs
def _log_spec_into_io_specs(
spec: graph_signature.InputSpec,
nodes: dict[str, torch.fx.Node],
inputs_or_outputs: dict[str, torch._export.serde.schema.TensorMeta | str],
) -> dict[str, torch._export.serde.schema.TensorMeta | str]:
# If dynamic is set to a constant input, it becomes a
# symbolic argument, which is not a tensor.
if isinstance(spec.arg, graph_signature.ConstantArgument):
# Constant input does not have tensor_meta.
return inputs_or_outputs
# Symbolic arguments are not tensors, so it does not have tensor_meta,
# but we need to provide a string representation for them to inform users.
name = spec.arg.name
if isinstance(
spec.arg,
(
graph_signature.SymIntArgument,
graph_signature.SymFloatArgument,
graph_signature.SymBoolArgument,
),
):
argument_to_str: dict[type[graph_signature.ArgumentSpec], str] = {
graph_signature.SymIntArgument: "SymInt",
graph_signature.SymFloatArgument: "SymFloat",
graph_signature.SymBoolArgument: "SymBool",
}
inputs_or_outputs[name] = argument_to_str[type(spec.arg)]
return inputs_or_outputs
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
inputs_or_outputs[name] = nodes[name].meta["tensor_meta"]
return inputs_or_outputs
def _count_fx_targets(
exported_program: torch.export.ExportedProgram,
) -> defaultdict[str, int]:

View File

@ -15,6 +15,8 @@ import textwrap
import warnings
from typing import Any, Callable, TYPE_CHECKING
import numpy as np
import torch
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.exporter import _dynamic_shapes, _ir_passes
@ -117,31 +119,40 @@ def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
return values
def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue:
def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue:
"""Convert a PyTorch tensor to an ONNX Runtime OrtValue."""
import onnxruntime as ort
from torch.onnx._internal.exporter import _core
if tensor.dtype == torch.bfloat16 or tensor.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
if isinstance(input, (int, float, str, bool)):
# Convert scalar values to OrtValue
dtype_mapping = {
int: np.int64,
float: np.float32,
}
dtype = dtype_mapping.get(type(input), None)
return ort.OrtValue.ortvalue_from_numpy(np.array(input, dtype=dtype))
if input.dtype == torch.bfloat16 or input.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
if hasattr(ort.OrtValue, "ortvalue_from_numpy_with_onnx_type"):
# This requires ONNX Runtime 1.21 or newer
if tensor.dtype == torch.bfloat16:
if input.dtype == torch.bfloat16:
uint_type = torch.uint16
else:
uint_type = torch.uint8
onnx_type = _core.torch_dtype_to_onnx_dtype(tensor.dtype)
onnx_type = _core.torch_dtype_to_onnx_dtype(input.dtype)
# Make tensor contiguous to ensure view() works
tensor = tensor.contiguous()
input = input.contiguous()
return ort.OrtValue.ortvalue_from_numpy_with_onnx_type(
tensor.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
input.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
)
raise RuntimeError(
f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. "
f"Failed to convert tensor of type '{input.dtype}' to OrtValue. "
"Please ensure that ONNX Runtime is built with DLPack support or is the latest version"
)
# TODO(#151064): Use dlpack when ORT properly supports it
return ort.OrtValue.ortvalue_from_numpy(tensor.numpy(force=True))
return ort.OrtValue.ortvalue_from_numpy(input.numpy(force=True))
def _from_ort_value(value: ort.OrtValue) -> torch.Tensor:
@ -208,7 +219,6 @@ ONNXProgram(
assert self._inference_session is not None
# We don't expect non-tensor as inputs
ort_input = {
k.name: _to_ort_value(v)
for k, v in zip(self.model.graph.inputs, flatten_args)
@ -414,7 +424,6 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:
"""Process input arguments for the ONNX model."""
args = _flatten_inputs(args, kwargs)
args = _remove_none_from_inputs(args)
args = _remove_non_tensor(args)
args = _convert_complex_to_real_representation(args)
return args
@ -428,47 +437,6 @@ def _remove_none_from_inputs(model_args):
return tuple(arg for arg in model_args if arg is not None)
def _remove_non_tensor(model_args):
"""Remove the non-tensor input arguments.
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
The concrete value is embedded into the graph as a constant arg of a target node. Meta
suggests in this case that one should rewrite the model code to make it tensor if the
input value is supposed to change at runtime. We might need to further investigate
the feasibility of that suggestion.
For example,
def func(x, b=1.0):
y = x + b
z = y.relu()
return (y, z)
x = torch.randn(1, 1, 2, dtype=torch.float32)
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
# class GraphModule(torch.nn.Module):
# def forward(self, x, b):
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
it's ignored in ONNX graph. Thus, we delete the useless input here.
"""
return tuple(
arg for arg in model_args if not isinstance(arg, (int, float, bool, str))
)
def _convert_complex_to_real_representation(model_args):
"""Convert complex dtype tensors to real representation tensors.