[ONNX] New export logic leveraging ExportedProgram and ONNX IR (#132530)

1/n PR to

- Move code from torch-onnx from commit 395495e566 into torch.onnx and fixes imports.
- Integrate the new export logic with the torch.onnx.export API and include basic set of tests.
- Refactor the API for the change.
- Improve documentation.

Next PRs will be more tests and docs.

Fix https://github.com/pytorch/pytorch/issues/129277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132530
Approved by: https://github.com/titaiwangms, https://github.com/malfet
This commit is contained in:
Justin Chu
2024-08-18 13:14:17 -07:00
committed by PyTorch MergeBot
parent 92151c814b
commit 5fab35d77c
28 changed files with 5321 additions and 356 deletions

View File

@ -165,9 +165,6 @@ ignore_missing_imports = True
[mypy-tensorboard.*]
ignore_missing_imports = True
[mypy-onnx.*]
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
@ -301,5 +298,14 @@ ignore_missing_imports = True
# Third party dependencies that are optional.
#
[mypy-onnx.*]
ignore_missing_imports = True
[mypy-onnxruntime.*]
ignore_missing_imports = True
[mypy-onnxscript.*]
ignore_missing_imports = True
[mypy-redis]
ignore_missing_imports = True

View File

@ -203,222 +203,5 @@ class TestLargeProtobufONNXProgramSerializerAPI(common_utils.TestCase):
serializer.serialize(onnx_program, io.BytesIO())
class TestONNXExportWithDynamo(common_utils.TestCase):
def test_args_normalization_with_no_kwargs(self):
exported_program = torch.export.export(
SampleModelTwoInputs(),
(
torch.randn(1, 1, 2),
torch.randn(1, 1, 2),
),
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_args_is_tensor_not_tuple(self):
exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),))
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program, torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModel(), torch.randn(1, 1, 2), dynamo=True
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_args_normalization_with_kwargs(self):
exported_program = torch.export.export(
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_args_normalization_with_empty_dict_at_the_tail(self):
exported_program = torch.export.export(
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
exported_program = torch.export.export(
SampleModelForDynamicShapes(),
(
torch.randn(2, 2, 3),
torch.randn(2, 2, 3),
),
dynamic_shapes={
"x": {
0: torch.export.Dim("customx_dim_0"),
1: torch.export.Dim("customx_dim_1"),
2: torch.export.Dim("customx_dim_2"),
},
"b": {
0: torch.export.Dim("customb_dim_0"),
1: torch.export.Dim("customb_dim_1"),
2: torch.export.Dim("customb_dim_2"),
},
},
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program,
torch.randn(2, 2, 3),
b=torch.randn(2, 2, 3),
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
},
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
exported_program = torch.export.export(
SampleModelForDynamicShapes(),
(
torch.randn(2, 2, 3),
torch.randn(2, 2, 3),
),
dynamic_shapes={
"x": {
0: torch.export.Dim("customx_dim_0"),
1: torch.export.Dim("customx_dim_1"),
2: torch.export.Dim("customx_dim_2"),
},
"b": {
0: torch.export.Dim("customb_dim_0"),
1: torch.export.Dim("customb_dim_1"),
2: torch.export.Dim("customb_dim_2"),
},
},
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program,
torch.randn(2, 2, 3),
b=torch.randn(2, 2, 3),
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"x": [0, 1, 2],
"b": [0, 1, 2],
},
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
exported_program = torch.export.export(
SampleModelForDynamicShapes(),
(
torch.randn(2, 2, 3),
torch.randn(2, 2, 3),
),
dynamic_shapes={
"x": None,
"b": {
0: torch.export.Dim("customb_dim_0"),
1: torch.export.Dim("customb_dim_1"),
2: torch.export.Dim("customb_dim_2"),
},
},
)
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
exported_program,
torch.randn(2, 2, 3),
b=torch.randn(2, 2, 3),
)
onnx_program_from_old_exporter = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"b": [0, 1, 2],
},
dynamo=True,
)
self.assertEqual(
onnx_program_from_new_exporter.model_proto,
onnx_program_from_old_exporter.model_proto,
)
def test_dynamic_shapes_hit_constraints_in_dynamo(self):
# SampleModelTwoInputs has constraints becuse of add of two inputs,
# so the two input shapes are related.
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Constraints violated",
):
_ = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(2, 2, 3), torch.randn(2, 2, 3)),
dynamic_axes={
"x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"},
"b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"},
},
dynamo=True,
)
def test_saved_f_exists_after_export(self):
with common_utils.TemporaryFileName(suffix=".onnx") as path:
_ = torch.onnx.export(
SampleModel(), torch.randn(1, 1, 2), path, dynamo=True
)
self.assertTrue(os.path.exists(path))
def test_raises_error_when_input_is_script_module(self):
class ScriptModule(torch.jit.ScriptModule):
def forward(self, x):
return x
with self.assertRaisesRegex(
TypeError,
"Dynamo export does not support ScriptModule or ScriptFunction.",
):
_ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -0,0 +1 @@
Directory for all ExportedProgram exporter logic.

View File

@ -0,0 +1,120 @@
# Owner(s): ["module: onnx"]
"""Simple API tests for the ONNX exporter."""
from __future__ import annotations
import os
import torch
from torch.onnx._internal import exporter
from torch.testing._internal import common_utils
class SampleModel(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y.relu()
return (y, z)
class SampleModelTwoInputs(torch.nn.Module):
def forward(self, x, b):
y = x + b
z = y.relu()
return (y, z)
class SampleModelForDynamicShapes(torch.nn.Module):
def forward(self, x, b):
return x.relu(), b.sigmoid()
class TestExportAPIDynamo(common_utils.TestCase):
"""Tests for the ONNX exporter API when dynamo=True."""
def test_args_normalization_with_no_kwargs(self):
onnx_program = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_args_normalization_with_kwargs(self):
onnx_program = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_args_normalization_with_empty_dict_at_the_tail(self):
onnx_program = torch.onnx.export(
SampleModelTwoInputs(),
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
onnx_program = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
},
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
onnx_program = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"x": [0, 1, 2],
"b": [0, 1, 2],
},
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
onnx_program = torch.onnx.export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
dynamic_axes={
"b": [0, 1, 2],
},
dynamo=True,
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
def test_saved_f_exists_after_export(self):
with common_utils.TemporaryFileName(suffix=".onnx") as path:
_ = torch.onnx.export(
SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
)
self.assertTrue(os.path.exists(path))
def test_export_supports_script_module(self):
class ScriptModule(torch.nn.Module):
def forward(self, x):
return x
onnx_program = torch.onnx.export(
torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),), dynamo=True
)
assert onnx_program
exporter.verify_onnx_program(onnx_program)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -286,6 +286,25 @@ class TestPublicBindings(TestCase):
# do not get imported by public code.
private_allowlist = {
"torch._inductor.codegen.cuda.cuda_kernel",
# TODO(#133647): Remove the onnx._internal entries after
# onnx and onnxscript are installed in CI.
"torch.onnx._internal.exporter",
"torch.onnx._internal.exporter._analysis",
"torch.onnx._internal.exporter._building",
"torch.onnx._internal.exporter._capture_strategies",
"torch.onnx._internal.exporter._compat",
"torch.onnx._internal.exporter._core",
"torch.onnx._internal.exporter._decomp",
"torch.onnx._internal.exporter._dispatching",
"torch.onnx._internal.exporter._fx_passes",
"torch.onnx._internal.exporter._ir_passes",
"torch.onnx._internal.exporter._isolated",
"torch.onnx._internal.exporter._onnx_program",
"torch.onnx._internal.exporter._registration",
"torch.onnx._internal.exporter._reporting",
"torch.onnx._internal.exporter._schemas",
"torch.onnx._internal.exporter._tensors",
"torch.onnx._internal.exporter._verification",
"torch.onnx._internal.fx._pass",
"torch.onnx._internal.fx.analysis",
"torch.onnx._internal.fx.analysis.unsupported_nodes",

View File

@ -1,65 +1,5 @@
# mypy: allow-untyped-defs
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._exporter_states import ExportTypes
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_opset7,
symbolic_opset8,
symbolic_opset9,
symbolic_opset10,
symbolic_opset11,
symbolic_opset12,
symbolic_opset13,
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
utils,
)
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
DiagnosticOptions,
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
dynamo_export,
enable_fake_mode,
)
from __future__ import annotations
__all__ = [
@ -115,6 +55,74 @@ __all__ = [
"is_onnxrt_backend_supported",
]
from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._exporter_states import ExportTypes
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_opset7,
symbolic_opset8,
symbolic_opset9,
symbolic_opset10,
symbolic_opset11,
symbolic_opset12,
symbolic_opset13,
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
utils,
)
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
DiagnosticOptions,
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
dynamo_export,
enable_fake_mode,
)
if TYPE_CHECKING:
import os
# Set namespace for exposed private names
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
@ -137,6 +145,257 @@ producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
def export(
model: torch.nn.Module
| torch.export.ExportedProgram
| torch.jit.ScriptModule
| torch.jit.ScriptFunction,
args: tuple[Any, ...],
f: str | os.PathLike | None = None,
*,
kwargs: dict[str, Any] | None = None,
export_params: bool = True,
verbose: bool | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
| Mapping[str, Sequence[int]]
| None = None,
keep_initializers_as_inputs: bool = False,
dynamo: bool = False,
# Dynamo only options
external_data: bool = True,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
report: bool = False,
verify: bool = False,
profile: bool = False,
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
# Deprecated options
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
do_constant_folding: bool = True,
custom_opsets: Mapping[str, int] | None = None,
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
autograd_inlining: bool = True,
**_: Any, # ignored options
) -> Any | None:
r"""Exports a model into ONNX format.
Args:
model: The model to be exported.
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
exported model; any Tensor arguments will become inputs of the exported model,
in the order they occur in the tuple.
f: Path to the output ONNX model file. E.g. "model.onnx".
kwargs: Optional example keyword inputs.
export_params: If false, parameters (weights) will not be exported.
verbose: Whether to enable verbose logging.
input_names: names to assign to the input nodes of the graph, in order.
output_names: names to assign to the output nodes of the graph, in order.
opset_version: The version of the
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
to target. Must be >= 7.
dynamic_axes:
By default the exported model will have the shapes of all input and output tensors
set to exactly match those given in ``args``. To specify axes of tensors as
dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
* KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
``output_names``.
* VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
list, each element is an axis index.
For example::
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
)
Produces::
input {
name: "x"
...
shape {
dim {
dim_value: 2 # axis 0
}
dim {
dim_value: 2 # axis 1
...
output {
name: "sum"
...
shape {
dim {
dim_value: 2 # axis 0
...
While::
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"],
dynamic_axes={
# dict value: manually named axes
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
}
)
Produces::
input {
name: "x"
...
shape {
dim {
dim_param: "my_custom_axis_name" # axis 0
}
dim {
dim_value: 2 # axis 1
...
output {
name: "sum"
...
shape {
dim {
dim_param: "sum_dynamic_axes_1" # axis 0
...
keep_initializers_as_inputs: If True, all the
initializers (typically corresponding to model weights) in the
exported graph will also be added as inputs to the graph. If False,
then initializers are not added as inputs to the graph, and only
the user inputs are added as inputs.
Set this to True if you intend to supply model weights at runtime.
Set it to False if the weights are static to allow for better optimizations
(e.g. constant folding) by backends/runtimes.
dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
external_data: Whether to save the model weights as an external data file.
This is required for models with large weights that exceed the ONNX file size limit (2GB).
When False, the weights are saved in the ONNX file with the model architecture.
dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
:func:`torch.export.export` for more details.
report: Whether to generate a markdown report for the export process.
verify: Whether to verify the exported model using ONNX Runtime.
profile: Whether to profile the export process.
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
This is useful for debugging the exporter.
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
exported program.
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
training: Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type: Deprecated option. Only ONNX is supported.
do_constant_folding: Deprecated option. The exported graph is always optimized.
custom_opsets: Deprecated.
A dictionary:
* KEY (str): opset domain name
* VALUE (int): opset version
If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
the opset version is set to 1. Only custom opset domain name and version should be
indicated through this argument.
export_modules_as_functions: Deprecated option.
Flag to enable
exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
particular types of modules to export as local functions in ONNX.
This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
``opset_version`` < 15 implies IR version < 8, which means no local function support.
Module variables will be exported as function attributes. There are two categories of function
attributes.
1. Annotated attributes: class variables that have type annotations via
`PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
will be exported as attributes.
Annotated attributes are not used inside the subgraph of ONNX local function because
they are not created by PyTorch JIT tracing, but they may be used by consumers
to determine whether or not to replace the function with a particular fused kernel.
2. Inferred attributes: variables that are used by operators inside the module. Attribute names
will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
* ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
* ``True``: export all ``nn.Module`` forward calls as local function nodes.
* Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
only if the type of the ``nn.Module`` is found in the set.
autograd_inlining: Deprecated.
Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
"""
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
from torch.onnx._internal import exporter
if isinstance(args, torch.Tensor):
args = (args,)
return exporter.export_compat(
model,
args,
f,
kwargs=kwargs,
export_params=export_params,
verbose=verbose,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,
dynamic_shapes=dynamic_shapes,
report=report,
verify=verify,
profile=profile,
dump_exported_program=dump_exported_program,
artifacts_dir=artifacts_dir,
fallback=fallback,
)
else:
from torch.onnx.utils import export
export(
model,
args,
f, # type: ignore[arg-type]
kwargs=kwargs,
export_params=export_params,
verbose=verbose is True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
training=training,
operator_export_type=operator_export_type,
do_constant_folding=do_constant_folding,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
autograd_inlining=autograd_inlining,
)
return None
@_deprecation.deprecated(
since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead"
)

View File

@ -0,0 +1,38 @@
"""Utility to lazily import modules."""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from typing import Any, TYPE_CHECKING
class _LazyModule:
"""Lazily import a module."""
def __init__(self, module_name: str) -> None:
self._name = module_name
self._module: Any = None
def __repr__(self) -> str:
return f"<lazy module '{self._name}'>"
def __getattr__(self, attr):
if self._module is None:
self._module = importlib.import_module(".", self._name)
return getattr(self._module, attr)
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
# NOTE: Add additional used imports here.
if TYPE_CHECKING:
import onnx
import onnxscript
onnxscript_ir = onnxscript.ir
else:
onnx = _LazyModule("onnx")
onnxscript = _LazyModule("onnxscript")
onnxscript_ir = _LazyModule("onnxscript.ir")

View File

@ -0,0 +1,16 @@
__all__ = [
"ONNXRegistry",
"ONNXProgram",
"analyze",
"export",
"exported_program_to_ir",
"verify_onnx_program",
"export_compat",
]
from ._analysis import analyze
from ._compat import export_compat
from ._core import export, exported_program_to_ir
from ._onnx_program import ONNXProgram
from ._registration import ONNXRegistry
from ._verification import verify_onnx_program

View File

@ -0,0 +1,250 @@
"""Compatibility analyzer for PyTorch models."""
# mypy: allow-untyped-defs
# flake8: noqa: B950 We do not need flake8 as it complains line length
from __future__ import annotations
import dataclasses
import textwrap
import traceback
from collections import defaultdict
from typing import TYPE_CHECKING
import onnxscript
import torch
import torch._export.serde.schema
from torch.export import graph_signature
from torch.onnx._internal.exporter import _dispatching, _registration
if TYPE_CHECKING:
import torch.fx
@dataclasses.dataclass
class ModelInfo:
"""Information about the model."""
parameter_count: defaultdict[torch.dtype, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
buffer_count: defaultdict[torch.dtype, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
fx_node_count: int = 0
fx_node_op_count: defaultdict[str, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
fx_node_target_count: defaultdict[str, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field(
default_factory=list
)
inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
default_factory=dict
)
outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
default_factory=dict
)
def _count_weights(
exported_program: torch.export.ExportedProgram,
) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]:
"""Count the size of the parameters in the exported program."""
parameter_count: defaultdict[torch.dtype, int] = defaultdict(int)
buffer_count: defaultdict[torch.dtype, int] = defaultdict(int)
for parameter in exported_program.parameters():
dtype = parameter.dtype
parameter_count[dtype] += parameter.numel()
for buffer in exported_program.buffers():
dtype = buffer.dtype
buffer_count[dtype] += buffer.numel()
return parameter_count, buffer_count
def _format_model_info(model_info: ModelInfo) -> str:
"""Format the information about the model."""
lines = [
textwrap.dedent(
f"""\
PyTorch ONNX Conversion Analysis
## Model Information
The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters).
Number of parameters per dtype:
```python
{model_info.parameter_count}
```
Number of buffers per dtype:
```python
{model_info.buffer_count}
```
"""
),
"Inputs:",
*[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()],
"",
"Outputs:",
*[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()],
"",
f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:",
]
for op, count in model_info.fx_node_op_count.items():
lines.append(f"- `{op}`: {count}")
lines.append("\n")
lines.append("Of the call_function nodes, the counts of operators used are:\n")
sorted_targets = sorted(
model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True
)
for target, count in sorted_targets:
lines.append(f"- `{target}`: {count}")
lines.append("")
lines.append("## ONNX Conversion Information")
lines.append("")
if model_info.dispatch_failures:
lines.append(
"The model contains operators the dispatcher could not find registered ONNX decompositions for. "
"This may be due to missing implementations, decompositions not registered "
"correctly, or a bug in the dispatcher."
)
lines.append("")
lines.append("Errors grouped by operator:\n")
target_to_nodes = defaultdict(list)
for node, _ in model_info.dispatch_failures:
target_to_nodes[str(node.target)].append(node)
target_to_messages = {}
for node, message in model_info.dispatch_failures:
if str(node.target) not in target_to_messages:
target_to_messages[str(node.target)] = message
for target, nodes in sorted(
target_to_nodes.items(), key=lambda x: x[0], reverse=True
):
message = textwrap.indent(
f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`",
" ",
)
lines.append(f"- `{target}`: {message}")
else:
lines.append("All operators in the model have registered ONNX decompositions.")
return "\n".join(lines)
def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]:
"""Get the input and output specs of the exported program."""
nodes: dict[str, torch.fx.Node] = {
node.name: node for node in exported_program.graph.nodes
}
user_inputs = [
spec
for spec in exported_program.graph_signature.input_specs
if spec.kind == graph_signature.InputKind.USER_INPUT
]
user_outputs = [
spec
for spec in exported_program.graph_signature.output_specs
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
]
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
for spec in user_inputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
inputs[name] = nodes[name].meta["tensor_meta"]
for spec in user_outputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
outputs[name] = nodes[name].meta["tensor_meta"]
return inputs, outputs
def _count_fx_targets(
exported_program: torch.export.ExportedProgram,
) -> defaultdict[str, int]:
"""Count the number of targets for each node in the exported program."""
fx_node_target_count: defaultdict[str, int] = defaultdict(int)
for node in exported_program.graph.nodes:
if node.op == "call_function":
fx_node_target_count[str(node.target)] += 1
return fx_node_target_count
def analyze(
exported_program: torch.export.ExportedProgram,
registry: _registration.ONNXRegistry | None = None,
file=None,
) -> None:
"""Analyze the compatibility of the exported program."""
# Get basic information about the model
model_info = ModelInfo()
model_info.parameter_count, model_info.buffer_count = _count_weights(
exported_program
)
model_info.fx_node_count = len(exported_program.graph.nodes)
model_info.fx_node_target_count = _count_fx_targets(exported_program)
inputs, outputs = _get_io_specs(exported_program)
model_info.inputs = inputs
model_info.outputs = outputs
if registry is None:
# Trigger op registration
from onnxscript.function_libs.torch_lib import ops # noqa: F401
del ops
registry = _registration.ONNXRegistry.from_torchlib(
onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type]
)
# Try to find ops for every node in the graph
for node in exported_program.graph.nodes:
model_info.fx_node_op_count[node.op] += 1
if node.op == "call_function":
try:
onnx_function, message = _dispatching.dispatch(node, registry)
except Exception as e:
message = "Critical Error in dispatcher:\n"
formatted_exception = "\n".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
message += f"```pytb\n{formatted_exception}\n```\n"
onnx_function = None
if onnx_function is None:
model_info.dispatch_failures.append((node, message))
# Print the results
report = _format_model_info(model_info)
print(report, file=file, flush=True)
def compare_ops(
program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram
) -> tuple[set[str], set[str]]:
"""Compare and get unique ops in two exported programs.
Args:
program_a: The first exported program.
program_b: The second exported program.
Returns:
A tuple of two sets, where the first set contains the unique ops in the first program
and the second set contains the unique ops in the second program.
"""
program_a_ops = set(_count_fx_targets(program_a))
program_b_ops = set(_count_fx_targets(program_b))
return program_a_ops - program_b_ops, program_b_ops - program_a_ops

View File

@ -0,0 +1,516 @@
"""NOTES:
We need a typing module that will handling Python to ONNX type promotion for use.
For example, if we have torch.ops.aten.add(Tensor, 1.0), we need to promote 1.0
to the same type as Tensor. The same thing needs to work for
torch.ops.aten.add(1.0, Tensor) as well, which means we need a mechanism to`
"""
# mypy: allow-untyped-defs
# mypy: disable-error-code=union-attr
from __future__ import annotations
import copy
import inspect
import logging
from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union
import onnxscript
from onnxscript import evaluator, ir
from onnxscript.ir import convenience as ir_convenience
import torch
from torch.onnx._internal.exporter import _schemas, _tensors, errors
if TYPE_CHECKING:
import onnx
logger = logging.getLogger(__name__)
# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes
ValidAttributeType = Union[
ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None
]
AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType]
# Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value
def _construct_named_inputs_and_attrs(
signature: _schemas.OpSignature,
args: Sequence[AllowedArgType],
kwargs: Mapping[str, AllowedArgType],
) -> tuple[dict[str, AllowedArgType], dict[str, ValidAttributeType]]:
"""Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs.
This function uses the OpSignature to determine which argument in args and kwargs corresponds to
which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are
stored in named_attrs. If an _optional input_ is not provided, it is filled with None.
Args:
signature: The OpSignature for the node.
args: The positional arguments for the node.
kwargs: The keyword arguments for the node.
Returns:
A tuple of two mappings: named_inputs and named_attrs.
Raises:
ValueError: If a required parameter is not provided.
"""
# 1. Construct the (named_inputs, named_attrs) mapping based on (args, kwargs) and the signature.
# a. Loop over all parameters in the signature and args together
# b. Depending on param.is_input, Record named_inputs[param.name] = arg or named_attrs[param.name] = arg
# c. Handle kwargs as well
# d. Fill in None if the input is not provided
named_inputs = {}
named_attrs = {}
reversed_args_stack = list(reversed(args))
for param in signature.params:
if isinstance(param, _schemas.Parameter):
# Handle inputs
if reversed_args_stack:
# First exhaust the positional arguments
if param.variadic:
# Handle variadic arguments
named_inputs[param.name] = tuple(args)
reversed_args_stack.clear()
else:
named_inputs[param.name] = reversed_args_stack.pop() # type: ignore[assignment]
elif param.name in kwargs:
named_inputs[param.name] = kwargs[param.name] # type: ignore[assignment]
elif param.required:
raise ValueError(
f"Required parameter '{param.name}' is not provided. "
f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}."
)
else:
logger.debug(
"Optional parameter '%s' is not provided. Added as None. Signature: %s",
param.name,
signature,
)
named_inputs[param.name] = None # type: ignore[assignment]
else:
# Handle attributes
attribute: ValidAttributeType | ir.Attr
assert isinstance(
param, _schemas.AttributeParameter
), f"Expected AttributeParameter, got {type(param)}"
if reversed_args_stack:
# First exhaust the positional arguments
attribute = reversed_args_stack.pop() # type: ignore[assignment]
elif param.name in kwargs:
attribute = kwargs[param.name] # type: ignore[assignment]
elif param.default is not None:
attribute = param.default
else:
attribute = None
if attribute is None:
if param.required:
raise ValueError(
f"Required attribute '{param.name}' is not provided. "
f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}."
)
else:
logger.debug(
"Optional attribute '%s' is None. Dropped. Signature: %s",
param.name,
signature,
)
continue
if isinstance(attribute, ir.Attr):
# Turn the attribute from an default value into an actual parameter for the node
attr_copied = copy.copy(attribute)
# Make sure the name is the same as the parameter name and not the name of the default parameter
attr_copied.name = param.name
attribute = attr_copied
if isinstance(attribute, int) and param.type == ir.AttributeType.FLOAT:
# Convert the attribute to float if needed. This happens in PyTorch
# where an attribute marked as float can be passed as an int.
attribute = float(attribute)
named_attrs[param.name] = attribute
return named_inputs, named_attrs # type: ignore[return-value]
def _resolve_parameter_dtypes(
signature: _schemas.OpSignature, named_inputs: Mapping[str, AllowedArgType]
) -> Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol]:
"""Determine which parameter takes which type.
Handle non-tensor input corner cases and type promotion.
Requires:
All ir.Value in name_inputs should have type set. Their type should be
compatible with the type_constraint of the corresponding parameter in the signature.
Args:
signature: The OpSignature for the node.
named_inputs: The mapping of parameter names to their arguments.
Returns:
A mapping of Constraint names to ir.TypeProtocol.
"""
# a. Create type_binding: dict[str, ir.TypeProtocol]
# b. Iterate over all named_inputs
# b0. Find the corresponding parameter in the signature
# b1. If the argument is a Python constant, skip.
# b2. If the argument is a ir.Value, Bind {constraint: arg.type}.
type_binding = {}
for name, arg in named_inputs.items():
param = signature.params_map[name]
assert isinstance(
param, _schemas.Parameter
), f"Expected Parameter, got {type(param)}"
if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)):
# Skip the Python constants because we do not know what dtype they should take yet
continue
elif isinstance(arg, ir.Value):
if arg.type is None:
# Skip the ir.Value if the type is not set
continue
# NOTE: We assume arg.type is compatible with the type_constraint
assert arg.type is not None, f"Expected type to be set for {arg}"
# TODO(justinchuby): Implement type promotion logic here.
type_binding[param.type_constraint] = arg.type
return type_binding
def _process_python_constants_and_sequences(
signature: _schemas.OpSignature,
named_inputs: dict[str, AllowedArgType],
type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol],
constant_farm: dict[
tuple[
bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float],
ir.DataType,
],
ir.Value,
],
opset: onnxscript.values.Opset,
) -> dict[str, ir.Value | None]:
"""Convert Python constants to Constant nodes and list to Sequence nodes based on the dtype information.
The added constants will be replacing values in named_inputs in place.
Args:
signature: The OpSignature for the node.
named_inputs: The mapping of parameter names to their arguments.
type_binding: A mapping of Constraint names to ir.DataType.
constant_farm: A dictionary of {(py_value, ir.DataType): ir.Value} to store the deduplicated constants.
opset: The Opset to use for creating Constant nodes.
Returns:
None
"""
# 3. Convert Python constants to Constant nodes based on the dtype information;
# construct sequences
# a. Iterate over all parameters in the signature the second time
# b. If the parameter is in to_resolve_type:
# - If param.constraint in type_binding,
# Get the constant from constant_farm (deduplicated);
# otherwise set named_inputs[param.name] = Constant(value, dtype=type_binding[param.constraint])
# - Otherwise, set named_inputs[param.name] = Constant(value)
for name, arg in named_inputs.items():
param = signature.params_map[name]
assert isinstance(
param, _schemas.Parameter
), f"Expected Parameter, got {type(param)}"
if isinstance(arg, ir.Value):
# TODO(justinchuby): Cast the ir.Value here if needed
continue
if (
isinstance(arg, Sequence)
and len(arg) > 0
and all(isinstance(val, ir.Value) for val in arg)
):
# Skip the sequence of ir.Value. This is a variadic input or a Sequence input
# NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants
# like `Max(0, ir.Value())`
# We need to convert the Python constants to Constant nodes
# NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float]
continue
# if param.variadic:
# # FXIME: Handle variadic inputs and sequence inputs differently
# raise NotImplementedError
# TODO: Find a way to recursively build constants. Maybe extract the logic out.
# FIXME: I am here
assert isinstance(
param, _schemas.Parameter
), f"Expected Parameter, got {type(param)}"
if param.type_constraint in type_binding:
# A known dtype is available
dtype = type_binding[param.type_constraint].dtype
elif len(param.type_constraint.allowed_types) == 1:
# Only one type is allowed
dtype = next(iter(param.type_constraint.allowed_types)).dtype
else:
# No dtype information available. Infer from the Python constant
if isinstance(arg, bool):
dtype = ir.DataType.BOOL
elif isinstance(arg, float):
dtype = ir.DataType.FLOAT
elif isinstance(arg, int):
dtype = ir.DataType.INT64
elif isinstance(arg, str):
dtype = ir.DataType.STRING
elif isinstance(arg, (tuple, list)) and all(
isinstance(val, int) for val in arg
):
dtype = ir.DataType.INT64
elif isinstance(arg, (tuple, list)) and any(
isinstance(val, float) for val in arg
):
# NOTE: if any float is present, the dtype is float
dtype = ir.DataType.FLOAT
elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
dtype = arg.dtype
elif arg is None:
dtype = ir.DataType.UNDEFINED
else:
raise TypeError(
f"Constant input '{arg}' of type '{type(arg)}' is not supported"
)
if arg is None:
constant_value = None
elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
# Deduplicate the constants
if isinstance(arg, (tuple, list)):
# Make the arg hashable
arg = tuple(arg) # noqa: PLW2901
constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type]
if constant_value is None:
constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type]
constant_value = opset.Constant(value=constant_tensor)
constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index]
else:
constant_value = opset.Constant(value=arg)
named_inputs[param.name] = constant_value
return named_inputs # type: ignore[return-value]
def _construct_node(
signature: _schemas.OpSignature,
named_inputs: Mapping[str, ir.Value | None],
named_attrs: Mapping[str, ValidAttributeType],
opset: onnxscript.values.Opset,
) -> ir.Node:
"""Construct the node with the inputs and attributes.
Variadic inputs are flattened.
Args:
signature: The OpSignature for the node.
named_inputs: The mapping of parameter names to their arguments. When we
do not have the schema of an operator, we do not know the names of
the inputs, in which case the names can be anything because they
are not used in this function. The data structure is passed in for
consistency with the other functions.
named_attrs: The mapping of attribute names to their values.
"""
inputs: list[Any] = []
# Flatten variadic inputs
for value in named_inputs.values():
if isinstance(value, Sequence):
inputs.extend(value)
else:
inputs.append(value)
# Construct and filter out None attributes
attributes = [
attr
for attr in ir_convenience.convert_attributes(named_attrs)
if attr.value is not None
]
outputs = [_tensors.SymbolicTensor(opset) for _ in signature.outputs]
return ir.Node(
signature.domain,
signature.name,
inputs=inputs,
attributes=attributes,
outputs=outputs,
)
class OpRecorder(evaluator.Evaluator):
"""An onnxscript Evaluator that captures the graph into torchscript."""
def __init__(
self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value]
):
self.nodes: list[ir.Node] = []
self.opset = opset
self.functions: dict[ir.OperatorIdentifier, onnxscript.OnnxFunction] = {}
self.constant_farm = constant_farm
def _call_op(
self,
op_signature: _schemas.OpSignature,
named_inputs: dict[str, AllowedArgType],
named_attrs: dict[str, ValidAttributeType],
) -> Sequence[_tensors.SymbolicTensor]:
"""Record nodes for the given opschema and arguments.
Args:
op_signature: The OpSchema containing the node signature.
named_inputs: The mapping of parameter names to their arguments.
named_attrs: The mapping of attribute names to their values.
"""
type_binding = _resolve_parameter_dtypes(op_signature, named_inputs)
try:
converted_named_inputs = _process_python_constants_and_sequences(
op_signature, named_inputs, type_binding, self.constant_farm, self.opset
)
except Exception as e:
raise errors.GraphConstructionError(
f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. "
f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}."
) from e
try:
self.nodes.append(
node := _construct_node(
op_signature, converted_named_inputs, named_attrs, self.opset
)
)
except Exception as e:
raise errors.GraphConstructionError(
f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. "
f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, "
f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}."
) from e
return node.outputs # type: ignore[return-value]
def eval(
self,
schema: onnx.defs.OpSchema,
args: Sequence[AllowedArgType], # type: ignore[override]
kwargs: Mapping[str, AllowedArgType],
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor]:
try:
op_signature = _schemas.OpSignature.from_opschema(schema)
named_inputs, named_attrs = _construct_named_inputs_and_attrs(
op_signature, args, kwargs
)
# TODO(justinchuby): Handle cast
if schema.name == "CastLike":
assert len(named_inputs) == 2
# Skip CastLike if the input and output types are the same
src_input = named_inputs["input"]
target_type = named_inputs["target_type"]
if (
isinstance(src_input, ir.Value)
and isinstance(target_type, ir.Value)
and src_input.dtype is not None
and target_type.dtype is not None
):
# dtypes are available
if src_input.dtype == target_type.dtype:
# Same type. No cast needed
return src_input # type: ignore[return-value]
else:
# Create a Cast node
return self.opset.Cast(src_input, to=target_type.dtype) # type: ignore[union-attr,return-value]
outputs = self._call_op(op_signature, named_inputs, named_attrs)
if len(outputs) == 1:
return outputs[0]
return outputs
except Exception as e:
raise errors.GraphConstructionError(
f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}."
) from e
def eval_function( # type: ignore[override]
self,
function: onnxscript.OnnxFunction,
args: Sequence[AllowedArgType],
kwargs: Mapping[str, AllowedArgType],
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
try:
# Special cases for handling IsScalar and Rank
if function.name == "IsScalar":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], _tensors.SymbolicTensor):
if args[0].rank is not None:
return args[0].rank == 0
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
return False
else:
# Python constants are scalars
return True
if function.name == "Rank":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], _tensors.SymbolicTensor):
if args[0].rank is not None:
return args[0].rank
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
if all(isinstance(arg, (int, float)) for arg in args[0]):
return 1
else:
# Fall to call add_function_call
pass
else:
# Python constants are scalars
return 0
# NOTE: signature is written to function in the registration process
# TODO: Upstream signature to ONNX Function
if hasattr(function, "signature"):
op_signature = function.signature
else:
op_signature = _schemas.OpSignature.from_function(
function, function.function_ir.domain, function.name
)
named_inputs, named_attrs = _construct_named_inputs_and_attrs(
op_signature, args, kwargs
)
# NOTE: We need to call traceable functions after the _construct_named_inputs_and_attrs
# call because it will filter out the unexpected kwargs for us.
if function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(**named_inputs, **named_attrs)
outputs = self._call_op(op_signature, named_inputs, named_attrs)
self.functions[(function.function_ir.domain, function.name, "")] = function
if len(outputs) == 1:
return outputs[0]
return outputs
except Exception as e:
try:
source_file = inspect.getsourcefile(function.function)
_, lineno = inspect.getsourcelines(function.function)
except Exception:
source_file = lineno = None
raise errors.GraphConstructionError(
f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}."
+ f" The function is defined at '{source_file}:{lineno}'."
if source_file
else ""
) from e

View File

@ -0,0 +1,335 @@
"""Strategies for capturing ExportedPrograms."""
# mypy: allow-untyped-defs
from __future__ import annotations
import abc
import dataclasses
import datetime
import pathlib
from typing import Any, Callable, TYPE_CHECKING
import torch
from torch._export import converter as _torchscript_converter
from torch.utils import _pytree
if TYPE_CHECKING:
import os
def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
"""Prints messages based on `verbose`."""
if verbose is False:
return lambda *_, **__: None
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
def _take_first_line(text: str) -> str:
"""Take the first line of a text."""
lines = text.split("\n", maxsplit=1)
first_line = lines[0]
if len(lines) > 1:
first_line += "[...]"
return first_line
@dataclasses.dataclass
class Result:
exported_program: torch.export.ExportedProgram | None
strategy: str
exception: Exception | None = None
@property
def success(self) -> bool:
return self.exported_program is not None
class CaptureStrategy(abc.ABC):
"""Strategy for capturing a module as ExportedProgram.
To use a strategy, create an instance and call it with the model, args, kwargs, and dynamic_shapes.
Example::
strategy = TorchExportStrategy(verbose=True)
result = strategy(model, args, kwargs, dynamic_shapes)
"""
def __init__(
self,
*,
verbose: bool = False,
dump: bool = False,
artifacts_dir: str | os.PathLike = ".",
timestamp: str | None = None,
):
"""Initialize the strategy.
Args:
verbose: Whether to print verbose messages.
dump: Whether to dump the intermediate artifacts to a file.
"""
self._verbose_print = _verbose_printer(verbose)
self._dump = dump
self._artifacts_dir = pathlib.Path(artifacts_dir)
self._timestamp = timestamp or datetime.datetime.now().strftime(
"%Y-%m-%d_%H-%M-%S-%f"
)
def __call__(
self,
model: torch.nn.Module | torch.jit.ScriptFunction,
args: tuple[Any, ...],
kwargs: dict[str, Any] | None,
dynamic_shapes,
) -> Result:
self._enter(model)
if kwargs is None:
kwargs = {}
try:
exported_program = self._capture(model, args, kwargs, dynamic_shapes)
except Exception as e:
self._failure(model, e)
return Result(
exported_program=None,
strategy=self.__class__.__name__,
exception=e,
)
self._success(model)
return Result(exported_program, strategy=self.__call__.__name__)
@abc.abstractmethod
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
raise NotImplementedError
def _enter(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None:
return
def _success(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None:
return
def _failure(
self, model: torch.nn.Module | torch.jit.ScriptFunction, e: Exception
) -> None:
return
class TorchExportStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export`..."
)
def _success(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export`... ✅"
)
def _failure(self, model, e) -> None:
del e # Unused
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export`... ❌"
)
class TorchExportNonStrictStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
)
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`..."
)
def _success(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ✅"
)
def _failure(self, model, e) -> None:
del e # Unused
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ❌"
)
class JitTraceConvertStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
del dynamic_shapes # Unused
flattened_args, spec = _pytree.tree_flatten((args, kwargs))
flattened_args = tuple(flattened_args)
# Since torch.jit.trace only accepts Tensors as inputs, we filter
# out non-Tensor arguments and reconstruct the arguments after entering
# the WrappedModel.
tensor_placeholder = object()
non_tensor_args = [
arg if not isinstance(arg, torch.Tensor) else tensor_placeholder
for arg in flattened_args
]
tensor_args = tuple(
arg for arg in flattened_args if isinstance(arg, torch.Tensor)
)
class WrappedModel(torch.nn.Module):
"""Wrap the model so that it takes flattened arguments."""
def __init__(self, m):
super().__init__()
self.model = m
def forward(self, *_args):
# Take the non-Tensor arguments list as a starting point and
# replace the tensor_placeholder with the actual tensor arguments
# from _args.
reconstructed_flattened_args = non_tensor_args.copy()
_args_iter = iter(_args)
for i, arg in enumerate(reconstructed_flattened_args):
if arg is tensor_placeholder:
reconstructed_flattened_args[i] = next(_args_iter)
# Unflatten the arguments and kwargs to pass to the model.
unflattened_args, unflattened_kwargs = _pytree.tree_unflatten(
reconstructed_flattened_args, spec
)
results = self.model(*unflattened_args, **unflattened_kwargs)
if not isinstance(results, tuple):
results = (results,)
flattened_results, _ = _pytree.tree_flatten(results)
if len(flattened_results) == 1:
return flattened_results[0]
return tuple(flattened_results)
jit_model = torch.jit.trace(
WrappedModel(model),
example_inputs=tensor_args,
check_trace=False,
strict=False,
)
if self._dump:
program_path = self._artifacts_dir / f"onnx_export_{self._timestamp}.pt"
try:
torch.jit.save(jit_model, program_path)
except Exception as e:
self._verbose_print(
f"Failed to save Torch Script model due to an error: {e}"
)
else:
self._verbose_print(
f"Torch Script model has been saved to '{program_path}'."
)
return _torchscript_converter.TS2EPConverter(
jit_model, flattened_args
).convert()
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with Torch Script..."
)
def _success(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with Torch Script... ✅"
)
def _failure(self, model, e) -> None:
del e # Unused
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with Torch Script... ❌"
)
class LegacyDynamoStrategy(CaptureStrategy):
"""Strategy implemented by the ONNX team using internal dynamo APIs and custom fx passes."""
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
# NOTE: Import here to prevent circular dependency
from torch.onnx._internal.fx import diagnostics, passes
graph_module, _ = torch._dynamo.export(
model,
tracing_mode="symbolic",
dynamic_shapes=dynamic_shapes,
)(
*args,
**kwargs,
)
torch._dynamo.reset()
diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.export",
torch.__version__,
)
flattened_args, _ = _pytree.tree_flatten((args, kwargs))
flattened_args = tuple(flattened_args)
# ONNX does not support views and mutations.
# Functionalize to get a semantically equivalent graph without mutations.
graph_module = passes.Functionalize(
diagnostic_context,
graph_module,
enable_dynamic_axes=bool(dynamic_shapes),
).run(*flattened_args)
# Input mutations are detected and distilled after `Functionalize` pass.
# Remove them since ONNX inference does not need them.
graph_module = passes.RemoveInputMutation(diagnostic_context, graph_module).run(
*flattened_args
)
# Use torch.export to recapture the GraphModule into an ExportedProgram.
return torch.export.export(graph_module, flattened_args)
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with internal Dynamo apis..."
)
def _success(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ✅"
)
def _failure(self, model, e) -> None:
del e # Unused
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ❌"
)
CAPTURE_STRATEGIES = (
TorchExportStrategy,
TorchExportNonStrictStrategy,
JitTraceConvertStrategy,
LegacyDynamoStrategy,
)

View File

@ -0,0 +1,225 @@
"""Compatibility functions for the torch.onnx.export API."""
# mypy: allow-untyped-defs
# mypy: disable-error-code=attr-defined
from __future__ import annotations
import inspect
import logging
from typing import Any, Mapping, Sequence, TYPE_CHECKING
import onnx
import torch
import torch.export
from torch.onnx._internal.exporter import _core, _onnx_program
if TYPE_CHECKING:
import os
logger = logging.getLogger(__name__)
def _signature(model) -> inspect.Signature:
should_be_callable = getattr(model, "forward", model)
if callable(should_be_callable):
return inspect.signature(should_be_callable)
raise ValueError("model has no forward method and is not callable")
def _from_dynamic_axes_to_dynamic_shapes(
model,
dynamic_axes=None,
input_names: Sequence[str] | None = None,
) -> dict[str, Any] | None:
"""
dynamic_axes examples:
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
(2) dynamic_axes = {"x": [0], "y": [1]}
these will be converted to dynamic_shapes respectively:
(1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
(2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names
"""
# https://github.com/pytorch/pytorch/pull/128371
# 1. The function does not need to provide dynamic_shapes to torch.export.export
if dynamic_axes is None:
return None
if input_names is None:
input_names = []
sig = _signature(model)
if len(input_names) > len(sig.parameters):
raise ValueError(
f"Number of input names ({len(input_names)}) should not be greater than "
f"the number of model inputs ({len(sig.parameters)})"
)
input_names_to_model_inputs = {}
for idx, param_name in enumerate(sig.parameters):
if idx < len(input_names):
input_names_to_model_inputs[input_names[idx]] = param_name
else:
input_names_to_model_inputs[param_name] = param_name
# NOTE: torch.export.export does not support input names assignment,
# so we need to map input names to model inputs to create dynamic_shapes
# for the exported program
dynamic_shapes_to_exported_program = {}
for input_name, axes in dynamic_axes.items():
# input_name can be either from inptu_names or from the model inputs
if input_name not in input_names_to_model_inputs:
raise ValueError(
f"dynamix axis: {input_name} is not found in the input names: {input_names}"
)
model_input_name = input_names_to_model_inputs[input_name]
if isinstance(axes, dict):
dynamic_shapes_to_exported_program[model_input_name] = {
k: torch.export.Dim(v) for k, v in axes.items()
}
elif isinstance(axes, list):
dynamic_shapes_to_exported_program[model_input_name] = {
k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes
}
else:
raise TypeError(
f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
)
# torch.export.export needs static dim to present in dynamic_shapes
# for all input tensors, so we need to add them with None
for input_name in sig.parameters:
if input_name not in dynamic_shapes_to_exported_program:
dynamic_shapes_to_exported_program[input_name] = None # type: ignore[assignment]
return dynamic_shapes_to_exported_program
def _get_torch_export_args(
args: tuple[Any, ...],
kwargs: dict[str, Any] | None,
) -> tuple[tuple[Any, ...], dict[str, Any] | None]:
"""Obtain the arguments for torch.onnx.export from the model and the input arguments."""
if not kwargs and args and isinstance(args[-1], dict):
kwargs = args[-1]
args = args[:-1]
return args, kwargs
def _convert_version(path: str | os.PathLike, opset_version: int) -> None:
"""Convert the ONNX file to a specific version."""
model = onnx.load(path, load_external_data=False)
model = onnx.version_converter.convert_version(model, opset_version)
onnx.save(model, path)
def export_compat(
model: torch.nn.Module
| torch.export.ExportedProgram
| torch.jit.ScriptModule
| torch.jit.ScriptFunction,
args: tuple[Any, ...],
f: str | os.PathLike | None = None,
*,
kwargs: dict[str, Any] | None = None,
export_params: bool = True,
verbose: bool | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
| Mapping[str, Sequence[int]]
| None = None,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
keep_initializers_as_inputs: bool = False,
external_data: bool = True,
report: bool = False,
verify: bool = False,
profile: bool = False,
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
**_,
) -> _onnx_program.ONNXProgram | None:
if isinstance(model, torch.export.ExportedProgram):
# We the model is already exported program, so the args, kwargs, and dynamic_shapes
# are not used
dynamic_shapes = dynamic_shapes or {}
else:
args, kwargs = _get_torch_export_args(args, kwargs)
if dynamic_shapes is None and dynamic_axes is not None:
dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
model, dynamic_axes, input_names
)
should_convert_version = False
try:
onnx_program = _core.export(
model,
args,
kwargs,
registry=None,
dynamic_shapes=dynamic_shapes,
input_names=input_names,
output_names=output_names,
profile=profile,
report=report,
verify=verify,
dump_exported_program=dump_exported_program,
artifacts_dir=artifacts_dir,
verbose=verbose,
)
if f is not None:
# Always save the initializers as external data to reduce the size of the ONNX file
onnx_program.save(
f,
include_initializers=export_params,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,
)
if (
opset_version is not None
and opset_version != onnx_program.model.opset_imports.get("")
):
should_convert_version = True
except Exception as e:
if fallback:
if verbose is not False:
print(
"[torch.onnx] Falling back to legacy torch.onnx.export due "
f"to the following error: {e}",
)
torch.onnx.utils.export(
model, # type: ignore[arg-type]
args,
f, # type: ignore[arg-type]
kwargs=kwargs,
export_params=export_params,
input_names=input_names,
output_names=output_names,
opset_version=17, # TODO(justinchuby): Hard coded to 17 for now
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
)
onnx_program = None
if opset_version is None:
opset_version = 18
if opset_version != 17:
should_convert_version = True
else:
raise
if f is not None and should_convert_version:
assert opset_version is not None
if verbose is not False:
print(
f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..."
)
_convert_version(f, opset_version)
return onnx_program

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
"""Build decomp table from PyTorch."""
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Callable, TYPE_CHECKING
import torch
import torch._ops
if TYPE_CHECKING:
from torch.onnx._internal.exporter import _registration
def get_onnx_implemented_overloads(
registry: _registration.ONNXRegistry,
) -> list[torch._ops.OperatorBase]:
"""
Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
Args:
registry: The ONNX registry for PyTorch.
Returns:
A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
"""
registered_ops: list[torch._ops.OperatorBase] = []
for op_namespace in (torch.ops.aten, torch.ops.prims):
op_names = dir(op_namespace)
for op_name in op_names:
op_overload_packet = getattr(op_namespace, op_name)
if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
continue
for overload_name in op_overload_packet.overloads():
op_overload = getattr(op_overload_packet, overload_name)
if registry.is_registered(op_overload):
registered_ops.append(op_overload)
return registered_ops
def create_onnx_friendly_decomposition_table(
registry,
) -> dict[torch._ops.OperatorBase, Callable]:
"""
This function creates a dictionary of op overloads and their decomposition functions
for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's
built-in aten-to-aten decomposition.
Args:
registry: The ONNX registry for PyTorch.
Returns:
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
decomposition functions.
"""
decomposition_table: dict[torch._ops.OperatorBase, Callable] = {}
onnx_registered_ops = set(get_onnx_implemented_overloads(registry))
# NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
# TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
# definitions in a single TORCH_LIBRARY block.
for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined]
# Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX
# symbolic function.
# NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is
# not exportable anyways.
if op_overload in onnx_registered_ops:
continue
decomposition_table[op_overload] = decomp_fn
return decomposition_table

View File

@ -0,0 +1,345 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import logging
from typing import Sequence
import onnxscript
from onnxscript import ir
import torch
import torch.fx
from torch.onnx._internal.exporter import _registration, _schemas
logger = logging.getLogger(__name__)
# Define utilities to convert PyTorch data types so users do not need to specify manually
_TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.DOUBLE,
torch.complex64: ir.DataType.FLOAT,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
}
def _torch_dtype_to_onnx_compatible_dtype(dtype: torch.dtype) -> ir.DataType:
return _TORCH_DTYPE_TO_ONNX_COMPATIBLE[dtype]
def _attribute_type_compatible_with_arg(
attr: _schemas.AttributeParameter,
value: ir.Value | int | float | bool | Sequence[int] | Sequence[float] | None,
) -> bool:
"""Check if the attribute type is compatible with the argument."""
if isinstance(value, bool):
return attr.type is ir.AttributeType.INT
if isinstance(value, str):
return attr.type is ir.AttributeType.STRING
if isinstance(value, int):
return attr.type in {ir.AttributeType.INT, ir.AttributeType.FLOAT}
if isinstance(value, float):
return attr.type is ir.AttributeType.FLOAT
if isinstance(value, complex):
return False
if isinstance(value, Sequence):
if attr.type is ir.AttributeType.INTS:
return all(isinstance(i, int) for i in value)
if attr.type is ir.AttributeType.FLOATS:
return all(isinstance(i, (int, float)) for i in value)
if isinstance(value, torch.dtype):
return attr.type is ir.AttributeType.INT
if isinstance(value, (torch.device, torch.memory_format, torch.layout)):
return attr.type is ir.AttributeType.STRING
if value is None and not attr.required:
# An optional attribute is not supplied
return True
return False
def _param_type_compatible_with_arg(
param: _schemas.Parameter,
value: ir.TypeProtocol
| str
| int
| float
| complex
| Sequence[int]
| Sequence[float]
| None,
assigned_types: dict[str, ir.TypeProtocol],
) -> bool:
# Handle Python types first
if isinstance(value, bool): # noqa: SIM102
if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}:
return True
if isinstance(value, int) and param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.INT4),
ir.TensorType(ir.DataType.INT8),
ir.TensorType(ir.DataType.INT16),
ir.TensorType(ir.DataType.INT32),
ir.TensorType(ir.DataType.INT64),
# Int inputs can be casted to a float too
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),
ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ),
ir.TensorType(ir.DataType.FLOAT16),
ir.TensorType(ir.DataType.FLOAT),
ir.TensorType(ir.DataType.DOUBLE),
}:
return True
if isinstance(value, float) and param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),
ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ),
ir.TensorType(ir.DataType.FLOAT16),
ir.TensorType(ir.DataType.FLOAT),
ir.TensorType(ir.DataType.DOUBLE),
}:
return True
if isinstance(value, complex) and param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.FLOAT),
ir.TensorType(ir.DataType.DOUBLE),
ir.TensorType(ir.DataType.COMPLEX64),
ir.TensorType(ir.DataType.COMPLEX128),
}:
return True
if isinstance(value, str): # noqa: SIM102
if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}:
return True
if isinstance(value, (list, tuple)):
if param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.INT32),
ir.TensorType(ir.DataType.INT64),
ir.TensorType(ir.DataType.FLOAT),
ir.TensorType(ir.DataType.DOUBLE),
ir.SequenceType(ir.TensorType(ir.DataType.INT32)),
ir.SequenceType(ir.TensorType(ir.DataType.INT64)),
ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)),
} and all(isinstance(i, (int)) for i in value):
# We will just allow any fx node and trust that the overload handles it
return True
if param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.FLOAT),
ir.TensorType(ir.DataType.DOUBLE),
ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)),
} and all(isinstance(i, (int, float)) for i in value):
# We will just allow any fx node and trust that the overload handles it
return True
if value is None and not param.required:
# An optional parameter is not supplied
return True
if not isinstance(value, ir.TypeProtocol):
return False
# Then check tensor types
if param.type_constraint.name in assigned_types:
# If a typevar is already bound, check if the value has the same type
assigned_type = assigned_types[param.type_constraint.name]
return assigned_type == value
# If the typevar is not bound, bind it to the value type
if value in param.type_constraint.allowed_types:
# TODO: Maybe just check dtype? Being more strict here for now
assigned_types[param.type_constraint.name] = value
return True
return False
def _get_type_from_tensor(
tensor: torch.Tensor | Sequence[torch.Tensor],
) -> ir.TypeProtocol:
if isinstance(tensor, torch.Tensor):
return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype))
first_tensor = next((item for item in tensor if item is not None), None)
if first_tensor is None:
return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED))
return ir.SequenceType(
ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(first_tensor.dtype))
)
def _get_first_tensor_in_node_list(
nodes: Sequence[torch.fx.Node | None],
) -> torch.Tensor | None:
for node in nodes:
if (
node is not None
and "val" in node.meta
and isinstance(node.meta["val"], torch.Tensor)
):
return node.meta["val"]
return None
def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]:
# FIXME: node.target may not have a schema
torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr]
node_args = {}
for arg, schema_arg in zip(node.args, torch_schema.arguments):
node_args[schema_arg.name] = arg
node_args.update(node.kwargs)
return node_args
def get_matching_overload(
node: torch.fx.Node,
overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction],
) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]:
"""Get the overload that matches the node's arguments.
Args:
node: The node to match.
overloads: The overloads to match against.
Returns:
A tuple containing the matched overload and a string describing the reason for failure or success.
"""
named_args = _get_named_fx_node_args(node)
# FIXME: node.target may and builtin and not have a schema
# FIXME: Handle when we don't know the names of the arguments
schema_args: dict[str, torch.Argument] = {
arg.name: arg
for arg in node.target._schema.arguments # type: ignore[union-attr]
}
failure_messages: list[str] = []
for overload in overloads:
assigned_types: dict[str, ir.TypeProtocol] = {}
fail_reason = ""
if not hasattr(overload, "signature"):
# When an overload does not have a signature, we assume it is a custom op and should be matched
return (
overload,
"The overload does not have a signature. Assuming it is a custom op and matching it.",
)
for param in overload.signature:
if param.name not in schema_args and param.required:
# We don't need to handle variadic inputs as there is none.
# A required parameter is not supplied.
fail_reason = "Required parameter not supplied"
break
# Get the argument
if param.name in named_args:
# Provided in Node args
arg = named_args[param.name]
elif (
param.name in schema_args
and schema_args[param.name].has_default_value()
):
# Provided in schema args
arg = schema_args[param.name].default_value
elif param.has_default():
# Provided in the ONNX op definition
arg = param.default
else:
fail_reason = "Parameter not provided"
break
if isinstance(param, _schemas.Parameter):
if isinstance(arg, torch.Tensor):
arg = _get_type_from_tensor(arg) # type: ignore[assignment]
if isinstance(arg, (list, tuple)) and any(
isinstance(t, torch.fx.Node) for t in arg
):
first_tensor = _get_first_tensor_in_node_list(arg)
assert first_tensor is not None
# FIXME: Handle symfloat here
arg = ir.SequenceType(_get_type_from_tensor(first_tensor)) # type: ignore[assignment]
elif isinstance(arg, torch.fx.Node):
meta_val = arg.meta["val"]
arg = _get_type_from_tensor(meta_val) # type: ignore[assignment]
# TODO: Handle None attributes
# FIXME: Handle symfloat etc.
# Handle tensors and Python values
if not _param_type_compatible_with_arg(param, arg, assigned_types): # type: ignore[arg-type]
fail_reason = (
f"Parameter type not compatible with argument: param=`{param}`, "
f"assigned_types=`{assigned_types}`, arg=`{arg}`"
)
break
elif isinstance(param, _schemas.AttributeParameter):
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
break
if not fail_reason:
return overload, "Successfully matched overload"
else:
failure_messages.append(
f"- Failed to match overload `{overload}`: {fail_reason}"
)
return (
None,
f"All overloads did not match the node `{node.format_node()}`.\n"
+ "\n".join(failure_messages),
)
def _arg_has_complex_dtype(arg) -> bool:
"""Check if the node has complex dtype recursively."""
if (
isinstance(arg, torch.fx.Node)
and "val" in arg.meta
and isinstance(arg.meta["val"], torch.Tensor)
and torch.is_complex(arg.meta["val"])
):
return True
elif isinstance(arg, list):
return any(_arg_has_complex_dtype(item) for item in arg)
return False
def dispatch(
node: torch.fx.Node, registry: _registration.ONNXRegistry
) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]:
"""Dispatch a node to an ONNX function based on the node's target and the ONNX registry.
Args:
node: The node to dispatch.
registry: The ONNX registry to use for dispatching.
Returns:
A tuple containing the matched ONNX function and a string describing the reason for failure or success.
"""
# TODO: Handle when node does not have a target
decomp_metas = registry.get_decomps(node.target) # type: ignore[arg-type]
# Determine if the node has complex inputs.
is_complex = any(_arg_has_complex_dtype(arg) for arg in node.args) or any(
_arg_has_complex_dtype(arg) for arg in node.kwargs.values()
)
if is_complex:
decomp_metas = [decomp for decomp in decomp_metas if decomp.is_complex]
if not decomp_metas:
return None, "No decompositions registered for the complex-valued input"
else:
decomp_metas = [decomp for decomp in decomp_metas if not decomp.is_complex]
if not decomp_metas:
return None, "No decompositions registered for the real-valued input"
if len(decomp_metas) == 1:
return (
decomp_metas[0].onnx_function,
"Fast path: Only one decomposition is defined",
)
overload, message = get_matching_overload(
node, [decomp.onnx_function for decomp in decomp_metas]
)
return overload, message

View File

@ -0,0 +1,72 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import torch
import torch.export
import torch.fx
from torch.onnx._internal.exporter import _decomp, _registration
from torch.onnx._internal.fx import diagnostics, passes
_ATEN_ASSERTION_TARGETS = frozenset(
{
torch.ops.aten.sym_constrain_range_for_size.default,
torch.ops.aten._assert_async.msg,
}
)
def decompose_with_registry(
exported_program: torch.export.ExportedProgram, registry: _registration.ONNXRegistry
) -> torch.export.ExportedProgram:
"""Decompose the exported program with the given registry.
This function is needed so it shows clearly on the profiler results.
"""
decomp_table = _decomp.create_onnx_friendly_decomposition_table(registry)
onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry))
# Try to preserve some known CompositeImplicitAutograd ops
aten = torch.ops.aten
to_preserve = {
aten._upsample_bilinear2d_aa.default,
aten._upsample_nearest_exact1d.vec,
aten._upsample_nearest_exact2d.vec,
aten._upsample_nearest_exact3d.vec,
aten.group_norm.default,
aten.linear.default,
aten.upsample_bilinear2d.default,
aten.upsample_bilinear2d.vec,
aten.upsample_linear1d.default,
aten.upsample_linear1d.vec,
aten.upsample_nearest1d.default,
aten.upsample_nearest1d.vec,
aten.upsample_nearest2d.default,
aten.upsample_nearest2d.vec,
aten.upsample_nearest3d.default,
aten.upsample_nearest3d.vec,
aten.upsample_trilinear3d.default,
aten.upsample_trilinear3d.vec,
}
# We can only preserve implemented ops
can_preserve = tuple(to_preserve.intersection(onnx_registered_ops))
return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve)
def insert_type_promotion_nodes(
graph_module: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Inplace pass to insert explicit type promotion nodes."""
diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.export",
torch.__version__,
)
return passes.InsertTypePromotion(diagnostic_context, graph_module).run()
def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove all assertion and check nodes from the FX graph"""
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in _ATEN_ASSERTION_TARGETS:
graph_module.graph.erase_node(node)
graph_module.recompile()
return graph_module

View File

@ -0,0 +1,41 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import logging
from typing import Sequence
from onnxscript import ir
logger = logging.getLogger(__name__)
def rename_inputs(model: ir.Model, new_names: Sequence[str]) -> None:
# TODO: Ensure the names do not have duplicates
for input, new_name in zip(model.graph.inputs, new_names):
input.metadata_props["pkg.torch.onnx.original_node_name"] = str(input.name)
input.name = new_name
def rename_outputs(model: ir.Model, new_names: Sequence[str]) -> None:
for output, new_name in zip(model.graph.outputs, new_names):
output.metadata_props["pkg.torch.onnx.original_node_name"] = str(output.name)
output.name = new_name
def add_torchlib_common_imports(model: ir.Model) -> None:
"""Hack to add torchlib common imports to the model."""
try:
# TODO(justinchuby): Remove this hack and improved onnxscript
from onnxscript.function_libs.torch_lib.ops import common as common_ops
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
is_scalar_func = ir.serde.deserialize_function(
common_ops.IsScalar.to_function_proto()
)
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
except Exception:
logger.exception("Failed to add torchlib common imports to the model.")

View File

@ -0,0 +1,55 @@
"""Isolated calls to methods that may segfault."""
# mypy: allow-untyped-defs
from __future__ import annotations
import multiprocessing
import os
import warnings
from typing import Callable
_IS_WINDOWS = os.name == "nt"
def _call_function_and_return_exception(func, args, kwargs):
"""Call function and return a exception if there is one."""
try:
return func(*args, **kwargs)
except Exception as e:
return e
def safe_call(func: Callable, *args, **kwargs):
"""Call a function in a separate process.
Args:
func: The function to call.
args: The positional arguments to pass to the function.
kwargs: The keyword arguments to pass to the function.
Returns:
The return value of the function.
Raises:
Exception: If the function raised an exception.
"""
if _IS_WINDOWS:
# On Windows, we cannot create a new process with fork.
warnings.warn(
f"A new process is not created for {func} on Windows.", stacklevel=1
)
return func(*args, **kwargs)
with multiprocessing.get_context("fork").Pool(1) as pool:
# It is important to fork a process here to prevent the main logic from
# running again when the user does not place it under a `if __name__ == "__main__":`
# block.
result = pool.apply_async(
_call_function_and_return_exception, (func, args, kwargs)
)
result = result.get(timeout=5)
if isinstance(result, Exception):
raise result
return result

View File

@ -0,0 +1,288 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="attr-defined,name-defined"
from __future__ import annotations
__all__ = ["ONNXProgram"]
import gc
import logging
import os
import pathlib
import tempfile
import textwrap
from typing import Callable, IO, Sequence, TYPE_CHECKING
import torch
from torch.onnx._internal import _lazy_import
from torch.utils import _pytree as pytree
onnx = _lazy_import.onnx
ir = _lazy_import.onnxscript_ir
if TYPE_CHECKING:
import onnxruntime as ort
logger = logging.getLogger(__name__)
def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession:
"""Initialize an ONNX Runtime inference session with the specified model."""
import onnxruntime as ort
session_options = ort.SessionOptions()
session_options.log_severity_level = 3 # 3: Error
possible_providers = (
"CUDAExecutionProvider",
"CPUExecutionProvider",
)
available_providers = set(ort.get_available_providers())
providers = [
provider for provider in possible_providers if provider in available_providers
]
return ort.InferenceSession(
model, providers=providers, sess_options=session_options
)
class ONNXProgram:
"""A substitute class for `torch.onnx.ONNXProgram`."""
def __init__(self, model: ir.Model, exported_program: torch.export.ExportedProgram):
self.model: ir.Model = model
self.exported_program = exported_program
self._inference_session: ort.InferenceSession | None = None
self._tempdir: tempfile.TemporaryDirectory | None = None
def __repr__(self) -> str:
return f"""\
ONNXProgram(
model=
{textwrap.indent(str(self.model), ' ' * 8)}
,
exported_program=
{textwrap.indent(str(self.exported_program), ' ' * 8)}
)
"""
def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]:
"""Run the ONNX model with the same arguments you would provide to the GraphModule."""
import onnxruntime as ort
flatten_args = _process_args(args, kwargs)
if self._inference_session is None:
self.initialize_inference_session()
assert self._inference_session is not None
# We don't expect non-tensor as inputs
ort_input = {
k.name: v.numpy(force=True)
for k, v in zip(self.model.graph.inputs, flatten_args)
}
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
logger.debug("Running the inference session with %s arguments.", len(ort_input))
outputs = self._inference_session.run(None, ort_input, run_options=run_options)
logger.debug("Inference session run completed.")
# TODO(justinchuby): Maybe output complex tensors as needed
return tuple(torch.from_numpy(output) for output in outputs)
@property
def model_proto(self) -> onnx.ModelProto:
"""Compatibility property for `torch.onnx.ONNXProgram.model_proto`."""
return ir.serde.serialize_model(self.model)
def save(
self,
destination: str | os.PathLike | IO[bytes],
*,
include_initializers: bool = True,
keep_initializers_as_inputs: bool = False,
external_data: bool | None = None,
**_,
):
"""Save the ONNX model to the specified destination.
When `external_data` is `True` or the model is larger than 2GB,
the weights are saved as external data in a separate file.
Args:
destination: The path to save the ONNX model to.
include_initializers: Whether to include the initializers in the saved model.
keep_initializers_as_inputs: Whether to keep the initializers as inputs in the saved model.
If `True`, the initializers are added as inputs to the model which means they can be overwritten.
by providing the initializers as model inputs.
external_data: Whether to save the weights as external data in a separate file.
Raises:
TypeError: If `external_data` is `True` and `destination` is not a file path.
"""
if not include_initializers:
self.model.graph.initializers.clear()
logger.warning(
"The initializers have been removed from the model. This is destructive. "
"Developers: Please implement ir.Model copy() and remove initializers on the copied model."
)
if keep_initializers_as_inputs:
self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type]
logger.warning(
"The initializers have been added as inputs to the model. This is destructive. "
"Developers: Please implement ir.Model copy() and remove initializers on the copied model."
)
proto = ir.serde.serialize_model(self.model)
byte_size = proto.ByteSize()
model_too_large = (byte_size) >= 1 << 31
if external_data or model_too_large:
# TODO: Create an IR pass to handle external tensors conversion
if model_too_large:
logger.warning(
"The serialized ONNX model is larger than 2GB (%s). "
"Saving the weights as external data in a separate file.",
byte_size,
)
if not isinstance(destination, (str, os.PathLike)):
raise TypeError(
"Saving the weights as external data is only supported when destination is a file path"
)
destination_path = pathlib.Path(destination)
# Create the directory if it does not exist
data_path = f"{destination_path.name}.data"
onnx.save_model(
proto,
destination,
save_as_external_data=True,
location=data_path,
)
else:
onnx.save_model(proto, destination)
def initialize_inference_session(
self,
initializer: Callable[
[str | bytes], ort.InferenceSession
] = _ort_session_initializer,
) -> None:
"""Initialize the ONNX Runtime inference session.
Args:
initializer: The function to initialize the ONNX Runtime inference
session with the specified model. By default, it uses the
:func:`_ort_session_initializer` function.
"""
# TODO(justinchuby): Allow different inference options
logger.debug("Initializing the inference session.")
proto = ir.serde.serialize_model(self.model)
byte_size = proto.ByteSize()
model_too_large = (byte_size) >= 1 << 31
if model_too_large:
logger.debug(
"The serialized ONNX model is larger than 2GB (%s).", byte_size
)
# Save the model to a temporary file if too large
self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
model_path = os.path.join(self._tempdir.name, "model.onnx")
data_path = "model.onnx.data"
onnx.save_model(
proto,
model_path,
save_as_external_data=True,
location=data_path,
)
model = model_path
else:
model = proto.SerializeToString() # type: ignore[assignment]
self._inference_session = initializer(model)
logger.debug("Inference session initialized.")
def release(self) -> None:
"""Release the inference session.
You may call this method to release the resources used by the inference session.
"""
# Release the inference session first so that the model file can be deleted
if self._inference_session is not None:
self._inference_session = None
gc.collect()
if self._tempdir is not None:
self._tempdir.cleanup()
self._tempdir = None
def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:
"""Process input arguments for the ONNX model."""
args = _flatten_inputs(args, kwargs)
args = _remove_none_from_inputs(args)
args = _remove_non_tensor(args)
args = _convert_complex_to_real_representation(args)
return args
def _flatten_inputs(model_args, model_kwargs):
flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs))
return flattened_args
def _remove_none_from_inputs(model_args):
return tuple(arg for arg in model_args if arg is not None)
def _remove_non_tensor(model_args):
"""Remove the non-tensor input arguments.
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
The concrete value is embedded into the graph as a constant arg of a target node. Meta
suggests in this case that one should rewrite the model code to make it tensor if the
input value is supposed to change at runtime. We might need to further investigate
the feasibility of that suggestion.
For example,
def func(x, b=1.0):
y = x + b
z = y.relu()
return (y, z)
x = torch.randn(1, 1, 2, dtype=torch.float32)
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
# class GraphModule(torch.nn.Module):
# def forward(self, x, b):
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
it's ignored in ONNX graph. Thus, we delete the useless input here.
"""
return tuple(
arg for arg in model_args if not isinstance(arg, (int, float, bool, str))
)
def _convert_complex_to_real_representation(model_args):
"""Convert complex dtype tensors to real representation tensors.
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
to real representation tensors (i.e., float dtype tensors with an extra dimension
representing the real and imaginary parts of the complex number).
"""
return tuple(
torch.view_as_real(arg.resolve_conj())
if isinstance(arg, torch.Tensor) and arg.is_complex()
else arg
for arg in model_args
)

View File

@ -0,0 +1,275 @@
"""Module for handling ATen to ONNX functions registration.
https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py
"""
# NOTE: Why do we need a different registry than the one in torchlib?
# The registry in torchlib is used to register functions that are already implemented in
# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different
# opsets etc. The registry implemented for the exporter is designed to be modifiable at
# export time by users, and is designed with dispatching in mind.
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
import logging
import math
import operator
import types
import typing
from typing import Callable, Literal, Mapping, Union
from typing_extensions import TypeAlias
import torch
import torch._ops
from torch.onnx._internal.exporter import _schemas
if typing.TYPE_CHECKING:
import onnxscript
from onnxscript.function_libs.torch_lib import registration as torchlib_registration
_DEFAULT_OPSET_VERSION = 18
TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable]
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class OnnxDecompMeta:
"""A wrapper of onnx-script function with additional metadata.
onnx_function: The onnx-script function from torchlib.
fx_target: The PyTorch node callable target.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
device: The device the function is registered to. If None, it is registered to all devices.
"""
onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction
fx_target: TorchOp
is_custom: bool = False
is_complex: bool = False
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
"""Obtain the torch op from <namespace>::<op_name>[.<overload>]"""
# TODO(justinchuby): Handle arbitrary custom ops
namespace, opname_overload = qualified_name.split("::")
op_name, *maybe_overload = opname_overload.split(".", 1)
if namespace == "_operator":
# Builtin functions
return getattr(operator, op_name)
if namespace == "math":
return getattr(math, op_name)
if namespace == "torchvision":
try:
import torchvision.ops # type: ignore[import-untyped]
except ImportError:
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
return None
try:
return getattr(torchvision.ops, op_name)
except AttributeError:
logger.warning("Failed to find torchvision op '%s'", qualified_name)
return None
except Exception:
logger.exception("Failed to find torchvision op '%s'", qualified_name)
try:
op_packet = getattr(getattr(torch.ops, namespace), op_name)
if maybe_overload:
overload = maybe_overload[0]
elif "default" in op_packet._overload_names or "" in op_packet._overload_names:
# Has a default overload
overload = "default"
else:
logger.warning(
"'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.",
qualified_name,
stacklevel=1,
)
return None
return getattr(op_packet, overload) # type: ignore[call-overload]
except AttributeError:
if qualified_name.endswith("getitem"):
# This is a special case where we registered the function incorrectly,
# but for BC reasons (pt<=2.4) we need to keep it.
return None
logger.info("'%s' is not found in this version of PyTorch.", qualified_name)
return None
except Exception:
logger.exception("Failed to find torch op '%s'", qualified_name)
return None
class ONNXRegistry:
"""Registry for ONNX functions.
The registry maintains a mapping from qualified names to symbolic functions under a
fixed opset version. It supports registering custom onnx-script functions and for
dispatcher to dispatch calls to the appropriate function.
"""
def __init__(self) -> None:
"""Initializes the registry"""
# TODO: Design multi-opset version support
self._opset_version = _DEFAULT_OPSET_VERSION
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}
@property
def opset_version(self) -> int:
"""The ONNX opset version the exporter should target.
Defaults to the latest supported ONNX opset version: 18.
The default version will increment over time as ONNX continues to evolve.
"""
return self._opset_version
@classmethod
def from_torchlib(
cls,
torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction]
| None = None,
) -> ONNXRegistry:
"""Populates the registry with ATen functions from torchlib.
Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
if torchlib_registry is None:
from onnxscript.function_libs.torch_lib import (
registration as torchlib_registration,
)
torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment]
for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr]
try:
# NOTE: This is heavily guarded with try-except because we don't want
# to fail the entire registry population if one function fails.
if qualified_name.startswith("internal::"):
# Skip the custom defined internal functions
continue
target = _get_overload(qualified_name)
if target is None:
continue
for overload_func in aten_overloads_func.overloads:
overload_func.signature = _schemas.OpSignature.from_function(
overload_func,
overload_func.function_ir.domain,
overload_func.name,
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=overload_func,
fx_target=target,
is_custom=False,
is_complex=False,
)
registry._register(target, onnx_decomposition)
for complex_func in aten_overloads_func.complex:
overload_func.signature = _schemas.OpSignature.from_function(
overload_func,
overload_func.function_ir.domain,
overload_func.name,
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=complex_func,
fx_target=target,
is_custom=False,
is_complex=True,
)
registry._register(target, onnx_decomposition)
except Exception:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue
return registry
def _register(
self,
target: TorchOp,
onnx_decomposition: OnnxDecompMeta,
) -> None:
"""Registers a OnnxDecompMeta to an operator.
Args:
target: The PyTorch node callable target.
onnx_decomposition: The OnnxDecompMeta to register.
"""
target_or_name: str | TorchOp
if isinstance(target, torch._ops.OpOverload):
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
# a dictionary is unreliable for some reason.
target_or_name = target.name()
else:
target_or_name = target
if onnx_decomposition.is_custom:
self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition)
else:
self.functions.setdefault(target_or_name, []).append(onnx_decomposition)
def register_op(
self,
target: TorchOp,
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
is_complex: bool = False,
) -> None:
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
Args:
target: The PyTorch node callable target.
function: The onnx-script function to register.
is_complex: Whether the function is a function that handles complex valued inputs.
"""
onnx_decomposition = OnnxDecompMeta(
onnx_function=function,
fx_target=target,
is_custom=True,
is_complex=is_complex,
)
self._register(target, onnx_decomposition)
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
The list is ordered by the time of registration. The custom operators should come
first in the list.
Args:
target: The PyTorch node callable target.
Returns:
A list of OnnxDecompMeta corresponding to the given name, or None if
the name is not in the registry.
"""
target_or_name: str | TorchOp
if isinstance(target, torch._ops.OpOverload):
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
# a dictionary is unreliable for some reason.
target_or_name = target.name()
else:
target_or_name = target
decomps = self.functions.get(target_or_name, [])
return sorted(decomps, key=lambda x: x.is_custom, reverse=True)
def is_registered(self, target: TorchOp) -> bool:
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
Args:
target: The PyTorch node callable target.
Returns:
True if the given op is registered, otherwise False.
"""
return bool(self.get_decomps(target))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(functions={self.functions})"

View File

@ -0,0 +1,193 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
import re
from typing import TYPE_CHECKING
from torch.onnx._internal.exporter import _analysis, _registration, _verification
if TYPE_CHECKING:
import os
from onnxscript import ir
import torch
@dataclasses.dataclass
class ExportStatus:
# Whether torch.export.export.export() succeeds
torch_export: bool | None = None
# Whether torch.export.export.export(..., strict=False) succeeds
torch_export_non_strict: bool | None = None
# Whether torch.jit.trace succeeds
torch_jit: bool | None = None
# Whether ONNX translation succeeds
onnx_translation: bool | None = None
# Whether ONNX model passes onnx.checker.check_model
onnx_checker: bool | None = None
# Whether ONNX model runs successfully with ONNX Runtime
onnx_runtime: bool | None = None
# Whether the output of the ONNX model is accurate
output_accuracy: bool | None = None
def _status_emoji(status: bool | None) -> str:
if status is None:
return ""
return "" if status else ""
def _format_export_status(status: ExportStatus) -> str:
return (
f"```\n"
f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n"
f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n"
f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"
f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n"
f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n"
f"```\n\n"
)
def _strip_color_from_string(text: str) -> str:
# This regular expression matches ANSI escape codes
# https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788
ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
return ansi_escape.sub("", text)
def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str:
# Adapted from https://github.com/pytorch/pytorch/pull/128476
# to remove colors
# Even though we can call graph_module.print_readable directly, since the
# colored option was added only recently, we can't guarantee that the
# version of PyTorch used by the user has this option. Therefore, we
# still call str(ExportedProgram)
text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n"
return text
def construct_report_file_name(timestamp: str, status: ExportStatus) -> str:
# Status could be None. So we need to check for False explicitly.
if not (status.torch_export or status.torch_export_non_strict or status.torch_jit):
# All strategies failed
postfix = "pt_export"
elif status.onnx_translation is False:
postfix = "conversion"
elif status.onnx_checker is False:
postfix = "checker"
elif status.onnx_runtime is False:
postfix = "runtime"
elif status.output_accuracy is False:
postfix = "accuracy"
elif status.torch_export is False or status.torch_export_non_strict is False:
# Some strategies failed
postfix = "strategies"
else:
postfix = "success"
return f"onnx_export_{timestamp}_{postfix}.md"
def format_decomp_comparison(
pre_decomp_unique_ops: set[str],
post_decomp_unique_ops: set[str],
) -> str:
"""Format the decomposition comparison result.
Args:
unique_ops_in_a: The unique ops in the first program.
unique_ops_in_b: The unique ops in the second program.
Returns:
The formatted comparison result.
"""
return (
f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n"
f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n"
)
def format_verification_infos(
verification_infos: list[_verification.VerificationInfo],
) -> str:
"""Format the verification result.
Args:
verification_infos: The verification result.
Returns:
The formatted verification result.
"""
return "\n".join(
f"`{info.name}`: `abs_diff={info.absolute_difference:e}`, `rel_diff={info.relative_difference:e}`"
for info in verification_infos
)
def create_torch_export_error_report(
filename: str | os.PathLike,
formatted_traceback: str,
*,
export_status: ExportStatus,
profile_result: str | None,
):
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Error Report\n\n")
f.write(_format_export_status(export_status))
f.write("Error message:\n\n")
f.write("```pytb\n")
f.write(formatted_traceback)
f.write("```\n\n")
if profile_result is not None:
f.write("## Profiling result\n\n")
f.write("```\n")
f.write(profile_result)
f.write("```\n")
def create_onnx_export_report(
filename: str | os.PathLike,
formatted_traceback: str,
program: torch.export.ExportedProgram,
*,
decomp_comparison: str | None = None,
export_status: ExportStatus,
profile_result: str | None,
model: ir.Model | None = None,
registry: _registration.ONNXRegistry | None = None,
verification_result: str | None = None,
):
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Report\n\n")
f.write(_format_export_status(export_status))
f.write("## Error messages\n\n")
f.write("```pytb\n")
f.write(formatted_traceback)
f.write("\n```\n\n")
f.write("## Exported program\n\n")
f.write(_format_exported_program(program))
if model is not None:
f.write("## ONNX model\n\n")
f.write("```python\n")
f.write(str(model))
f.write("\n```\n\n")
f.write("## Analysis\n\n")
_analysis.analyze(program, file=f, registry=registry)
if decomp_comparison is not None:
f.write("\n## Decomposition comparison\n\n")
f.write(decomp_comparison)
f.write("\n")
if verification_result is not None:
f.write("\n## Verification results\n\n")
f.write(verification_result)
f.write("\n")
if profile_result is not None:
f.write("\n## Profiling result\n\n")
f.write("```\n")
f.write(profile_result)
f.write("```\n")

View File

@ -0,0 +1,548 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import collections.abc
import dataclasses
import inspect
import logging
import types
import typing
from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union
import onnx
import onnxscript
from onnxscript import ir
logger = logging.getLogger(__name__)
# A special value to indicate that the default value is not specified
class _Empty:
def __repr__(self):
return "_EMPTY_DEFAULT"
_EMPTY_DEFAULT = _Empty()
# Map from python type to corresponding ONNX AttributeProto type
_PY_TYPE_TO_ATTR_TYPE = {
float: ir.AttributeType.FLOAT,
int: ir.AttributeType.INT,
str: ir.AttributeType.STRING,
bool: ir.AttributeType.INT,
ir.Tensor: ir.AttributeType.TENSOR,
ir.TensorProtocol: ir.AttributeType.TENSOR,
ir.Graph: ir.AttributeType.GRAPH,
ir.GraphProtocol: ir.AttributeType.GRAPH,
}
# Map from python type to corresponding ONNX AttributeProto type,
# for repeated (i.e., list of) values
_LIST_TYPE_TO_ATTR_TYPE = {
float: ir.AttributeType.FLOATS,
int: ir.AttributeType.INTS,
str: ir.AttributeType.STRINGS,
bool: ir.AttributeType.INTS,
ir.Tensor: ir.AttributeType.TENSORS,
ir.TensorProtocol: ir.AttributeType.TENSORS,
ir.Graph: ir.AttributeType.GRAPHS,
ir.GraphProtocol: ir.AttributeType.GRAPHS,
}
_ALL_VALUE_TYPES = (
{ir.TensorType(dtype) for dtype in ir.DataType}
| {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}
| {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType}
)
# TypeAnnotationValue represents the (value of) valid type-annotations recognized
# by ONNX Script. Currently, it supports
# - float, int, str (primitive attribute types)
# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
# - Tensor types
# - Sequence[Tensor] types
# - Union of above 2
# - TypeVars with above bounds
# - Above types with annotation attached
TypeAnnotationValue = Any
@dataclasses.dataclass(frozen=True)
class TypeConstraintParam:
"""Type constraint for a parameter.
Attributes:
name: Name of the parameter. E.g. "TFloat"
allowed_types: Allowed types for the parameter.
"""
name: str
allowed_types: set[ir.TypeProtocol]
description: str = ""
def __hash__(self) -> int:
return hash((self.name, tuple(self.allowed_types)))
def __str__(self) -> str:
allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
return f"{self.name}={allowed_types_str}"
@classmethod
def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description)
@classmethod
def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type]
@dataclasses.dataclass(frozen=True)
class Parameter:
"""A formal parameter of an operator."""
name: str
type_constraint: TypeConstraintParam
required: bool
variadic: bool
default: Any = _EMPTY_DEFAULT
# TODO: Add other properties too
def __str__(self) -> str:
type_str = self.type_constraint.name
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"
def has_default(self) -> bool:
return self.default is not _EMPTY_DEFAULT
@dataclasses.dataclass(frozen=True)
class AttributeParameter:
name: str
type: ir.AttributeType
required: bool
default: ir.Attr | None = None
def __str__(self) -> str:
type_str = self.type.name
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"
def has_default(self) -> bool:
return self.default is not None
def _get_type_from_str(
type_str: str,
) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
"""Converter a type_str from ONNX Opschema to ir.TypeProtocol.
A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
"""
# TODO: Upstream this to IR
# Split the type_str a sequence types and dtypes
# 1. Remove the ending ")"
striped = type_str.rstrip(")")
# 2. Split the type_str by "("
type_parts = striped.split("(")
# Convert the dtype to ir.DataType
dtype = ir.DataType[type_parts[-1].upper()]
# Create a place holder type first
type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED)
# Construct the type
for type_part in reversed(type_parts[:-1]):
if type_part == "tensor":
type_ = ir.TensorType(dtype)
elif type_part == "seq":
type_ = ir.SequenceType(type_)
elif type_part == "optional":
type_ = ir.OptionalType(type_)
else:
raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
return type_ # type: ignore[return-value]
def _convert_formal_parameter(
param: onnx.defs.OpSchema.FormalParameter,
type_constraints: Mapping[str, TypeConstraintParam],
) -> Parameter:
"""Convert a formal parameter from ONNX Opschema to Parameter."""
if param.type_str in type_constraints:
type_constraint = type_constraints[param.type_str]
else:
# param.type_str can be a plain type like 'int64'.
type_constraint = TypeConstraintParam(
name=param.name,
allowed_types={_get_type_from_str(param.type_str)},
)
return Parameter(
name=param.name,
type_constraint=type_constraint,
required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
)
def _is_optional(type_: type) -> bool:
"""Returns whether a type_ is an Optional."""
origin_type = typing.get_origin(type_)
if origin_type is Union and type(None) in typing.get_args(type_):
# Python < 3.10
return True
if origin_type is Optional:
# Python >= 3.10
return True
if (
hasattr(types, "UnionType")
and origin_type is types.UnionType
and type(None) in typing.get_args(type_)
):
# Python >= 3.10
return True
return False
def _get_attr_type(type_: type) -> ir.AttributeType:
"""Obtain the type of the attribute from a Python class."""
try:
if type_ in _PY_TYPE_TO_ATTR_TYPE:
return _PY_TYPE_TO_ATTR_TYPE[type_]
origin_type = typing.get_origin(type_)
if origin_type is None:
return ir.AttributeType.UNDEFINED
if origin_type in (
collections.abc.Sequence,
Sequence,
typing.List,
list,
typing.Tuple,
tuple,
):
inner_type = typing.get_args(type_)[0]
if inner_type in _LIST_TYPE_TO_ATTR_TYPE:
return _LIST_TYPE_TO_ATTR_TYPE[inner_type]
except TypeError:
logger.warning("TypeError when checking %s.", type_, exc_info=True)
return ir.AttributeType.UNDEFINED
def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None:
"""Returns the name of the type constraint for a given type annotation.
Args:
type_: A Python type.
Returns:
The name of the type constraint if it is a TypeVar.
- Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
"""
if isinstance(type_, TypeVar):
return type_.__name__
if _is_optional(type_):
subtypes = typing.get_args(type_)
for subtype in subtypes:
if subtype is type(None):
continue
type_param_name = _get_type_constraint_name(subtype)
return type_param_name if type_param_name else None
origin_type = typing.get_origin(type_)
if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
subtypes = typing.get_args(type_)
type_param_name = _get_type_constraint_name(subtypes[0])
return f"Sequence_{type_param_name}" if type_param_name else None
return None
def _get_allowed_types_from_type_annotation(
type_: TypeAnnotationValue,
) -> set[ir.TypeProtocol]:
"""Obtain the allowed types from a type annotation."""
if type_ is onnxscript.onnx_types.TensorType:
# Any tensor type
return {ir.TensorType(dtype) for dtype in ir.DataType}
allowed_types: set[ir.TypeProtocol]
if isinstance(type_, TypeVar):
allowed_types = set()
if constraints := type_.__constraints__:
for constraint in constraints:
allowed_types.update(
_get_allowed_types_from_type_annotation(constraint)
)
else:
bound = type_.__bound__
if bound is None:
allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment]
else:
allowed_types.update(_get_allowed_types_from_type_annotation(bound))
return allowed_types
if hasattr(type_, "dtype"):
# A single tensor type like INT64, FLOAT, etc.
return {ir.TensorType(ir.DataType(type_.dtype))}
if _is_optional(type_):
allowed_types = set()
subtypes = typing.get_args(type_)
for subtype in subtypes:
if subtype is type(None):
continue
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
# NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful.
return allowed_types
origin_type = typing.get_origin(type_)
if origin_type is Union:
allowed_types = set()
subtypes = typing.get_args(type_)
for subtype in subtypes:
assert subtype is not type(
None
), "Union should not contain None type because it is handled by _is_optional."
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
return allowed_types
if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
subtypes = typing.get_args(type_)
return {
ir.SequenceType(t)
for t in _get_allowed_types_from_type_annotation(subtypes[0])
}
# Allow everything by default
return _ALL_VALUE_TYPES # type: ignore[return-value]
@dataclasses.dataclass
class OpSignature:
"""Schema for an operator.
Attributes:
domain: Domain of the operator. E.g. "".
name: Name of the operator. E.g. "Add".
overload: Overload name of the operator.
params: Input parameters. When the op is an ONNX function definition,
the order is according to the function signature. This mean we can
interleave ONNX inputs and ONNX attributes in the list.
outputs: Output parameters.
"""
domain: str
name: str
overload: str
params: Sequence[Parameter | AttributeParameter]
outputs: Sequence[Parameter]
params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
init=False, repr=False
)
def __post_init__(self):
self.params_map = {param.name: param for param in self.params}
def get(self, name: str) -> Parameter | AttributeParameter:
return self.params_map[name]
def __contains__(self, name: str) -> bool:
return name in self.params_map
def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
return iter(self.params)
def __str__(self) -> str:
domain = self.domain or "''"
# TODO: Double check the separator for overload
overload = f"::{self.overload}" if self.overload else ""
params = ", ".join(str(param) for param in self.params)
outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
type_constraints = {}
for param in self.params:
if isinstance(param, Parameter):
type_constraints[param.type_constraint.name] = param.type_constraint
for param in self.outputs:
type_constraints[param.type_constraint.name] = param.type_constraint
type_constraints_str = ", ".join(
str(type_constraint) for type_constraint in type_constraints.values()
)
return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"
@classmethod
def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature:
"""Produce an OpSignature from an ONNX Opschema."""
type_constraints = {
constraint.type_param_str: TypeConstraintParam(
name=constraint.type_param_str,
allowed_types={
_get_type_from_str(type_str)
for type_str in constraint.allowed_type_strs
},
description=constraint.description,
)
for constraint in opschema.type_constraints
}
params = [
_convert_formal_parameter(param, type_constraints)
for param in opschema.inputs
]
for param in opschema.attributes.values():
default_attr = (
ir.serde.deserialize_attribute(param.default_value)
if param.default_value is not None
else None
)
if default_attr is not None:
# Set the name of the default attribute because it may have a different name from the parameter
default_attr.name = param.name
params.append(
AttributeParameter(
name=param.name,
type=ir.AttributeType(param.type), # type: ignore[arg-type]
required=param.required,
default=default_attr, # type: ignore[arg-type]
)
)
outputs = [
_convert_formal_parameter(param, type_constraints)
for param in opschema.outputs
]
return cls(
domain=opschema.domain,
name=opschema.name,
overload="",
params=params,
outputs=outputs,
)
@classmethod
def from_function(
cls, func, domain: str, name: str | None = None, overload: str = ""
) -> OpSignature:
"""Produce an OpSignature from a function using type annotation."""
py_signature = inspect.signature(func)
# Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases
# https://github.com/python/cpython/issues/102405
type_hints = typing.get_type_hints(func)
params = []
# Create a mapping from type to a unique name
type_constraints: dict[str, TypeConstraintParam] = {}
for param in py_signature.parameters.values():
if param.name not in type_hints:
logger.warning(
"Missing annotation for parameter '%s' from %s. Treating as an Input.",
param.name,
py_signature,
)
type_constraints[param.name] = TypeConstraintParam.any_value(
f"T_{param.name}"
)
else:
type_ = type_hints[param.name]
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
# Construct the default attribute
if param.default is not inspect.Parameter.empty:
# TODO: Use ir_convenience instead to handle int as float
default = ir.Attr(param.name, attr_type, param.default)
else:
default = None
params.append(
AttributeParameter(
name=param.name,
type=attr_type,
required=param.default is inspect.Parameter.empty,
default=default,
)
)
else:
# Obtain the type constraint from the type annotation
# 1. Get a type constraint name from the type annotation
# If the type annotation is a TypeVar or Optional[TypeVar], get its name
# Otherwise, name it T_{param.name}
type_constraint_name = _get_type_constraint_name(type_)
if type_constraint_name is None:
type_constraint_name = f"T_{param.name}"
# 2. If the type constraint param is already initialized, use it
if type_constraint_name in type_constraints:
type_constraint = type_constraints[type_constraint_name]
else:
# 3. Otherwise, create a new TypeConstraintParam
type_constraint = TypeConstraintParam(
name=type_constraint_name,
allowed_types=_get_allowed_types_from_type_annotation(
type_
),
)
type_constraints[type_constraint_name] = type_constraint
# 4. Create Parameter
params.append(
Parameter( # type: ignore[arg-type]
name=param.name,
type_constraint=type_constraint,
required=param.default is inspect.Parameter.empty,
# TODO: Handle variadic
variadic=False,
default=param.default
if param.default is not inspect.Parameter.empty
else _EMPTY_DEFAULT,
)
)
return_type = type_hints.get("return")
outputs = []
if return_type is None:
# No returns
pass
else:
if typing.get_origin(return_type) is tuple:
# Multiple returns
return_types = typing.get_args(return_type)
else:
return_types = [return_type] # type: ignore[assignment]
for i, return_type_i in enumerate(return_types):
if (
return_param_name := _get_type_constraint_name(return_type_i)
) in type_constraints:
type_constraint = type_constraints[return_param_name]
else:
return_param_name = f"TReturn{i}"
type_constraint = TypeConstraintParam(
name=return_param_name,
allowed_types=_get_allowed_types_from_type_annotation(
return_type_i
),
)
type_constraints[return_param_name] = type_constraint
outputs.append(
Parameter(
name=return_param_name,
type_constraint=type_constraint,
required=True,
variadic=False,
default=_EMPTY_DEFAULT,
)
)
return cls(
domain=domain,
name=name or func.__name__,
overload=overload,
params=params,
outputs=outputs,
)

View File

@ -0,0 +1,98 @@
"""Subclass of ir.Value that supports Python operators."""
# mypy: allow-untyped-defs
from __future__ import annotations
import onnxscript
from onnxscript import ir
class SymbolicTensor(ir.Value):
"""A subclass of ir.Value that supports Python operators."""
def __init__(
self,
opset: onnxscript.values.Opset,
name: str | None = None,
shape: ir.Shape | None = None,
type: ir.TypeProtocol | None = None,
doc_string: str | None = None,
const_value: ir.TensorProtocol | None = None,
):
super().__init__(
name=name,
shape=shape,
type=type,
doc_string=doc_string,
const_value=const_value,
)
self._opset = opset
@property
def rank(self) -> int | None:
if self.shape is None:
return None
return len(self.shape)
# TODO: Implement indexing
def __mod__(self, other):
if self.dtype in {
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
}:
return self._opset.Mod(self, other, fmod=1)
return self._opset.Mod(self, other)
def __ne__(self, other):
return self._opset.Not(self._opset.Equal(self, other))
def __neg__(self):
return self._opset.Neg(self)
def __add__(self, other):
return self._opset.Add(self, other)
def __radd__(self, other):
return self._opset.Add(other, self)
def __rand__(self, other):
return self._opset.And(other, self)
def __mul__(self, other):
return self._opset.Mul(self, other)
def __rmul__(self, other):
return self._opset.Mul(other, self)
def __matmul__(self, other):
return self._opset.MatMul(self, other)
def __pow__(self, other):
return self._opset.Pow(self, other)
def __sub__(self, other):
return self._opset.Sub(self, other)
def __rsub__(self, other):
return self._opset.Sub(other, self)
def __truediv__(self, other):
return self._opset.Div(self, other)
def __lt__(self, other):
return self._opset.Less(self, other)
def __le__(self, other):
return self._opset.LessOrEqual(self, other)
def __eq__(self, other):
return self._opset.Equal(self, other)
def __ge__(self, other):
return self._opset.GreaterOrEqual(self, other)
def __gt__(self, other):
return self._opset.Greater(self, other)

View File

@ -0,0 +1,79 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
from typing import Any, TYPE_CHECKING
import torch
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from torch.onnx._internal.exporter import _onnx_program
@dataclasses.dataclass
class VerificationInfo:
name: str
absolute_difference: float
relative_difference: float
expected_dtype: torch.dtype
actual_dtype: torch.dtype
# NOTE: We don't need to include shape because the expected shape is already known
# and checked by the runtime
def _compare_tensors(
expected: torch.Tensor,
actual: torch.Tensor,
) -> tuple[float, float]:
# Move tensors to the same device
expected = expected.detach().cpu()
actual = actual.detach().cpu()
absolute_difference = torch.abs(expected - actual).max().item()
eps = 1e-7
relative_difference = (
(torch.abs(expected - actual) / (torch.abs(expected) + eps)).max().item()
)
return absolute_difference, relative_difference
def verify_onnx_program(
onnx_program: _onnx_program.ONNXProgram,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
) -> list[VerificationInfo]:
exported_program = onnx_program.exported_program
if args is None and kwargs is None:
# User did not provide example inputs, use the default example inputs
if exported_program.example_inputs is None:
raise ValueError(
"No example inputs provided and the exported_program does not contain example inputs. "
"Please provide arguments to verify the ONNX program."
)
args, kwargs = exported_program.example_inputs
if args is None:
args = ()
if kwargs is None:
kwargs = {}
torch_module = exported_program.module()
torch_outputs, _ = pytree.tree_flatten(torch_module(*args, **kwargs))
onnx_outputs = onnx_program(*args, **kwargs)
results = []
for torch_output, onnx_output, output_val in zip(
torch_outputs, onnx_outputs, onnx_program.model.graph.outputs
):
name = output_val.name
absolute_difference, relative_difference = _compare_tensors(
torch_output, onnx_output
)
results.append(
VerificationInfo(
name=str(name),
absolute_difference=absolute_difference,
relative_difference=relative_difference,
expected_dtype=torch_output.dtype,
actual_dtype=onnx_output.dtype,
)
)
return results

View File

@ -0,0 +1,30 @@
class ExporterError(RuntimeError):
"""Error during export."""
class TorchExportError(ExporterError):
"""Error during torch.export.export."""
class OnnxConversionError(ExporterError):
"""Error during ONNX conversion."""
class DispatchError(OnnxConversionError):
"""Error during ONNX Funtion dispatching."""
class GraphConstructionError(OnnxConversionError):
"""Error during graph construction."""
class OnnxCheckerError(ExporterError):
"""Error during ONNX model checking."""
class OnnxRuntimeError(ExporterError):
"""Error during ONNX Runtime execution."""
class OnnxValidationError(ExporterError):
"""Output value mismatch."""

View File

@ -554,7 +554,7 @@ class FxOnnxInterpreter:
)
with diagnostic.log_section(logging.DEBUG, "ONNX Graph:"):
diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph)
diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph) # type: ignore[attr-defined]
return onnxscript_graph
@ -655,7 +655,7 @@ class FxOnnxInterpreter:
# function signature in OpSchema, and find the best matched overload.
symbolic_fn = onnxfunction_dispatcher.dispatch(
node=node,
onnx_args=onnx_args,
onnx_args=onnx_args, # type: ignore[arg-type]
onnx_kwargs=onnx_kwargs,
diagnostic_context=self.diagnostic_context,
)
@ -781,7 +781,7 @@ class FxOnnxInterpreter:
outputs: onnxscript_graph_building.TorchScriptTensor | tuple[
onnxscript_graph_building.TorchScriptTensor, ...
] = parent_onnxscript_graph.add_module_call(
] = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
unique_module_name, sub_onnxscript_graph, onnx_args
)

View File

@ -61,7 +61,7 @@ def _create_tensor_proto_with_external_data(
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = scalar_type.onnx_type()
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]

View File

@ -172,10 +172,7 @@ def _get_torch_export_args(
def export(
model: torch.nn.Module
| torch.jit.ScriptModule
| torch.jit.ScriptFunction
| torch.export.ExportedProgram,
model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction,
args: tuple[Any, ...] | torch.Tensor,
f: str | None = None,
*,
@ -191,13 +188,11 @@ def export(
dynamic_axes: Mapping[str, Mapping[int, str]]
| Mapping[str, Sequence[int]]
| None = None,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
keep_initializers_as_inputs: bool | None = None,
custom_opsets: Mapping[str, int] | None = None,
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
autograd_inlining: bool | None = True,
dynamo: bool = False,
) -> torch.onnx.ONNXProgram | None:
autograd_inlining: bool = True,
) -> None:
r"""Exports a model into ONNX format.
If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
@ -491,8 +486,6 @@ def export(
autograd_inlining: Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
dynamo: Whether to export the model with Dynamo instead of TorchScript.
Raises:
:class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
:class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
@ -515,65 +508,29 @@ def export(
)
args = (args,) if isinstance(args, torch.Tensor) else args
if kwargs is not None:
args = args + (kwargs,)
if dynamo:
if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)):
raise TypeError(
"Dynamo export does not support ScriptModule or ScriptFunction."
)
# TODO(justinchuby): Remove the warning once logic migration is done
warnings.warn(
"export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, "
"do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and "
"autograd_inlining are not supported for dynamo export at the moment."
)
args, kwargs = _get_torch_export_args(args, kwargs)
if isinstance(model, torch.export.ExportedProgram):
exported_program = model
else:
if dynamic_shapes is None and dynamic_axes is not None:
dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
model, dynamic_axes, input_names
)
exported_program = torch.export.export(
model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type]
)
if kwargs is None:
# TODO(justinchuby): dynamo_export requires kwargs to be unpacked. Once migration is done
# we can pass kwargs as None
kwargs = {}
onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs)
if f is not None:
onnx_program.save(f)
return onnx_program
_export(
model,
args,
f,
export_params,
verbose,
training,
input_names,
output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
autograd_inlining=autograd_inlining,
)
else:
# Torch Script export path
if f is None:
raise ValueError("Export destination must be specified when dynamo=False.")
if kwargs is not None:
args = args + (kwargs,)
_export(
model,
args,
f,
export_params,
verbose,
training,
input_names,
output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
autograd_inlining=autograd_inlining,
)
return None
return None
def _is_constant_tensor_list(node):
@ -1531,7 +1488,7 @@ def _export(
custom_opsets=None,
add_node_names=True,
onnx_shape_inference=True,
export_modules_as_functions=False,
export_modules_as_functions: Any = False,
autograd_inlining=True,
):
assert GLOBALS.in_onnx_export is False
@ -1560,9 +1517,7 @@ def _export(
f"Exporting to ONNX opset version {opset_version} is not supported. "
f"by 'torch.onnx.export()'. "
f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. "
f"To use a newer opset version, consider 'torch.onnx.dynamo_export()'. "
f"Note that dynamo_export() is in preview. Please report errors with "
f"dynamo_export() as Github issues to https://github.com/pytorch/pytorch/issues.",
f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ",
category=errors.OnnxExporterWarning,
)