mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Support symbolic arguments in onnx exporter (#157734)
Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing: ```python class M(torch.nn.Module): def forward(self, a, x): return a + torch.tensor(1) + x op = torch.onnx.export(M(), (1, torch.ones(2)), dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}), dynamo=True, report=True) ``` symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing. NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted. The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
163f0d8f2a
commit
08e9dd280f
@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
@ -73,6 +74,68 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
||||
)
|
||||
|
||||
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
|
||||
class constant_plus_tensor_inputs(torch.nn.Module):
|
||||
def forward(self, a, x):
|
||||
return a + torch.tensor(1) + x
|
||||
|
||||
# Capture log output
|
||||
log_capture = io.StringIO()
|
||||
log_handler = logging.StreamHandler(log_capture)
|
||||
log_handler.setLevel(logging.ERROR)
|
||||
# Get the logger used in _core.py
|
||||
logger = logging.getLogger("torch.onnx._internal.exporter._core")
|
||||
original_level = logger.level
|
||||
logger.addHandler(log_handler)
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
try:
|
||||
with common_utils.TemporaryDirectoryName() as temp_dir:
|
||||
self.assert_export(
|
||||
constant_plus_tensor_inputs(),
|
||||
(
|
||||
1,
|
||||
torch.ones(2),
|
||||
),
|
||||
dynamic_shapes=(
|
||||
torch.export.Dim.DYNAMIC,
|
||||
{0: torch.export.Dim.DYNAMIC},
|
||||
),
|
||||
report=True,
|
||||
artifacts_dir=temp_dir,
|
||||
)
|
||||
# Check if the expected error was logged
|
||||
log_output = log_capture.getvalue()
|
||||
self.assertNotIn("Failed to save report due to an error", log_output)
|
||||
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
|
||||
# Note: We don't call assert_onnx_program here because it will fail
|
||||
# due to the input name mismatch issue mentioned in your error
|
||||
|
||||
finally:
|
||||
# Clean up logging
|
||||
logger.removeHandler(log_handler)
|
||||
logger.setLevel(original_level)
|
||||
|
||||
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
|
||||
class constant_plus_tensor_inputs(torch.nn.Module):
|
||||
def forward(self, a, x):
|
||||
return a + torch.tensor(1) + x
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
constant_plus_tensor_inputs(),
|
||||
(
|
||||
1,
|
||||
torch.ones(2),
|
||||
),
|
||||
dynamic_shapes=(
|
||||
None,
|
||||
{0: torch.export.Dim.DYNAMIC},
|
||||
),
|
||||
dynamo=True,
|
||||
)
|
||||
|
||||
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
||||
self.assert_export(
|
||||
SampleModelForDynamicShapes(),
|
||||
|
@ -520,7 +520,7 @@ def dynamo_export(
|
||||
dynamic_shape[i] = torch.export.Dim.AUTO
|
||||
return dynamic_shape
|
||||
else:
|
||||
return None
|
||||
return torch.export.Dim.AUTO
|
||||
|
||||
# model_args could be nested
|
||||
dynamic_shapes = _pytree.tree_map(
|
||||
@ -529,7 +529,6 @@ def dynamo_export(
|
||||
)
|
||||
else:
|
||||
dynamic_shapes = None
|
||||
|
||||
return _compat.export_compat(
|
||||
model, # type: ignore[arg-type]
|
||||
model_args,
|
||||
|
@ -159,22 +159,48 @@ def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict,
|
||||
for spec in exported_program.graph_signature.output_specs
|
||||
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
|
||||
]
|
||||
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
|
||||
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
|
||||
inputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
|
||||
outputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
|
||||
for spec in user_inputs:
|
||||
if isinstance(spec.arg, graph_signature.ConstantArgument):
|
||||
continue
|
||||
name = spec.arg.name
|
||||
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
|
||||
inputs[name] = nodes[name].meta["tensor_meta"]
|
||||
inputs = _log_spec_into_io_specs(spec, nodes, inputs)
|
||||
for spec in user_outputs:
|
||||
if isinstance(spec.arg, graph_signature.ConstantArgument):
|
||||
continue
|
||||
name = spec.arg.name
|
||||
outputs[name] = nodes[name].meta["tensor_meta"]
|
||||
outputs = _log_spec_into_io_specs(spec, nodes, outputs)
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def _log_spec_into_io_specs(
|
||||
spec: graph_signature.InputSpec,
|
||||
nodes: dict[str, torch.fx.Node],
|
||||
inputs_or_outputs: dict[str, torch._export.serde.schema.TensorMeta | str],
|
||||
) -> dict[str, torch._export.serde.schema.TensorMeta | str]:
|
||||
# If dynamic is set to a constant input, it becomes a
|
||||
# symbolic argument, which is not a tensor.
|
||||
if isinstance(spec.arg, graph_signature.ConstantArgument):
|
||||
# Constant input does not have tensor_meta.
|
||||
return inputs_or_outputs
|
||||
# Symbolic arguments are not tensors, so it does not have tensor_meta,
|
||||
# but we need to provide a string representation for them to inform users.
|
||||
name = spec.arg.name
|
||||
if isinstance(
|
||||
spec.arg,
|
||||
(
|
||||
graph_signature.SymIntArgument,
|
||||
graph_signature.SymFloatArgument,
|
||||
graph_signature.SymBoolArgument,
|
||||
),
|
||||
):
|
||||
argument_to_str: dict[type[graph_signature.ArgumentSpec], str] = {
|
||||
graph_signature.SymIntArgument: "SymInt",
|
||||
graph_signature.SymFloatArgument: "SymFloat",
|
||||
graph_signature.SymBoolArgument: "SymBool",
|
||||
}
|
||||
inputs_or_outputs[name] = argument_to_str[type(spec.arg)]
|
||||
return inputs_or_outputs
|
||||
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
|
||||
inputs_or_outputs[name] = nodes[name].meta["tensor_meta"]
|
||||
return inputs_or_outputs
|
||||
|
||||
|
||||
def _count_fx_targets(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
) -> defaultdict[str, int]:
|
||||
|
@ -15,6 +15,8 @@ import textwrap
|
||||
import warnings
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import _dynamic_shapes, _ir_passes
|
||||
@ -117,31 +119,40 @@ def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
|
||||
return values
|
||||
|
||||
|
||||
def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue:
|
||||
def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue:
|
||||
"""Convert a PyTorch tensor to an ONNX Runtime OrtValue."""
|
||||
import onnxruntime as ort
|
||||
|
||||
from torch.onnx._internal.exporter import _core
|
||||
|
||||
if tensor.dtype == torch.bfloat16 or tensor.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
|
||||
if isinstance(input, (int, float, str, bool)):
|
||||
# Convert scalar values to OrtValue
|
||||
dtype_mapping = {
|
||||
int: np.int64,
|
||||
float: np.float32,
|
||||
}
|
||||
dtype = dtype_mapping.get(type(input), None)
|
||||
return ort.OrtValue.ortvalue_from_numpy(np.array(input, dtype=dtype))
|
||||
|
||||
if input.dtype == torch.bfloat16 or input.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
|
||||
if hasattr(ort.OrtValue, "ortvalue_from_numpy_with_onnx_type"):
|
||||
# This requires ONNX Runtime 1.21 or newer
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
if input.dtype == torch.bfloat16:
|
||||
uint_type = torch.uint16
|
||||
else:
|
||||
uint_type = torch.uint8
|
||||
onnx_type = _core.torch_dtype_to_onnx_dtype(tensor.dtype)
|
||||
onnx_type = _core.torch_dtype_to_onnx_dtype(input.dtype)
|
||||
# Make tensor contiguous to ensure view() works
|
||||
tensor = tensor.contiguous()
|
||||
input = input.contiguous()
|
||||
return ort.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
||||
tensor.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
|
||||
input.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. "
|
||||
f"Failed to convert tensor of type '{input.dtype}' to OrtValue. "
|
||||
"Please ensure that ONNX Runtime is built with DLPack support or is the latest version"
|
||||
)
|
||||
# TODO(#151064): Use dlpack when ORT properly supports it
|
||||
return ort.OrtValue.ortvalue_from_numpy(tensor.numpy(force=True))
|
||||
return ort.OrtValue.ortvalue_from_numpy(input.numpy(force=True))
|
||||
|
||||
|
||||
def _from_ort_value(value: ort.OrtValue) -> torch.Tensor:
|
||||
@ -208,7 +219,6 @@ ONNXProgram(
|
||||
|
||||
assert self._inference_session is not None
|
||||
|
||||
# We don't expect non-tensor as inputs
|
||||
ort_input = {
|
||||
k.name: _to_ort_value(v)
|
||||
for k, v in zip(self.model.graph.inputs, flatten_args)
|
||||
@ -414,7 +424,6 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:
|
||||
"""Process input arguments for the ONNX model."""
|
||||
args = _flatten_inputs(args, kwargs)
|
||||
args = _remove_none_from_inputs(args)
|
||||
args = _remove_non_tensor(args)
|
||||
args = _convert_complex_to_real_representation(args)
|
||||
return args
|
||||
|
||||
@ -428,47 +437,6 @@ def _remove_none_from_inputs(model_args):
|
||||
return tuple(arg for arg in model_args if arg is not None)
|
||||
|
||||
|
||||
def _remove_non_tensor(model_args):
|
||||
"""Remove the non-tensor input arguments.
|
||||
|
||||
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
|
||||
|
||||
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
|
||||
The concrete value is embedded into the graph as a constant arg of a target node. Meta
|
||||
suggests in this case that one should rewrite the model code to make it tensor if the
|
||||
input value is supposed to change at runtime. We might need to further investigate
|
||||
the feasibility of that suggestion.
|
||||
|
||||
For example,
|
||||
|
||||
def func(x, b=1.0):
|
||||
y = x + b
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
x = torch.randn(1, 1, 2, dtype=torch.float32)
|
||||
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
|
||||
|
||||
# class GraphModule(torch.nn.Module):
|
||||
# def forward(self, x, b):
|
||||
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
|
||||
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
|
||||
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
|
||||
|
||||
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
|
||||
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
|
||||
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
|
||||
|
||||
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
|
||||
it's ignored in ONNX graph. Thus, we delete the useless input here.
|
||||
|
||||
"""
|
||||
|
||||
return tuple(
|
||||
arg for arg in model_args if not isinstance(arg, (int, float, bool, str))
|
||||
)
|
||||
|
||||
|
||||
def _convert_complex_to_real_representation(model_args):
|
||||
"""Convert complex dtype tensors to real representation tensors.
|
||||
|
||||
|
Reference in New Issue
Block a user