mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] New export logic leveraging ExportedProgram and ONNX IR (#132530)
1/n PR to
- Move code from torch-onnx from commit 395495e566
into torch.onnx and fixes imports.
- Integrate the new export logic with the torch.onnx.export API and include basic set of tests.
- Refactor the API for the change.
- Improve documentation.
Next PRs will be more tests and docs.
Fix https://github.com/pytorch/pytorch/issues/129277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132530
Approved by: https://github.com/titaiwangms, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
06cc2e83f0
commit
e8fc1e0118
@ -715,5 +715,5 @@ Classes
|
||||
:template: classtemplate.rst
|
||||
|
||||
JitScalarType
|
||||
torch.onnx.verification.GraphInfo
|
||||
torch.onnx.verification.VerificationOptions
|
||||
verification.GraphInfo
|
||||
verification.VerificationOptions
|
||||
|
12
mypy.ini
12
mypy.ini
@ -165,9 +165,6 @@ ignore_missing_imports = True
|
||||
[mypy-tensorboard.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-onnx.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-matplotlib.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@ -301,5 +298,14 @@ ignore_missing_imports = True
|
||||
# Third party dependencies that are optional.
|
||||
#
|
||||
|
||||
[mypy-onnx.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-onnxruntime.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-onnxscript.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-redis]
|
||||
ignore_missing_imports = True
|
@ -163,222 +163,5 @@ class TestDynamoExportAPI(common_utils.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestONNXExportWithDynamo(common_utils.TestCase):
|
||||
def test_args_normalization_with_no_kwargs(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelTwoInputs(),
|
||||
(
|
||||
torch.randn(1, 1, 2),
|
||||
torch.randn(1, 1, 2),
|
||||
),
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2)
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_args_is_tensor_not_tuple(self):
|
||||
exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),))
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program, torch.randn(1, 1, 2)
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModel(), torch.randn(1, 1, 2), dynamo=True
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_args_normalization_with_kwargs(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_args_normalization_with_empty_dict_at_the_tail(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(
|
||||
torch.randn(2, 2, 3),
|
||||
torch.randn(2, 2, 3),
|
||||
),
|
||||
dynamic_shapes={
|
||||
"x": {
|
||||
0: torch.export.Dim("customx_dim_0"),
|
||||
1: torch.export.Dim("customx_dim_1"),
|
||||
2: torch.export.Dim("customx_dim_2"),
|
||||
},
|
||||
"b": {
|
||||
0: torch.export.Dim("customb_dim_0"),
|
||||
1: torch.export.Dim("customb_dim_1"),
|
||||
2: torch.export.Dim("customb_dim_2"),
|
||||
},
|
||||
},
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
torch.randn(2, 2, 3),
|
||||
b=torch.randn(2, 2, 3),
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
|
||||
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(
|
||||
torch.randn(2, 2, 3),
|
||||
torch.randn(2, 2, 3),
|
||||
),
|
||||
dynamic_shapes={
|
||||
"x": {
|
||||
0: torch.export.Dim("customx_dim_0"),
|
||||
1: torch.export.Dim("customx_dim_1"),
|
||||
2: torch.export.Dim("customx_dim_2"),
|
||||
},
|
||||
"b": {
|
||||
0: torch.export.Dim("customb_dim_0"),
|
||||
1: torch.export.Dim("customb_dim_1"),
|
||||
2: torch.export.Dim("customb_dim_2"),
|
||||
},
|
||||
},
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
torch.randn(2, 2, 3),
|
||||
b=torch.randn(2, 2, 3),
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"x": [0, 1, 2],
|
||||
"b": [0, 1, 2],
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
|
||||
exported_program = torch.export.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(
|
||||
torch.randn(2, 2, 3),
|
||||
torch.randn(2, 2, 3),
|
||||
),
|
||||
dynamic_shapes={
|
||||
"x": None,
|
||||
"b": {
|
||||
0: torch.export.Dim("customb_dim_0"),
|
||||
1: torch.export.Dim("customb_dim_1"),
|
||||
2: torch.export.Dim("customb_dim_2"),
|
||||
},
|
||||
},
|
||||
)
|
||||
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
torch.randn(2, 2, 3),
|
||||
b=torch.randn(2, 2, 3),
|
||||
)
|
||||
onnx_program_from_old_exporter = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"b": [0, 1, 2],
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
onnx_program_from_new_exporter.model_proto,
|
||||
onnx_program_from_old_exporter.model_proto,
|
||||
)
|
||||
|
||||
def test_dynamic_shapes_hit_constraints_in_dynamo(self):
|
||||
# SampleModelTwoInputs has constraints becuse of add of two inputs,
|
||||
# so the two input shapes are related.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Constraints violated",
|
||||
):
|
||||
_ = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(2, 2, 3), torch.randn(2, 2, 3)),
|
||||
dynamic_axes={
|
||||
"x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"},
|
||||
"b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"},
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
|
||||
def test_saved_f_exists_after_export(self):
|
||||
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
||||
_ = torch.onnx.export(
|
||||
SampleModel(), torch.randn(1, 1, 2), path, dynamo=True
|
||||
)
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_raises_error_when_input_is_script_module(self):
|
||||
class ScriptModule(torch.jit.ScriptModule):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Dynamo export does not support ScriptModule or ScriptFunction.",
|
||||
):
|
||||
_ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
1
test/onnx/exporter/README.md
Normal file
1
test/onnx/exporter/README.md
Normal file
@ -0,0 +1 @@
|
||||
Directory for all ExportedProgram exporter logic.
|
120
test/onnx/exporter/test_api.py
Normal file
120
test/onnx/exporter/test_api.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Simple API tests for the ONNX exporter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import exporter
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class SampleModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = x + 1
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
|
||||
class SampleModelTwoInputs(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
y = x + b
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
|
||||
class SampleModelForDynamicShapes(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
return x.relu(), b.sigmoid()
|
||||
|
||||
|
||||
class TestExportAPIDynamo(common_utils.TestCase):
|
||||
"""Tests for the ONNX exporter API when dynamo=True."""
|
||||
|
||||
def test_args_normalization_with_no_kwargs(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_args_normalization_with_kwargs(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_args_normalization_with_empty_dict_at_the_tail(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelTwoInputs(),
|
||||
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
|
||||
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"x": [0, 1, 2],
|
||||
"b": [0, 1, 2],
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
|
||||
onnx_program = torch.onnx.export(
|
||||
SampleModelForDynamicShapes(),
|
||||
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
|
||||
dynamic_axes={
|
||||
"b": [0, 1, 2],
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
def test_saved_f_exists_after_export(self):
|
||||
with common_utils.TemporaryFileName(suffix=".onnx") as path:
|
||||
_ = torch.onnx.export(
|
||||
SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
|
||||
)
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_export_supports_script_module(self):
|
||||
class ScriptModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),), dynamo=True
|
||||
)
|
||||
assert onnx_program
|
||||
exporter.verify_onnx_program(onnx_program)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -286,6 +286,25 @@ class TestPublicBindings(TestCase):
|
||||
# do not get imported by public code.
|
||||
private_allowlist = {
|
||||
"torch._inductor.codegen.cuda.cuda_kernel",
|
||||
# TODO(#133647): Remove the onnx._internal entries after
|
||||
# onnx and onnxscript are installed in CI.
|
||||
"torch.onnx._internal.exporter",
|
||||
"torch.onnx._internal.exporter._analysis",
|
||||
"torch.onnx._internal.exporter._building",
|
||||
"torch.onnx._internal.exporter._capture_strategies",
|
||||
"torch.onnx._internal.exporter._compat",
|
||||
"torch.onnx._internal.exporter._core",
|
||||
"torch.onnx._internal.exporter._decomp",
|
||||
"torch.onnx._internal.exporter._dispatching",
|
||||
"torch.onnx._internal.exporter._fx_passes",
|
||||
"torch.onnx._internal.exporter._ir_passes",
|
||||
"torch.onnx._internal.exporter._isolated",
|
||||
"torch.onnx._internal.exporter._onnx_program",
|
||||
"torch.onnx._internal.exporter._registration",
|
||||
"torch.onnx._internal.exporter._reporting",
|
||||
"torch.onnx._internal.exporter._schemas",
|
||||
"torch.onnx._internal.exporter._tensors",
|
||||
"torch.onnx._internal.exporter._verification",
|
||||
"torch.onnx._internal.fx._pass",
|
||||
"torch.onnx._internal.fx.analysis",
|
||||
"torch.onnx._internal.fx.analysis.unsupported_nodes",
|
||||
|
@ -1,65 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from ._exporter_states import ExportTypes
|
||||
from ._internal.onnxruntime import (
|
||||
is_onnxrt_backend_supported,
|
||||
OrtBackend as _OrtBackend,
|
||||
OrtBackendOptions as _OrtBackendOptions,
|
||||
OrtExecutionProvider as _OrtExecutionProvider,
|
||||
)
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import CheckerError # Backwards compatibility
|
||||
from .utils import (
|
||||
_optimize_graph,
|
||||
_run_symbolic_function,
|
||||
_run_symbolic_method,
|
||||
export,
|
||||
export_to_pretty_string,
|
||||
is_in_onnx_export,
|
||||
register_custom_op_symbolic,
|
||||
select_model_mode_for_export,
|
||||
unregister_custom_op_symbolic,
|
||||
)
|
||||
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
symbolic_opset7,
|
||||
symbolic_opset8,
|
||||
symbolic_opset9,
|
||||
symbolic_opset10,
|
||||
symbolic_opset11,
|
||||
symbolic_opset12,
|
||||
symbolic_opset13,
|
||||
symbolic_opset14,
|
||||
symbolic_opset15,
|
||||
symbolic_opset16,
|
||||
symbolic_opset17,
|
||||
symbolic_opset18,
|
||||
symbolic_opset19,
|
||||
symbolic_opset20,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
|
||||
DiagnosticOptions,
|
||||
ExportOptions,
|
||||
ONNXProgram,
|
||||
ONNXProgramSerializer,
|
||||
ONNXRuntimeOptions,
|
||||
InvalidExportOptionsError,
|
||||
OnnxExporterError,
|
||||
OnnxRegistry,
|
||||
dynamo_export,
|
||||
enable_fake_mode,
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -115,6 +55,74 @@ __all__ = [
|
||||
"is_onnxrt_backend_supported",
|
||||
]
|
||||
|
||||
from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from ._exporter_states import ExportTypes
|
||||
from ._internal.onnxruntime import (
|
||||
is_onnxrt_backend_supported,
|
||||
OrtBackend as _OrtBackend,
|
||||
OrtBackendOptions as _OrtBackendOptions,
|
||||
OrtExecutionProvider as _OrtExecutionProvider,
|
||||
)
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import CheckerError # Backwards compatibility
|
||||
from .utils import (
|
||||
_optimize_graph,
|
||||
_run_symbolic_function,
|
||||
_run_symbolic_method,
|
||||
export_to_pretty_string,
|
||||
is_in_onnx_export,
|
||||
register_custom_op_symbolic,
|
||||
select_model_mode_for_export,
|
||||
unregister_custom_op_symbolic,
|
||||
)
|
||||
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
symbolic_opset7,
|
||||
symbolic_opset8,
|
||||
symbolic_opset9,
|
||||
symbolic_opset10,
|
||||
symbolic_opset11,
|
||||
symbolic_opset12,
|
||||
symbolic_opset13,
|
||||
symbolic_opset14,
|
||||
symbolic_opset15,
|
||||
symbolic_opset16,
|
||||
symbolic_opset17,
|
||||
symbolic_opset18,
|
||||
symbolic_opset19,
|
||||
symbolic_opset20,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
|
||||
DiagnosticOptions,
|
||||
ExportOptions,
|
||||
ONNXProgram,
|
||||
ONNXProgramSerializer,
|
||||
ONNXRuntimeOptions,
|
||||
InvalidExportOptionsError,
|
||||
OnnxExporterError,
|
||||
OnnxRegistry,
|
||||
dynamo_export,
|
||||
enable_fake_mode,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
# Set namespace for exposed private names
|
||||
ExportTypes.__module__ = "torch.onnx"
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
@ -137,6 +145,257 @@ producer_name = "pytorch"
|
||||
producer_version = _C_onnx.PRODUCER_VERSION
|
||||
|
||||
|
||||
def export(
|
||||
model: torch.nn.Module
|
||||
| torch.export.ExportedProgram
|
||||
| torch.jit.ScriptModule
|
||||
| torch.jit.ScriptFunction,
|
||||
args: tuple[Any, ...],
|
||||
f: str | os.PathLike | None = None,
|
||||
*,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
export_params: bool = True,
|
||||
verbose: bool | None = None,
|
||||
input_names: Sequence[str] | None = None,
|
||||
output_names: Sequence[str] | None = None,
|
||||
opset_version: int | None = None,
|
||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||
| Mapping[str, Sequence[int]]
|
||||
| None = None,
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
dynamo: bool = False,
|
||||
# Dynamo only options
|
||||
external_data: bool = True,
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
||||
report: bool = False,
|
||||
verify: bool = False,
|
||||
profile: bool = False,
|
||||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
# Deprecated options
|
||||
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
|
||||
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
|
||||
do_constant_folding: bool = True,
|
||||
custom_opsets: Mapping[str, int] | None = None,
|
||||
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
|
||||
autograd_inlining: bool = True,
|
||||
**_: Any, # ignored options
|
||||
) -> Any | None:
|
||||
r"""Exports a model into ONNX format.
|
||||
|
||||
Args:
|
||||
model: The model to be exported.
|
||||
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
|
||||
exported model; any Tensor arguments will become inputs of the exported model,
|
||||
in the order they occur in the tuple.
|
||||
f: Path to the output ONNX model file. E.g. "model.onnx".
|
||||
kwargs: Optional example keyword inputs.
|
||||
export_params: If false, parameters (weights) will not be exported.
|
||||
verbose: Whether to enable verbose logging.
|
||||
input_names: names to assign to the input nodes of the graph, in order.
|
||||
output_names: names to assign to the output nodes of the graph, in order.
|
||||
opset_version: The version of the
|
||||
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
|
||||
to target. Must be >= 7.
|
||||
dynamic_axes:
|
||||
|
||||
By default the exported model will have the shapes of all input and output tensors
|
||||
set to exactly match those given in ``args``. To specify axes of tensors as
|
||||
dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
|
||||
|
||||
* KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
|
||||
``output_names``.
|
||||
* VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
|
||||
list, each element is an axis index.
|
||||
|
||||
For example::
|
||||
|
||||
class SumModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sum(x, dim=1)
|
||||
|
||||
torch.onnx.export(
|
||||
SumModule(),
|
||||
(torch.ones(2, 2),),
|
||||
"onnx.pb",
|
||||
input_names=["x"],
|
||||
output_names=["sum"]
|
||||
)
|
||||
|
||||
Produces::
|
||||
|
||||
input {
|
||||
name: "x"
|
||||
...
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2 # axis 0
|
||||
}
|
||||
dim {
|
||||
dim_value: 2 # axis 1
|
||||
...
|
||||
output {
|
||||
name: "sum"
|
||||
...
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2 # axis 0
|
||||
...
|
||||
|
||||
While::
|
||||
|
||||
torch.onnx.export(
|
||||
SumModule(),
|
||||
(torch.ones(2, 2),),
|
||||
"onnx.pb",
|
||||
input_names=["x"],
|
||||
output_names=["sum"],
|
||||
dynamic_axes={
|
||||
# dict value: manually named axes
|
||||
"x": {0: "my_custom_axis_name"},
|
||||
# list value: automatic names
|
||||
"sum": [0],
|
||||
}
|
||||
)
|
||||
|
||||
Produces::
|
||||
|
||||
input {
|
||||
name: "x"
|
||||
...
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "my_custom_axis_name" # axis 0
|
||||
}
|
||||
dim {
|
||||
dim_value: 2 # axis 1
|
||||
...
|
||||
output {
|
||||
name: "sum"
|
||||
...
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "sum_dynamic_axes_1" # axis 0
|
||||
...
|
||||
|
||||
keep_initializers_as_inputs: If True, all the
|
||||
initializers (typically corresponding to model weights) in the
|
||||
exported graph will also be added as inputs to the graph. If False,
|
||||
then initializers are not added as inputs to the graph, and only
|
||||
the user inputs are added as inputs.
|
||||
|
||||
Set this to True if you intend to supply model weights at runtime.
|
||||
Set it to False if the weights are static to allow for better optimizations
|
||||
(e.g. constant folding) by backends/runtimes.
|
||||
|
||||
dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
|
||||
external_data: Whether to save the model weights as an external data file.
|
||||
This is required for models with large weights that exceed the ONNX file size limit (2GB).
|
||||
When False, the weights are saved in the ONNX file with the model architecture.
|
||||
dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
|
||||
:func:`torch.export.export` for more details.
|
||||
report: Whether to generate a markdown report for the export process.
|
||||
verify: Whether to verify the exported model using ONNX Runtime.
|
||||
profile: Whether to profile the export process.
|
||||
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
|
||||
This is useful for debugging the exporter.
|
||||
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
|
||||
exported program.
|
||||
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
|
||||
|
||||
training: Deprecated option. Instead, set the training mode of the model before exporting.
|
||||
operator_export_type: Deprecated option. Only ONNX is supported.
|
||||
do_constant_folding: Deprecated option. The exported graph is always optimized.
|
||||
custom_opsets: Deprecated.
|
||||
A dictionary:
|
||||
|
||||
* KEY (str): opset domain name
|
||||
* VALUE (int): opset version
|
||||
|
||||
If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
|
||||
the opset version is set to 1. Only custom opset domain name and version should be
|
||||
indicated through this argument.
|
||||
export_modules_as_functions: Deprecated option.
|
||||
|
||||
Flag to enable
|
||||
exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
|
||||
particular types of modules to export as local functions in ONNX.
|
||||
This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
|
||||
``opset_version`` < 15 implies IR version < 8, which means no local function support.
|
||||
Module variables will be exported as function attributes. There are two categories of function
|
||||
attributes.
|
||||
|
||||
1. Annotated attributes: class variables that have type annotations via
|
||||
`PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
|
||||
will be exported as attributes.
|
||||
Annotated attributes are not used inside the subgraph of ONNX local function because
|
||||
they are not created by PyTorch JIT tracing, but they may be used by consumers
|
||||
to determine whether or not to replace the function with a particular fused kernel.
|
||||
|
||||
2. Inferred attributes: variables that are used by operators inside the module. Attribute names
|
||||
will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
|
||||
python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
|
||||
|
||||
* ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
|
||||
* ``True``: export all ``nn.Module`` forward calls as local function nodes.
|
||||
* Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
|
||||
only if the type of the ``nn.Module`` is found in the set.
|
||||
autograd_inlining: Deprecated.
|
||||
Flag used to control whether to inline autograd functions.
|
||||
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
|
||||
"""
|
||||
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
|
||||
from torch.onnx._internal import exporter
|
||||
|
||||
if isinstance(args, torch.Tensor):
|
||||
args = (args,)
|
||||
return exporter.export_compat(
|
||||
model,
|
||||
args,
|
||||
f,
|
||||
kwargs=kwargs,
|
||||
export_params=export_params,
|
||||
verbose=verbose,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
external_data=external_data,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
report=report,
|
||||
verify=verify,
|
||||
profile=profile,
|
||||
dump_exported_program=dump_exported_program,
|
||||
artifacts_dir=artifacts_dir,
|
||||
fallback=fallback,
|
||||
)
|
||||
else:
|
||||
from torch.onnx.utils import export
|
||||
|
||||
export(
|
||||
model,
|
||||
args,
|
||||
f, # type: ignore[arg-type]
|
||||
kwargs=kwargs,
|
||||
export_params=export_params,
|
||||
verbose=verbose is True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
training=training,
|
||||
operator_export_type=operator_export_type,
|
||||
do_constant_folding=do_constant_folding,
|
||||
custom_opsets=custom_opsets,
|
||||
export_modules_as_functions=export_modules_as_functions,
|
||||
autograd_inlining=autograd_inlining,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
||||
|
||||
# Returns True iff ONNX logging is turned on.
|
||||
|
38
torch/onnx/_internal/_lazy_import.py
Normal file
38
torch/onnx/_internal/_lazy_import.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Utility to lazily import modules."""
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
|
||||
class _LazyModule:
|
||||
"""Lazily import a module."""
|
||||
|
||||
def __init__(self, module_name: str) -> None:
|
||||
self._name = module_name
|
||||
self._module: Any = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<lazy module '{self._name}'>"
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if self._module is None:
|
||||
self._module = importlib.import_module(".", self._name)
|
||||
return getattr(self._module, attr)
|
||||
|
||||
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||
# imported in user code.
|
||||
# NOTE: Add additional used imports here.
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
import onnxscript
|
||||
|
||||
onnxscript_ir = onnxscript.ir
|
||||
|
||||
else:
|
||||
onnx = _LazyModule("onnx")
|
||||
onnxscript = _LazyModule("onnxscript")
|
||||
onnxscript_ir = _LazyModule("onnxscript.ir")
|
16
torch/onnx/_internal/exporter/__init__.py
Normal file
16
torch/onnx/_internal/exporter/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
__all__ = [
|
||||
"ONNXRegistry",
|
||||
"ONNXProgram",
|
||||
"analyze",
|
||||
"export",
|
||||
"exported_program_to_ir",
|
||||
"verify_onnx_program",
|
||||
"export_compat",
|
||||
]
|
||||
|
||||
from ._analysis import analyze
|
||||
from ._compat import export_compat
|
||||
from ._core import export, exported_program_to_ir
|
||||
from ._onnx_program import ONNXProgram
|
||||
from ._registration import ONNXRegistry
|
||||
from ._verification import verify_onnx_program
|
250
torch/onnx/_internal/exporter/_analysis.py
Normal file
250
torch/onnx/_internal/exporter/_analysis.py
Normal file
@ -0,0 +1,250 @@
|
||||
"""Compatibility analyzer for PyTorch models."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa: B950 We do not need flake8 as it complains line length
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import textwrap
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import onnxscript
|
||||
|
||||
import torch
|
||||
import torch._export.serde.schema
|
||||
from torch.export import graph_signature
|
||||
from torch.onnx._internal.exporter import _dispatching, _registration
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.fx
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelInfo:
|
||||
"""Information about the model."""
|
||||
|
||||
parameter_count: defaultdict[torch.dtype, int] = dataclasses.field(
|
||||
default_factory=lambda: defaultdict(int)
|
||||
)
|
||||
buffer_count: defaultdict[torch.dtype, int] = dataclasses.field(
|
||||
default_factory=lambda: defaultdict(int)
|
||||
)
|
||||
fx_node_count: int = 0
|
||||
fx_node_op_count: defaultdict[str, int] = dataclasses.field(
|
||||
default_factory=lambda: defaultdict(int)
|
||||
)
|
||||
fx_node_target_count: defaultdict[str, int] = dataclasses.field(
|
||||
default_factory=lambda: defaultdict(int)
|
||||
)
|
||||
dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
def _count_weights(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]:
|
||||
"""Count the size of the parameters in the exported program."""
|
||||
|
||||
parameter_count: defaultdict[torch.dtype, int] = defaultdict(int)
|
||||
buffer_count: defaultdict[torch.dtype, int] = defaultdict(int)
|
||||
for parameter in exported_program.parameters():
|
||||
dtype = parameter.dtype
|
||||
parameter_count[dtype] += parameter.numel()
|
||||
|
||||
for buffer in exported_program.buffers():
|
||||
dtype = buffer.dtype
|
||||
buffer_count[dtype] += buffer.numel()
|
||||
|
||||
return parameter_count, buffer_count
|
||||
|
||||
|
||||
def _format_model_info(model_info: ModelInfo) -> str:
|
||||
"""Format the information about the model."""
|
||||
lines = [
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
PyTorch ONNX Conversion Analysis
|
||||
|
||||
## Model Information
|
||||
|
||||
The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters).
|
||||
Number of parameters per dtype:
|
||||
```python
|
||||
{model_info.parameter_count}
|
||||
```
|
||||
Number of buffers per dtype:
|
||||
```python
|
||||
{model_info.buffer_count}
|
||||
```
|
||||
"""
|
||||
),
|
||||
"Inputs:",
|
||||
*[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()],
|
||||
"",
|
||||
"Outputs:",
|
||||
*[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()],
|
||||
"",
|
||||
f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:",
|
||||
]
|
||||
for op, count in model_info.fx_node_op_count.items():
|
||||
lines.append(f"- `{op}`: {count}")
|
||||
lines.append("\n")
|
||||
lines.append("Of the call_function nodes, the counts of operators used are:\n")
|
||||
sorted_targets = sorted(
|
||||
model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
for target, count in sorted_targets:
|
||||
lines.append(f"- `{target}`: {count}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("## ONNX Conversion Information")
|
||||
lines.append("")
|
||||
|
||||
if model_info.dispatch_failures:
|
||||
lines.append(
|
||||
"The model contains operators the dispatcher could not find registered ONNX decompositions for. "
|
||||
"This may be due to missing implementations, decompositions not registered "
|
||||
"correctly, or a bug in the dispatcher."
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("Errors grouped by operator:\n")
|
||||
|
||||
target_to_nodes = defaultdict(list)
|
||||
for node, _ in model_info.dispatch_failures:
|
||||
target_to_nodes[str(node.target)].append(node)
|
||||
|
||||
target_to_messages = {}
|
||||
for node, message in model_info.dispatch_failures:
|
||||
if str(node.target) not in target_to_messages:
|
||||
target_to_messages[str(node.target)] = message
|
||||
|
||||
for target, nodes in sorted(
|
||||
target_to_nodes.items(), key=lambda x: x[0], reverse=True
|
||||
):
|
||||
message = textwrap.indent(
|
||||
f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`",
|
||||
" ",
|
||||
)
|
||||
lines.append(f"- `{target}`: {message}")
|
||||
else:
|
||||
lines.append("All operators in the model have registered ONNX decompositions.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]:
|
||||
"""Get the input and output specs of the exported program."""
|
||||
|
||||
nodes: dict[str, torch.fx.Node] = {
|
||||
node.name: node for node in exported_program.graph.nodes
|
||||
}
|
||||
user_inputs = [
|
||||
spec
|
||||
for spec in exported_program.graph_signature.input_specs
|
||||
if spec.kind == graph_signature.InputKind.USER_INPUT
|
||||
]
|
||||
user_outputs = [
|
||||
spec
|
||||
for spec in exported_program.graph_signature.output_specs
|
||||
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
|
||||
]
|
||||
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
|
||||
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
|
||||
for spec in user_inputs:
|
||||
if isinstance(spec.arg, graph_signature.ConstantArgument):
|
||||
continue
|
||||
name = spec.arg.name
|
||||
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
|
||||
inputs[name] = nodes[name].meta["tensor_meta"]
|
||||
for spec in user_outputs:
|
||||
if isinstance(spec.arg, graph_signature.ConstantArgument):
|
||||
continue
|
||||
name = spec.arg.name
|
||||
outputs[name] = nodes[name].meta["tensor_meta"]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def _count_fx_targets(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
) -> defaultdict[str, int]:
|
||||
"""Count the number of targets for each node in the exported program."""
|
||||
fx_node_target_count: defaultdict[str, int] = defaultdict(int)
|
||||
for node in exported_program.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
fx_node_target_count[str(node.target)] += 1
|
||||
return fx_node_target_count
|
||||
|
||||
|
||||
def analyze(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
registry: _registration.ONNXRegistry | None = None,
|
||||
file=None,
|
||||
) -> None:
|
||||
"""Analyze the compatibility of the exported program."""
|
||||
# Get basic information about the model
|
||||
model_info = ModelInfo()
|
||||
model_info.parameter_count, model_info.buffer_count = _count_weights(
|
||||
exported_program
|
||||
)
|
||||
model_info.fx_node_count = len(exported_program.graph.nodes)
|
||||
model_info.fx_node_target_count = _count_fx_targets(exported_program)
|
||||
inputs, outputs = _get_io_specs(exported_program)
|
||||
model_info.inputs = inputs
|
||||
model_info.outputs = outputs
|
||||
|
||||
if registry is None:
|
||||
# Trigger op registration
|
||||
from onnxscript.function_libs.torch_lib import ops # noqa: F401
|
||||
|
||||
del ops
|
||||
registry = _registration.ONNXRegistry.from_torchlib(
|
||||
onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Try to find ops for every node in the graph
|
||||
for node in exported_program.graph.nodes:
|
||||
model_info.fx_node_op_count[node.op] += 1
|
||||
if node.op == "call_function":
|
||||
try:
|
||||
onnx_function, message = _dispatching.dispatch(node, registry)
|
||||
except Exception as e:
|
||||
message = "Critical Error in dispatcher:\n"
|
||||
formatted_exception = "\n".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
message += f"```pytb\n{formatted_exception}\n```\n"
|
||||
onnx_function = None
|
||||
if onnx_function is None:
|
||||
model_info.dispatch_failures.append((node, message))
|
||||
|
||||
# Print the results
|
||||
report = _format_model_info(model_info)
|
||||
print(report, file=file, flush=True)
|
||||
|
||||
|
||||
def compare_ops(
|
||||
program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Compare and get unique ops in two exported programs.
|
||||
|
||||
Args:
|
||||
program_a: The first exported program.
|
||||
program_b: The second exported program.
|
||||
|
||||
Returns:
|
||||
A tuple of two sets, where the first set contains the unique ops in the first program
|
||||
and the second set contains the unique ops in the second program.
|
||||
"""
|
||||
program_a_ops = set(_count_fx_targets(program_a))
|
||||
program_b_ops = set(_count_fx_targets(program_b))
|
||||
return program_a_ops - program_b_ops, program_b_ops - program_a_ops
|
516
torch/onnx/_internal/exporter/_building.py
Normal file
516
torch/onnx/_internal/exporter/_building.py
Normal file
@ -0,0 +1,516 @@
|
||||
"""NOTES:
|
||||
|
||||
We need a typing module that will handling Python to ONNX type promotion for use.
|
||||
For example, if we have torch.ops.aten.add(Tensor, 1.0), we need to promote 1.0
|
||||
to the same type as Tensor. The same thing needs to work for
|
||||
torch.ops.aten.add(1.0, Tensor) as well, which means we need a mechanism to`
|
||||
"""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=union-attr
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
import onnxscript
|
||||
from onnxscript import evaluator, ir
|
||||
from onnxscript.ir import convenience as ir_convenience
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _schemas, _tensors, errors
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes
|
||||
ValidAttributeType = Union[
|
||||
ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None
|
||||
]
|
||||
|
||||
AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType]
|
||||
|
||||
|
||||
# Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value
|
||||
def _construct_named_inputs_and_attrs(
|
||||
signature: _schemas.OpSignature,
|
||||
args: Sequence[AllowedArgType],
|
||||
kwargs: Mapping[str, AllowedArgType],
|
||||
) -> tuple[dict[str, AllowedArgType], dict[str, ValidAttributeType]]:
|
||||
"""Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs.
|
||||
|
||||
This function uses the OpSignature to determine which argument in args and kwargs corresponds to
|
||||
which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are
|
||||
stored in named_attrs. If an _optional input_ is not provided, it is filled with None.
|
||||
|
||||
Args:
|
||||
signature: The OpSignature for the node.
|
||||
args: The positional arguments for the node.
|
||||
kwargs: The keyword arguments for the node.
|
||||
|
||||
Returns:
|
||||
A tuple of two mappings: named_inputs and named_attrs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a required parameter is not provided.
|
||||
"""
|
||||
# 1. Construct the (named_inputs, named_attrs) mapping based on (args, kwargs) and the signature.
|
||||
# a. Loop over all parameters in the signature and args together
|
||||
# b. Depending on param.is_input, Record named_inputs[param.name] = arg or named_attrs[param.name] = arg
|
||||
# c. Handle kwargs as well
|
||||
# d. Fill in None if the input is not provided
|
||||
named_inputs = {}
|
||||
named_attrs = {}
|
||||
reversed_args_stack = list(reversed(args))
|
||||
for param in signature.params:
|
||||
if isinstance(param, _schemas.Parameter):
|
||||
# Handle inputs
|
||||
if reversed_args_stack:
|
||||
# First exhaust the positional arguments
|
||||
if param.variadic:
|
||||
# Handle variadic arguments
|
||||
named_inputs[param.name] = tuple(args)
|
||||
reversed_args_stack.clear()
|
||||
else:
|
||||
named_inputs[param.name] = reversed_args_stack.pop() # type: ignore[assignment]
|
||||
elif param.name in kwargs:
|
||||
named_inputs[param.name] = kwargs[param.name] # type: ignore[assignment]
|
||||
elif param.required:
|
||||
raise ValueError(
|
||||
f"Required parameter '{param.name}' is not provided. "
|
||||
f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Optional parameter '%s' is not provided. Added as None. Signature: %s",
|
||||
param.name,
|
||||
signature,
|
||||
)
|
||||
named_inputs[param.name] = None # type: ignore[assignment]
|
||||
else:
|
||||
# Handle attributes
|
||||
attribute: ValidAttributeType | ir.Attr
|
||||
assert isinstance(
|
||||
param, _schemas.AttributeParameter
|
||||
), f"Expected AttributeParameter, got {type(param)}"
|
||||
if reversed_args_stack:
|
||||
# First exhaust the positional arguments
|
||||
attribute = reversed_args_stack.pop() # type: ignore[assignment]
|
||||
elif param.name in kwargs:
|
||||
attribute = kwargs[param.name] # type: ignore[assignment]
|
||||
elif param.default is not None:
|
||||
attribute = param.default
|
||||
else:
|
||||
attribute = None
|
||||
|
||||
if attribute is None:
|
||||
if param.required:
|
||||
raise ValueError(
|
||||
f"Required attribute '{param.name}' is not provided. "
|
||||
f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Optional attribute '%s' is None. Dropped. Signature: %s",
|
||||
param.name,
|
||||
signature,
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(attribute, ir.Attr):
|
||||
# Turn the attribute from an default value into an actual parameter for the node
|
||||
attr_copied = copy.copy(attribute)
|
||||
# Make sure the name is the same as the parameter name and not the name of the default parameter
|
||||
attr_copied.name = param.name
|
||||
attribute = attr_copied
|
||||
|
||||
if isinstance(attribute, int) and param.type == ir.AttributeType.FLOAT:
|
||||
# Convert the attribute to float if needed. This happens in PyTorch
|
||||
# where an attribute marked as float can be passed as an int.
|
||||
attribute = float(attribute)
|
||||
named_attrs[param.name] = attribute
|
||||
return named_inputs, named_attrs # type: ignore[return-value]
|
||||
|
||||
|
||||
def _resolve_parameter_dtypes(
|
||||
signature: _schemas.OpSignature, named_inputs: Mapping[str, AllowedArgType]
|
||||
) -> Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol]:
|
||||
"""Determine which parameter takes which type.
|
||||
|
||||
Handle non-tensor input corner cases and type promotion.
|
||||
|
||||
Requires:
|
||||
All ir.Value in name_inputs should have type set. Their type should be
|
||||
compatible with the type_constraint of the corresponding parameter in the signature.
|
||||
|
||||
Args:
|
||||
signature: The OpSignature for the node.
|
||||
named_inputs: The mapping of parameter names to their arguments.
|
||||
|
||||
Returns:
|
||||
A mapping of Constraint names to ir.TypeProtocol.
|
||||
"""
|
||||
# a. Create type_binding: dict[str, ir.TypeProtocol]
|
||||
# b. Iterate over all named_inputs
|
||||
# b0. Find the corresponding parameter in the signature
|
||||
# b1. If the argument is a Python constant, skip.
|
||||
# b2. If the argument is a ir.Value, Bind {constraint: arg.type}.
|
||||
type_binding = {}
|
||||
for name, arg in named_inputs.items():
|
||||
param = signature.params_map[name]
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)):
|
||||
# Skip the Python constants because we do not know what dtype they should take yet
|
||||
continue
|
||||
elif isinstance(arg, ir.Value):
|
||||
if arg.type is None:
|
||||
# Skip the ir.Value if the type is not set
|
||||
continue
|
||||
# NOTE: We assume arg.type is compatible with the type_constraint
|
||||
assert arg.type is not None, f"Expected type to be set for {arg}"
|
||||
# TODO(justinchuby): Implement type promotion logic here.
|
||||
type_binding[param.type_constraint] = arg.type
|
||||
return type_binding
|
||||
|
||||
|
||||
def _process_python_constants_and_sequences(
|
||||
signature: _schemas.OpSignature,
|
||||
named_inputs: dict[str, AllowedArgType],
|
||||
type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol],
|
||||
constant_farm: dict[
|
||||
tuple[
|
||||
bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float],
|
||||
ir.DataType,
|
||||
],
|
||||
ir.Value,
|
||||
],
|
||||
opset: onnxscript.values.Opset,
|
||||
) -> dict[str, ir.Value | None]:
|
||||
"""Convert Python constants to Constant nodes and list to Sequence nodes based on the dtype information.
|
||||
|
||||
The added constants will be replacing values in named_inputs in place.
|
||||
|
||||
Args:
|
||||
signature: The OpSignature for the node.
|
||||
named_inputs: The mapping of parameter names to their arguments.
|
||||
type_binding: A mapping of Constraint names to ir.DataType.
|
||||
constant_farm: A dictionary of {(py_value, ir.DataType): ir.Value} to store the deduplicated constants.
|
||||
opset: The Opset to use for creating Constant nodes.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# 3. Convert Python constants to Constant nodes based on the dtype information;
|
||||
# construct sequences
|
||||
# a. Iterate over all parameters in the signature the second time
|
||||
# b. If the parameter is in to_resolve_type:
|
||||
# - If param.constraint in type_binding,
|
||||
# Get the constant from constant_farm (deduplicated);
|
||||
# otherwise set named_inputs[param.name] = Constant(value, dtype=type_binding[param.constraint])
|
||||
# - Otherwise, set named_inputs[param.name] = Constant(value)
|
||||
for name, arg in named_inputs.items():
|
||||
param = signature.params_map[name]
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
|
||||
if isinstance(arg, ir.Value):
|
||||
# TODO(justinchuby): Cast the ir.Value here if needed
|
||||
continue
|
||||
if (
|
||||
isinstance(arg, Sequence)
|
||||
and len(arg) > 0
|
||||
and all(isinstance(val, ir.Value) for val in arg)
|
||||
):
|
||||
# Skip the sequence of ir.Value. This is a variadic input or a Sequence input
|
||||
# NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants
|
||||
# like `Max(0, ir.Value())`
|
||||
# We need to convert the Python constants to Constant nodes
|
||||
# NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float]
|
||||
continue
|
||||
# if param.variadic:
|
||||
# # FXIME: Handle variadic inputs and sequence inputs differently
|
||||
# raise NotImplementedError
|
||||
# TODO: Find a way to recursively build constants. Maybe extract the logic out.
|
||||
# FIXME: I am here
|
||||
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
|
||||
if param.type_constraint in type_binding:
|
||||
# A known dtype is available
|
||||
dtype = type_binding[param.type_constraint].dtype
|
||||
elif len(param.type_constraint.allowed_types) == 1:
|
||||
# Only one type is allowed
|
||||
dtype = next(iter(param.type_constraint.allowed_types)).dtype
|
||||
else:
|
||||
# No dtype information available. Infer from the Python constant
|
||||
if isinstance(arg, bool):
|
||||
dtype = ir.DataType.BOOL
|
||||
elif isinstance(arg, float):
|
||||
dtype = ir.DataType.FLOAT
|
||||
elif isinstance(arg, int):
|
||||
dtype = ir.DataType.INT64
|
||||
elif isinstance(arg, str):
|
||||
dtype = ir.DataType.STRING
|
||||
elif isinstance(arg, (tuple, list)) and all(
|
||||
isinstance(val, int) for val in arg
|
||||
):
|
||||
dtype = ir.DataType.INT64
|
||||
elif isinstance(arg, (tuple, list)) and any(
|
||||
isinstance(val, float) for val in arg
|
||||
):
|
||||
# NOTE: if any float is present, the dtype is float
|
||||
dtype = ir.DataType.FLOAT
|
||||
elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
|
||||
dtype = arg.dtype
|
||||
elif arg is None:
|
||||
dtype = ir.DataType.UNDEFINED
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Constant input '{arg}' of type '{type(arg)}' is not supported"
|
||||
)
|
||||
|
||||
if arg is None:
|
||||
constant_value = None
|
||||
elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
|
||||
# Deduplicate the constants
|
||||
if isinstance(arg, (tuple, list)):
|
||||
# Make the arg hashable
|
||||
arg = tuple(arg) # noqa: PLW2901
|
||||
constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type]
|
||||
if constant_value is None:
|
||||
constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type]
|
||||
constant_value = opset.Constant(value=constant_tensor)
|
||||
constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index]
|
||||
else:
|
||||
constant_value = opset.Constant(value=arg)
|
||||
|
||||
named_inputs[param.name] = constant_value
|
||||
return named_inputs # type: ignore[return-value]
|
||||
|
||||
|
||||
def _construct_node(
|
||||
signature: _schemas.OpSignature,
|
||||
named_inputs: Mapping[str, ir.Value | None],
|
||||
named_attrs: Mapping[str, ValidAttributeType],
|
||||
opset: onnxscript.values.Opset,
|
||||
) -> ir.Node:
|
||||
"""Construct the node with the inputs and attributes.
|
||||
|
||||
Variadic inputs are flattened.
|
||||
|
||||
Args:
|
||||
signature: The OpSignature for the node.
|
||||
named_inputs: The mapping of parameter names to their arguments. When we
|
||||
do not have the schema of an operator, we do not know the names of
|
||||
the inputs, in which case the names can be anything because they
|
||||
are not used in this function. The data structure is passed in for
|
||||
consistency with the other functions.
|
||||
named_attrs: The mapping of attribute names to their values.
|
||||
"""
|
||||
inputs: list[Any] = []
|
||||
# Flatten variadic inputs
|
||||
for value in named_inputs.values():
|
||||
if isinstance(value, Sequence):
|
||||
inputs.extend(value)
|
||||
else:
|
||||
inputs.append(value)
|
||||
|
||||
# Construct and filter out None attributes
|
||||
attributes = [
|
||||
attr
|
||||
for attr in ir_convenience.convert_attributes(named_attrs)
|
||||
if attr.value is not None
|
||||
]
|
||||
outputs = [_tensors.SymbolicTensor(opset) for _ in signature.outputs]
|
||||
return ir.Node(
|
||||
signature.domain,
|
||||
signature.name,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
class OpRecorder(evaluator.Evaluator):
|
||||
"""An onnxscript Evaluator that captures the graph into torchscript."""
|
||||
|
||||
def __init__(
|
||||
self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value]
|
||||
):
|
||||
self.nodes: list[ir.Node] = []
|
||||
self.opset = opset
|
||||
self.functions: dict[ir.OperatorIdentifier, onnxscript.OnnxFunction] = {}
|
||||
self.constant_farm = constant_farm
|
||||
|
||||
def _call_op(
|
||||
self,
|
||||
op_signature: _schemas.OpSignature,
|
||||
named_inputs: dict[str, AllowedArgType],
|
||||
named_attrs: dict[str, ValidAttributeType],
|
||||
) -> Sequence[_tensors.SymbolicTensor]:
|
||||
"""Record nodes for the given opschema and arguments.
|
||||
|
||||
Args:
|
||||
op_signature: The OpSchema containing the node signature.
|
||||
named_inputs: The mapping of parameter names to their arguments.
|
||||
named_attrs: The mapping of attribute names to their values.
|
||||
"""
|
||||
type_binding = _resolve_parameter_dtypes(op_signature, named_inputs)
|
||||
try:
|
||||
converted_named_inputs = _process_python_constants_and_sequences(
|
||||
op_signature, named_inputs, type_binding, self.constant_farm, self.opset
|
||||
)
|
||||
except Exception as e:
|
||||
raise errors.GraphConstructionError(
|
||||
f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. "
|
||||
f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.nodes.append(
|
||||
node := _construct_node(
|
||||
op_signature, converted_named_inputs, named_attrs, self.opset
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise errors.GraphConstructionError(
|
||||
f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. "
|
||||
f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, "
|
||||
f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}."
|
||||
) from e
|
||||
return node.outputs # type: ignore[return-value]
|
||||
|
||||
def eval(
|
||||
self,
|
||||
schema: onnx.defs.OpSchema,
|
||||
args: Sequence[AllowedArgType], # type: ignore[override]
|
||||
kwargs: Mapping[str, AllowedArgType],
|
||||
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor]:
|
||||
try:
|
||||
op_signature = _schemas.OpSignature.from_opschema(schema)
|
||||
named_inputs, named_attrs = _construct_named_inputs_and_attrs(
|
||||
op_signature, args, kwargs
|
||||
)
|
||||
# TODO(justinchuby): Handle cast
|
||||
if schema.name == "CastLike":
|
||||
assert len(named_inputs) == 2
|
||||
# Skip CastLike if the input and output types are the same
|
||||
src_input = named_inputs["input"]
|
||||
target_type = named_inputs["target_type"]
|
||||
|
||||
if (
|
||||
isinstance(src_input, ir.Value)
|
||||
and isinstance(target_type, ir.Value)
|
||||
and src_input.dtype is not None
|
||||
and target_type.dtype is not None
|
||||
):
|
||||
# dtypes are available
|
||||
if src_input.dtype == target_type.dtype:
|
||||
# Same type. No cast needed
|
||||
return src_input # type: ignore[return-value]
|
||||
else:
|
||||
# Create a Cast node
|
||||
return self.opset.Cast(src_input, to=target_type.dtype) # type: ignore[union-attr,return-value]
|
||||
|
||||
outputs = self._call_op(op_signature, named_inputs, named_attrs)
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
except Exception as e:
|
||||
raise errors.GraphConstructionError(
|
||||
f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}."
|
||||
) from e
|
||||
|
||||
def eval_function( # type: ignore[override]
|
||||
self,
|
||||
function: onnxscript.OnnxFunction,
|
||||
args: Sequence[AllowedArgType],
|
||||
kwargs: Mapping[str, AllowedArgType],
|
||||
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
|
||||
try:
|
||||
# Special cases for handling IsScalar and Rank
|
||||
if function.name == "IsScalar":
|
||||
if len(args) != 1:
|
||||
raise TypeError(
|
||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
||||
)
|
||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
||||
if args[0].rank is not None:
|
||||
return args[0].rank == 0
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
elif isinstance(args[0], Sequence):
|
||||
return False
|
||||
else:
|
||||
# Python constants are scalars
|
||||
return True
|
||||
if function.name == "Rank":
|
||||
if len(args) != 1:
|
||||
raise TypeError(
|
||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
||||
)
|
||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
||||
if args[0].rank is not None:
|
||||
return args[0].rank
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
elif isinstance(args[0], Sequence):
|
||||
if all(isinstance(arg, (int, float)) for arg in args[0]):
|
||||
return 1
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
else:
|
||||
# Python constants are scalars
|
||||
return 0
|
||||
|
||||
# NOTE: signature is written to function in the registration process
|
||||
# TODO: Upstream signature to ONNX Function
|
||||
if hasattr(function, "signature"):
|
||||
op_signature = function.signature
|
||||
else:
|
||||
op_signature = _schemas.OpSignature.from_function(
|
||||
function, function.function_ir.domain, function.name
|
||||
)
|
||||
|
||||
named_inputs, named_attrs = _construct_named_inputs_and_attrs(
|
||||
op_signature, args, kwargs
|
||||
)
|
||||
|
||||
# NOTE: We need to call traceable functions after the _construct_named_inputs_and_attrs
|
||||
# call because it will filter out the unexpected kwargs for us.
|
||||
if function.traceable:
|
||||
# Trace the function call instead of adding the function as a node
|
||||
return function.function(**named_inputs, **named_attrs)
|
||||
|
||||
outputs = self._call_op(op_signature, named_inputs, named_attrs)
|
||||
|
||||
self.functions[(function.function_ir.domain, function.name, "")] = function
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
except Exception as e:
|
||||
try:
|
||||
source_file = inspect.getsourcefile(function.function)
|
||||
_, lineno = inspect.getsourcelines(function.function)
|
||||
except Exception:
|
||||
source_file = lineno = None
|
||||
raise errors.GraphConstructionError(
|
||||
f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}."
|
||||
+ f" The function is defined at '{source_file}:{lineno}'."
|
||||
if source_file
|
||||
else ""
|
||||
) from e
|
335
torch/onnx/_internal/exporter/_capture_strategies.py
Normal file
335
torch/onnx/_internal/exporter/_capture_strategies.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""Strategies for capturing ExportedPrograms."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import datetime
|
||||
import pathlib
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._export import converter as _torchscript_converter
|
||||
from torch.utils import _pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
|
||||
def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
|
||||
"""Prints messages based on `verbose`."""
|
||||
if verbose is False:
|
||||
return lambda *_, **__: None
|
||||
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
||||
|
||||
|
||||
def _take_first_line(text: str) -> str:
|
||||
"""Take the first line of a text."""
|
||||
lines = text.split("\n", maxsplit=1)
|
||||
first_line = lines[0]
|
||||
if len(lines) > 1:
|
||||
first_line += "[...]"
|
||||
return first_line
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Result:
|
||||
exported_program: torch.export.ExportedProgram | None
|
||||
strategy: str
|
||||
exception: Exception | None = None
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self.exported_program is not None
|
||||
|
||||
|
||||
class CaptureStrategy(abc.ABC):
|
||||
"""Strategy for capturing a module as ExportedProgram.
|
||||
|
||||
To use a strategy, create an instance and call it with the model, args, kwargs, and dynamic_shapes.
|
||||
Example::
|
||||
|
||||
strategy = TorchExportStrategy(verbose=True)
|
||||
result = strategy(model, args, kwargs, dynamic_shapes)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
verbose: bool = False,
|
||||
dump: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
timestamp: str | None = None,
|
||||
):
|
||||
"""Initialize the strategy.
|
||||
|
||||
Args:
|
||||
verbose: Whether to print verbose messages.
|
||||
dump: Whether to dump the intermediate artifacts to a file.
|
||||
"""
|
||||
self._verbose_print = _verbose_printer(verbose)
|
||||
self._dump = dump
|
||||
self._artifacts_dir = pathlib.Path(artifacts_dir)
|
||||
self._timestamp = timestamp or datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d_%H-%M-%S-%f"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
model: torch.nn.Module | torch.jit.ScriptFunction,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any] | None,
|
||||
dynamic_shapes,
|
||||
) -> Result:
|
||||
self._enter(model)
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
try:
|
||||
exported_program = self._capture(model, args, kwargs, dynamic_shapes)
|
||||
except Exception as e:
|
||||
self._failure(model, e)
|
||||
return Result(
|
||||
exported_program=None,
|
||||
strategy=self.__class__.__name__,
|
||||
exception=e,
|
||||
)
|
||||
self._success(model)
|
||||
return Result(exported_program, strategy=self.__call__.__name__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
raise NotImplementedError
|
||||
|
||||
def _enter(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None:
|
||||
return
|
||||
|
||||
def _success(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None:
|
||||
return
|
||||
|
||||
def _failure(
|
||||
self, model: torch.nn.Module | torch.jit.ScriptFunction, e: Exception
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
class TorchExportStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export`..."
|
||||
)
|
||||
|
||||
def _success(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export`... ✅"
|
||||
)
|
||||
|
||||
def _failure(self, model, e) -> None:
|
||||
del e # Unused
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export`... ❌"
|
||||
)
|
||||
|
||||
|
||||
class TorchExportNonStrictStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
|
||||
)
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`..."
|
||||
)
|
||||
|
||||
def _success(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ✅"
|
||||
)
|
||||
|
||||
def _failure(self, model, e) -> None:
|
||||
del e # Unused
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ❌"
|
||||
)
|
||||
|
||||
|
||||
class JitTraceConvertStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
del dynamic_shapes # Unused
|
||||
|
||||
flattened_args, spec = _pytree.tree_flatten((args, kwargs))
|
||||
flattened_args = tuple(flattened_args)
|
||||
|
||||
# Since torch.jit.trace only accepts Tensors as inputs, we filter
|
||||
# out non-Tensor arguments and reconstruct the arguments after entering
|
||||
# the WrappedModel.
|
||||
tensor_placeholder = object()
|
||||
non_tensor_args = [
|
||||
arg if not isinstance(arg, torch.Tensor) else tensor_placeholder
|
||||
for arg in flattened_args
|
||||
]
|
||||
tensor_args = tuple(
|
||||
arg for arg in flattened_args if isinstance(arg, torch.Tensor)
|
||||
)
|
||||
|
||||
class WrappedModel(torch.nn.Module):
|
||||
"""Wrap the model so that it takes flattened arguments."""
|
||||
|
||||
def __init__(self, m):
|
||||
super().__init__()
|
||||
self.model = m
|
||||
|
||||
def forward(self, *_args):
|
||||
# Take the non-Tensor arguments list as a starting point and
|
||||
# replace the tensor_placeholder with the actual tensor arguments
|
||||
# from _args.
|
||||
reconstructed_flattened_args = non_tensor_args.copy()
|
||||
_args_iter = iter(_args)
|
||||
for i, arg in enumerate(reconstructed_flattened_args):
|
||||
if arg is tensor_placeholder:
|
||||
reconstructed_flattened_args[i] = next(_args_iter)
|
||||
# Unflatten the arguments and kwargs to pass to the model.
|
||||
unflattened_args, unflattened_kwargs = _pytree.tree_unflatten(
|
||||
reconstructed_flattened_args, spec
|
||||
)
|
||||
results = self.model(*unflattened_args, **unflattened_kwargs)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
flattened_results, _ = _pytree.tree_flatten(results)
|
||||
if len(flattened_results) == 1:
|
||||
return flattened_results[0]
|
||||
return tuple(flattened_results)
|
||||
|
||||
jit_model = torch.jit.trace(
|
||||
WrappedModel(model),
|
||||
example_inputs=tensor_args,
|
||||
check_trace=False,
|
||||
strict=False,
|
||||
)
|
||||
if self._dump:
|
||||
program_path = self._artifacts_dir / f"onnx_export_{self._timestamp}.pt"
|
||||
try:
|
||||
torch.jit.save(jit_model, program_path)
|
||||
except Exception as e:
|
||||
self._verbose_print(
|
||||
f"Failed to save Torch Script model due to an error: {e}"
|
||||
)
|
||||
else:
|
||||
self._verbose_print(
|
||||
f"Torch Script model has been saved to '{program_path}'."
|
||||
)
|
||||
return _torchscript_converter.TS2EPConverter(
|
||||
jit_model, flattened_args
|
||||
).convert()
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with Torch Script..."
|
||||
)
|
||||
|
||||
def _success(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with Torch Script... ✅"
|
||||
)
|
||||
|
||||
def _failure(self, model, e) -> None:
|
||||
del e # Unused
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with Torch Script... ❌"
|
||||
)
|
||||
|
||||
|
||||
class LegacyDynamoStrategy(CaptureStrategy):
|
||||
"""Strategy implemented by the ONNX team using internal dynamo APIs and custom fx passes."""
|
||||
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
# NOTE: Import here to prevent circular dependency
|
||||
from torch.onnx._internal.fx import diagnostics, passes
|
||||
|
||||
graph_module, _ = torch._dynamo.export(
|
||||
model,
|
||||
tracing_mode="symbolic",
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
torch._dynamo.reset()
|
||||
|
||||
diagnostic_context = diagnostics.DiagnosticContext(
|
||||
"torch.onnx.export",
|
||||
torch.__version__,
|
||||
)
|
||||
|
||||
flattened_args, _ = _pytree.tree_flatten((args, kwargs))
|
||||
flattened_args = tuple(flattened_args)
|
||||
|
||||
# ONNX does not support views and mutations.
|
||||
# Functionalize to get a semantically equivalent graph without mutations.
|
||||
graph_module = passes.Functionalize(
|
||||
diagnostic_context,
|
||||
graph_module,
|
||||
enable_dynamic_axes=bool(dynamic_shapes),
|
||||
).run(*flattened_args)
|
||||
|
||||
# Input mutations are detected and distilled after `Functionalize` pass.
|
||||
# Remove them since ONNX inference does not need them.
|
||||
graph_module = passes.RemoveInputMutation(diagnostic_context, graph_module).run(
|
||||
*flattened_args
|
||||
)
|
||||
|
||||
# Use torch.export to recapture the GraphModule into an ExportedProgram.
|
||||
return torch.export.export(graph_module, flattened_args)
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with internal Dynamo apis..."
|
||||
)
|
||||
|
||||
def _success(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ✅"
|
||||
)
|
||||
|
||||
def _failure(self, model, e) -> None:
|
||||
del e # Unused
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with internal Dynamo apis... ❌"
|
||||
)
|
||||
|
||||
|
||||
CAPTURE_STRATEGIES = (
|
||||
TorchExportStrategy,
|
||||
TorchExportNonStrictStrategy,
|
||||
JitTraceConvertStrategy,
|
||||
LegacyDynamoStrategy,
|
||||
)
|
225
torch/onnx/_internal/exporter/_compat.py
Normal file
225
torch/onnx/_internal/exporter/_compat.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""Compatibility functions for the torch.onnx.export API."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import onnx
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
from torch.onnx._internal.exporter import _core, _onnx_program
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _signature(model) -> inspect.Signature:
|
||||
should_be_callable = getattr(model, "forward", model)
|
||||
if callable(should_be_callable):
|
||||
return inspect.signature(should_be_callable)
|
||||
raise ValueError("model has no forward method and is not callable")
|
||||
|
||||
|
||||
def _from_dynamic_axes_to_dynamic_shapes(
|
||||
model,
|
||||
dynamic_axes=None,
|
||||
input_names: Sequence[str] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
|
||||
dynamic_axes examples:
|
||||
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
|
||||
(2) dynamic_axes = {"x": [0], "y": [1]}
|
||||
|
||||
these will be converted to dynamic_shapes respectively:
|
||||
(1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
|
||||
(2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names
|
||||
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/pull/128371
|
||||
# 1. The function does not need to provide dynamic_shapes to torch.export.export
|
||||
if dynamic_axes is None:
|
||||
return None
|
||||
|
||||
if input_names is None:
|
||||
input_names = []
|
||||
|
||||
sig = _signature(model)
|
||||
if len(input_names) > len(sig.parameters):
|
||||
raise ValueError(
|
||||
f"Number of input names ({len(input_names)}) should not be greater than "
|
||||
f"the number of model inputs ({len(sig.parameters)})"
|
||||
)
|
||||
input_names_to_model_inputs = {}
|
||||
for idx, param_name in enumerate(sig.parameters):
|
||||
if idx < len(input_names):
|
||||
input_names_to_model_inputs[input_names[idx]] = param_name
|
||||
else:
|
||||
input_names_to_model_inputs[param_name] = param_name
|
||||
|
||||
# NOTE: torch.export.export does not support input names assignment,
|
||||
# so we need to map input names to model inputs to create dynamic_shapes
|
||||
# for the exported program
|
||||
dynamic_shapes_to_exported_program = {}
|
||||
for input_name, axes in dynamic_axes.items():
|
||||
# input_name can be either from inptu_names or from the model inputs
|
||||
if input_name not in input_names_to_model_inputs:
|
||||
raise ValueError(
|
||||
f"dynamix axis: {input_name} is not found in the input names: {input_names}"
|
||||
)
|
||||
model_input_name = input_names_to_model_inputs[input_name]
|
||||
if isinstance(axes, dict):
|
||||
dynamic_shapes_to_exported_program[model_input_name] = {
|
||||
k: torch.export.Dim(v) for k, v in axes.items()
|
||||
}
|
||||
elif isinstance(axes, list):
|
||||
dynamic_shapes_to_exported_program[model_input_name] = {
|
||||
k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes
|
||||
}
|
||||
else:
|
||||
raise TypeError(
|
||||
f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
|
||||
)
|
||||
# torch.export.export needs static dim to present in dynamic_shapes
|
||||
# for all input tensors, so we need to add them with None
|
||||
for input_name in sig.parameters:
|
||||
if input_name not in dynamic_shapes_to_exported_program:
|
||||
dynamic_shapes_to_exported_program[input_name] = None # type: ignore[assignment]
|
||||
|
||||
return dynamic_shapes_to_exported_program
|
||||
|
||||
|
||||
def _get_torch_export_args(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any] | None,
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any] | None]:
|
||||
"""Obtain the arguments for torch.onnx.export from the model and the input arguments."""
|
||||
if not kwargs and args and isinstance(args[-1], dict):
|
||||
kwargs = args[-1]
|
||||
args = args[:-1]
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _convert_version(path: str | os.PathLike, opset_version: int) -> None:
|
||||
"""Convert the ONNX file to a specific version."""
|
||||
model = onnx.load(path, load_external_data=False)
|
||||
model = onnx.version_converter.convert_version(model, opset_version)
|
||||
onnx.save(model, path)
|
||||
|
||||
|
||||
def export_compat(
|
||||
model: torch.nn.Module
|
||||
| torch.export.ExportedProgram
|
||||
| torch.jit.ScriptModule
|
||||
| torch.jit.ScriptFunction,
|
||||
args: tuple[Any, ...],
|
||||
f: str | os.PathLike | None = None,
|
||||
*,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
export_params: bool = True,
|
||||
verbose: bool | None = None,
|
||||
input_names: Sequence[str] | None = None,
|
||||
output_names: Sequence[str] | None = None,
|
||||
opset_version: int | None = None,
|
||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||
| Mapping[str, Sequence[int]]
|
||||
| None = None,
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
external_data: bool = True,
|
||||
report: bool = False,
|
||||
verify: bool = False,
|
||||
profile: bool = False,
|
||||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
**_,
|
||||
) -> _onnx_program.ONNXProgram | None:
|
||||
if isinstance(model, torch.export.ExportedProgram):
|
||||
# We the model is already exported program, so the args, kwargs, and dynamic_shapes
|
||||
# are not used
|
||||
dynamic_shapes = dynamic_shapes or {}
|
||||
else:
|
||||
args, kwargs = _get_torch_export_args(args, kwargs)
|
||||
if dynamic_shapes is None and dynamic_axes is not None:
|
||||
dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
|
||||
model, dynamic_axes, input_names
|
||||
)
|
||||
|
||||
should_convert_version = False
|
||||
|
||||
try:
|
||||
onnx_program = _core.export(
|
||||
model,
|
||||
args,
|
||||
kwargs,
|
||||
registry=None,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
profile=profile,
|
||||
report=report,
|
||||
verify=verify,
|
||||
dump_exported_program=dump_exported_program,
|
||||
artifacts_dir=artifacts_dir,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if f is not None:
|
||||
# Always save the initializers as external data to reduce the size of the ONNX file
|
||||
onnx_program.save(
|
||||
f,
|
||||
include_initializers=export_params,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
external_data=external_data,
|
||||
)
|
||||
if (
|
||||
opset_version is not None
|
||||
and opset_version != onnx_program.model.opset_imports.get("")
|
||||
):
|
||||
should_convert_version = True
|
||||
|
||||
except Exception as e:
|
||||
if fallback:
|
||||
if verbose is not False:
|
||||
print(
|
||||
"[torch.onnx] Falling back to legacy torch.onnx.export due "
|
||||
f"to the following error: {e}",
|
||||
)
|
||||
torch.onnx.utils.export(
|
||||
model, # type: ignore[arg-type]
|
||||
args,
|
||||
f, # type: ignore[arg-type]
|
||||
kwargs=kwargs,
|
||||
export_params=export_params,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=17, # TODO(justinchuby): Hard coded to 17 for now
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
)
|
||||
onnx_program = None
|
||||
if opset_version is None:
|
||||
opset_version = 18
|
||||
if opset_version != 17:
|
||||
should_convert_version = True
|
||||
else:
|
||||
raise
|
||||
|
||||
if f is not None and should_convert_version:
|
||||
assert opset_version is not None
|
||||
if verbose is not False:
|
||||
print(
|
||||
f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..."
|
||||
)
|
||||
_convert_version(f, opset_version)
|
||||
|
||||
return onnx_program
|
1344
torch/onnx/_internal/exporter/_core.py
Normal file
1344
torch/onnx/_internal/exporter/_core.py
Normal file
File diff suppressed because it is too large
Load Diff
74
torch/onnx/_internal/exporter/_decomp.py
Normal file
74
torch/onnx/_internal/exporter/_decomp.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""Build decomp table from PyTorch."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.onnx._internal.exporter import _registration
|
||||
|
||||
|
||||
def get_onnx_implemented_overloads(
|
||||
registry: _registration.ONNXRegistry,
|
||||
) -> list[torch._ops.OperatorBase]:
|
||||
"""
|
||||
Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
|
||||
|
||||
Args:
|
||||
registry: The ONNX registry for PyTorch.
|
||||
|
||||
Returns:
|
||||
A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
|
||||
"""
|
||||
registered_ops: list[torch._ops.OperatorBase] = []
|
||||
for op_namespace in (torch.ops.aten, torch.ops.prims):
|
||||
op_names = dir(op_namespace)
|
||||
for op_name in op_names:
|
||||
op_overload_packet = getattr(op_namespace, op_name)
|
||||
if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
|
||||
continue
|
||||
|
||||
for overload_name in op_overload_packet.overloads():
|
||||
op_overload = getattr(op_overload_packet, overload_name)
|
||||
if registry.is_registered(op_overload):
|
||||
registered_ops.append(op_overload)
|
||||
return registered_ops
|
||||
|
||||
|
||||
def create_onnx_friendly_decomposition_table(
|
||||
registry,
|
||||
) -> dict[torch._ops.OperatorBase, Callable]:
|
||||
"""
|
||||
This function creates a dictionary of op overloads and their decomposition functions
|
||||
for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
|
||||
its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's
|
||||
built-in aten-to-aten decomposition.
|
||||
|
||||
Args:
|
||||
registry: The ONNX registry for PyTorch.
|
||||
|
||||
Returns:
|
||||
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
|
||||
decomposition functions.
|
||||
"""
|
||||
decomposition_table: dict[torch._ops.OperatorBase, Callable] = {}
|
||||
onnx_registered_ops = set(get_onnx_implemented_overloads(registry))
|
||||
|
||||
# NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
|
||||
# TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
|
||||
# definitions in a single TORCH_LIBRARY block.
|
||||
for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined]
|
||||
# Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX
|
||||
# symbolic function.
|
||||
# NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is
|
||||
# not exportable anyways.
|
||||
if op_overload in onnx_registered_ops:
|
||||
continue
|
||||
decomposition_table[op_overload] = decomp_fn
|
||||
|
||||
return decomposition_table
|
345
torch/onnx/_internal/exporter/_dispatching.py
Normal file
345
torch/onnx/_internal/exporter/_dispatching.py
Normal file
@ -0,0 +1,345 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
import onnxscript
|
||||
from onnxscript import ir
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.onnx._internal.exporter import _registration, _schemas
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define utilities to convert PyTorch data types so users do not need to specify manually
|
||||
_TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = {
|
||||
torch.bfloat16: ir.DataType.BFLOAT16,
|
||||
torch.bool: ir.DataType.BOOL,
|
||||
torch.complex128: ir.DataType.DOUBLE,
|
||||
torch.complex64: ir.DataType.FLOAT,
|
||||
torch.float16: ir.DataType.FLOAT16,
|
||||
torch.float32: ir.DataType.FLOAT,
|
||||
torch.float64: ir.DataType.DOUBLE,
|
||||
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
|
||||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||
torch.int16: ir.DataType.INT16,
|
||||
torch.int32: ir.DataType.INT32,
|
||||
torch.int64: ir.DataType.INT64,
|
||||
torch.int8: ir.DataType.INT8,
|
||||
torch.uint8: ir.DataType.UINT8,
|
||||
}
|
||||
|
||||
|
||||
def _torch_dtype_to_onnx_compatible_dtype(dtype: torch.dtype) -> ir.DataType:
|
||||
return _TORCH_DTYPE_TO_ONNX_COMPATIBLE[dtype]
|
||||
|
||||
|
||||
def _attribute_type_compatible_with_arg(
|
||||
attr: _schemas.AttributeParameter,
|
||||
value: ir.Value | int | float | bool | Sequence[int] | Sequence[float] | None,
|
||||
) -> bool:
|
||||
"""Check if the attribute type is compatible with the argument."""
|
||||
if isinstance(value, bool):
|
||||
return attr.type is ir.AttributeType.INT
|
||||
if isinstance(value, str):
|
||||
return attr.type is ir.AttributeType.STRING
|
||||
if isinstance(value, int):
|
||||
return attr.type in {ir.AttributeType.INT, ir.AttributeType.FLOAT}
|
||||
if isinstance(value, float):
|
||||
return attr.type is ir.AttributeType.FLOAT
|
||||
if isinstance(value, complex):
|
||||
return False
|
||||
if isinstance(value, Sequence):
|
||||
if attr.type is ir.AttributeType.INTS:
|
||||
return all(isinstance(i, int) for i in value)
|
||||
if attr.type is ir.AttributeType.FLOATS:
|
||||
return all(isinstance(i, (int, float)) for i in value)
|
||||
if isinstance(value, torch.dtype):
|
||||
return attr.type is ir.AttributeType.INT
|
||||
if isinstance(value, (torch.device, torch.memory_format, torch.layout)):
|
||||
return attr.type is ir.AttributeType.STRING
|
||||
if value is None and not attr.required:
|
||||
# An optional attribute is not supplied
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _param_type_compatible_with_arg(
|
||||
param: _schemas.Parameter,
|
||||
value: ir.TypeProtocol
|
||||
| str
|
||||
| int
|
||||
| float
|
||||
| complex
|
||||
| Sequence[int]
|
||||
| Sequence[float]
|
||||
| None,
|
||||
assigned_types: dict[str, ir.TypeProtocol],
|
||||
) -> bool:
|
||||
# Handle Python types first
|
||||
if isinstance(value, bool): # noqa: SIM102
|
||||
if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}:
|
||||
return True
|
||||
if isinstance(value, int) and param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.INT4),
|
||||
ir.TensorType(ir.DataType.INT8),
|
||||
ir.TensorType(ir.DataType.INT16),
|
||||
ir.TensorType(ir.DataType.INT32),
|
||||
ir.TensorType(ir.DataType.INT64),
|
||||
# Int inputs can be casted to a float too
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT16),
|
||||
ir.TensorType(ir.DataType.FLOAT),
|
||||
ir.TensorType(ir.DataType.DOUBLE),
|
||||
}:
|
||||
return True
|
||||
if isinstance(value, float) and param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT16),
|
||||
ir.TensorType(ir.DataType.FLOAT),
|
||||
ir.TensorType(ir.DataType.DOUBLE),
|
||||
}:
|
||||
return True
|
||||
if isinstance(value, complex) and param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.FLOAT),
|
||||
ir.TensorType(ir.DataType.DOUBLE),
|
||||
ir.TensorType(ir.DataType.COMPLEX64),
|
||||
ir.TensorType(ir.DataType.COMPLEX128),
|
||||
}:
|
||||
return True
|
||||
if isinstance(value, str): # noqa: SIM102
|
||||
if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}:
|
||||
return True
|
||||
if isinstance(value, (list, tuple)):
|
||||
if param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.INT32),
|
||||
ir.TensorType(ir.DataType.INT64),
|
||||
ir.TensorType(ir.DataType.FLOAT),
|
||||
ir.TensorType(ir.DataType.DOUBLE),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.INT32)),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.INT64)),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)),
|
||||
} and all(isinstance(i, (int)) for i in value):
|
||||
# We will just allow any fx node and trust that the overload handles it
|
||||
return True
|
||||
if param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.FLOAT),
|
||||
ir.TensorType(ir.DataType.DOUBLE),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
|
||||
ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)),
|
||||
} and all(isinstance(i, (int, float)) for i in value):
|
||||
# We will just allow any fx node and trust that the overload handles it
|
||||
return True
|
||||
if value is None and not param.required:
|
||||
# An optional parameter is not supplied
|
||||
return True
|
||||
|
||||
if not isinstance(value, ir.TypeProtocol):
|
||||
return False
|
||||
|
||||
# Then check tensor types
|
||||
if param.type_constraint.name in assigned_types:
|
||||
# If a typevar is already bound, check if the value has the same type
|
||||
assigned_type = assigned_types[param.type_constraint.name]
|
||||
return assigned_type == value
|
||||
# If the typevar is not bound, bind it to the value type
|
||||
if value in param.type_constraint.allowed_types:
|
||||
# TODO: Maybe just check dtype? Being more strict here for now
|
||||
assigned_types[param.type_constraint.name] = value
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_type_from_tensor(
|
||||
tensor: torch.Tensor | Sequence[torch.Tensor],
|
||||
) -> ir.TypeProtocol:
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype))
|
||||
first_tensor = next((item for item in tensor if item is not None), None)
|
||||
if first_tensor is None:
|
||||
return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED))
|
||||
return ir.SequenceType(
|
||||
ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(first_tensor.dtype))
|
||||
)
|
||||
|
||||
|
||||
def _get_first_tensor_in_node_list(
|
||||
nodes: Sequence[torch.fx.Node | None],
|
||||
) -> torch.Tensor | None:
|
||||
for node in nodes:
|
||||
if (
|
||||
node is not None
|
||||
and "val" in node.meta
|
||||
and isinstance(node.meta["val"], torch.Tensor)
|
||||
):
|
||||
return node.meta["val"]
|
||||
return None
|
||||
|
||||
|
||||
def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]:
|
||||
# FIXME: node.target may not have a schema
|
||||
torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr]
|
||||
node_args = {}
|
||||
for arg, schema_arg in zip(node.args, torch_schema.arguments):
|
||||
node_args[schema_arg.name] = arg
|
||||
|
||||
node_args.update(node.kwargs)
|
||||
return node_args
|
||||
|
||||
|
||||
def get_matching_overload(
|
||||
node: torch.fx.Node,
|
||||
overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction],
|
||||
) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]:
|
||||
"""Get the overload that matches the node's arguments.
|
||||
|
||||
Args:
|
||||
node: The node to match.
|
||||
overloads: The overloads to match against.
|
||||
|
||||
Returns:
|
||||
A tuple containing the matched overload and a string describing the reason for failure or success.
|
||||
"""
|
||||
named_args = _get_named_fx_node_args(node)
|
||||
# FIXME: node.target may and builtin and not have a schema
|
||||
# FIXME: Handle when we don't know the names of the arguments
|
||||
schema_args: dict[str, torch.Argument] = {
|
||||
arg.name: arg
|
||||
for arg in node.target._schema.arguments # type: ignore[union-attr]
|
||||
}
|
||||
failure_messages: list[str] = []
|
||||
for overload in overloads:
|
||||
assigned_types: dict[str, ir.TypeProtocol] = {}
|
||||
fail_reason = ""
|
||||
if not hasattr(overload, "signature"):
|
||||
# When an overload does not have a signature, we assume it is a custom op and should be matched
|
||||
return (
|
||||
overload,
|
||||
"The overload does not have a signature. Assuming it is a custom op and matching it.",
|
||||
)
|
||||
for param in overload.signature:
|
||||
if param.name not in schema_args and param.required:
|
||||
# We don't need to handle variadic inputs as there is none.
|
||||
# A required parameter is not supplied.
|
||||
fail_reason = "Required parameter not supplied"
|
||||
break
|
||||
|
||||
# Get the argument
|
||||
if param.name in named_args:
|
||||
# Provided in Node args
|
||||
arg = named_args[param.name]
|
||||
elif (
|
||||
param.name in schema_args
|
||||
and schema_args[param.name].has_default_value()
|
||||
):
|
||||
# Provided in schema args
|
||||
arg = schema_args[param.name].default_value
|
||||
elif param.has_default():
|
||||
# Provided in the ONNX op definition
|
||||
arg = param.default
|
||||
else:
|
||||
fail_reason = "Parameter not provided"
|
||||
break
|
||||
|
||||
if isinstance(param, _schemas.Parameter):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg = _get_type_from_tensor(arg) # type: ignore[assignment]
|
||||
if isinstance(arg, (list, tuple)) and any(
|
||||
isinstance(t, torch.fx.Node) for t in arg
|
||||
):
|
||||
first_tensor = _get_first_tensor_in_node_list(arg)
|
||||
assert first_tensor is not None
|
||||
# FIXME: Handle symfloat here
|
||||
arg = ir.SequenceType(_get_type_from_tensor(first_tensor)) # type: ignore[assignment]
|
||||
elif isinstance(arg, torch.fx.Node):
|
||||
meta_val = arg.meta["val"]
|
||||
arg = _get_type_from_tensor(meta_val) # type: ignore[assignment]
|
||||
# TODO: Handle None attributes
|
||||
# FIXME: Handle symfloat etc.
|
||||
# Handle tensors and Python values
|
||||
if not _param_type_compatible_with_arg(param, arg, assigned_types): # type: ignore[arg-type]
|
||||
fail_reason = (
|
||||
f"Parameter type not compatible with argument: param=`{param}`, "
|
||||
f"assigned_types=`{assigned_types}`, arg=`{arg}`"
|
||||
)
|
||||
break
|
||||
elif isinstance(param, _schemas.AttributeParameter):
|
||||
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
|
||||
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
|
||||
break
|
||||
if not fail_reason:
|
||||
return overload, "Successfully matched overload"
|
||||
else:
|
||||
failure_messages.append(
|
||||
f"- Failed to match overload `{overload}`: {fail_reason}"
|
||||
)
|
||||
return (
|
||||
None,
|
||||
f"All overloads did not match the node `{node.format_node()}`.\n"
|
||||
+ "\n".join(failure_messages),
|
||||
)
|
||||
|
||||
|
||||
def _arg_has_complex_dtype(arg) -> bool:
|
||||
"""Check if the node has complex dtype recursively."""
|
||||
if (
|
||||
isinstance(arg, torch.fx.Node)
|
||||
and "val" in arg.meta
|
||||
and isinstance(arg.meta["val"], torch.Tensor)
|
||||
and torch.is_complex(arg.meta["val"])
|
||||
):
|
||||
return True
|
||||
elif isinstance(arg, list):
|
||||
return any(_arg_has_complex_dtype(item) for item in arg)
|
||||
return False
|
||||
|
||||
|
||||
def dispatch(
|
||||
node: torch.fx.Node, registry: _registration.ONNXRegistry
|
||||
) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]:
|
||||
"""Dispatch a node to an ONNX function based on the node's target and the ONNX registry.
|
||||
|
||||
Args:
|
||||
node: The node to dispatch.
|
||||
registry: The ONNX registry to use for dispatching.
|
||||
|
||||
Returns:
|
||||
A tuple containing the matched ONNX function and a string describing the reason for failure or success.
|
||||
"""
|
||||
# TODO: Handle when node does not have a target
|
||||
decomp_metas = registry.get_decomps(node.target) # type: ignore[arg-type]
|
||||
# Determine if the node has complex inputs.
|
||||
is_complex = any(_arg_has_complex_dtype(arg) for arg in node.args) or any(
|
||||
_arg_has_complex_dtype(arg) for arg in node.kwargs.values()
|
||||
)
|
||||
if is_complex:
|
||||
decomp_metas = [decomp for decomp in decomp_metas if decomp.is_complex]
|
||||
if not decomp_metas:
|
||||
return None, "No decompositions registered for the complex-valued input"
|
||||
else:
|
||||
decomp_metas = [decomp for decomp in decomp_metas if not decomp.is_complex]
|
||||
if not decomp_metas:
|
||||
return None, "No decompositions registered for the real-valued input"
|
||||
|
||||
if len(decomp_metas) == 1:
|
||||
return (
|
||||
decomp_metas[0].onnx_function,
|
||||
"Fast path: Only one decomposition is defined",
|
||||
)
|
||||
|
||||
overload, message = get_matching_overload(
|
||||
node, [decomp.onnx_function for decomp in decomp_metas]
|
||||
)
|
||||
return overload, message
|
72
torch/onnx/_internal/exporter/_fx_passes.py
Normal file
72
torch/onnx/_internal/exporter/_fx_passes.py
Normal file
@ -0,0 +1,72 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.fx
|
||||
from torch.onnx._internal.exporter import _decomp, _registration
|
||||
from torch.onnx._internal.fx import diagnostics, passes
|
||||
|
||||
|
||||
_ATEN_ASSERTION_TARGETS = frozenset(
|
||||
{
|
||||
torch.ops.aten.sym_constrain_range_for_size.default,
|
||||
torch.ops.aten._assert_async.msg,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decompose_with_registry(
|
||||
exported_program: torch.export.ExportedProgram, registry: _registration.ONNXRegistry
|
||||
) -> torch.export.ExportedProgram:
|
||||
"""Decompose the exported program with the given registry.
|
||||
|
||||
This function is needed so it shows clearly on the profiler results.
|
||||
"""
|
||||
decomp_table = _decomp.create_onnx_friendly_decomposition_table(registry)
|
||||
onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry))
|
||||
# Try to preserve some known CompositeImplicitAutograd ops
|
||||
aten = torch.ops.aten
|
||||
to_preserve = {
|
||||
aten._upsample_bilinear2d_aa.default,
|
||||
aten._upsample_nearest_exact1d.vec,
|
||||
aten._upsample_nearest_exact2d.vec,
|
||||
aten._upsample_nearest_exact3d.vec,
|
||||
aten.group_norm.default,
|
||||
aten.linear.default,
|
||||
aten.upsample_bilinear2d.default,
|
||||
aten.upsample_bilinear2d.vec,
|
||||
aten.upsample_linear1d.default,
|
||||
aten.upsample_linear1d.vec,
|
||||
aten.upsample_nearest1d.default,
|
||||
aten.upsample_nearest1d.vec,
|
||||
aten.upsample_nearest2d.default,
|
||||
aten.upsample_nearest2d.vec,
|
||||
aten.upsample_nearest3d.default,
|
||||
aten.upsample_nearest3d.vec,
|
||||
aten.upsample_trilinear3d.default,
|
||||
aten.upsample_trilinear3d.vec,
|
||||
}
|
||||
# We can only preserve implemented ops
|
||||
can_preserve = tuple(to_preserve.intersection(onnx_registered_ops))
|
||||
return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve)
|
||||
|
||||
|
||||
def insert_type_promotion_nodes(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Inplace pass to insert explicit type promotion nodes."""
|
||||
diagnostic_context = diagnostics.DiagnosticContext(
|
||||
"torch.onnx.export",
|
||||
torch.__version__,
|
||||
)
|
||||
return passes.InsertTypePromotion(diagnostic_context, graph_module).run()
|
||||
|
||||
|
||||
def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
"""Remove all assertion and check nodes from the FX graph"""
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_function" and node.target in _ATEN_ASSERTION_TARGETS:
|
||||
graph_module.graph.erase_node(node)
|
||||
graph_module.recompile()
|
||||
return graph_module
|
41
torch/onnx/_internal/exporter/_ir_passes.py
Normal file
41
torch/onnx/_internal/exporter/_ir_passes.py
Normal file
@ -0,0 +1,41 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from onnxscript import ir
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rename_inputs(model: ir.Model, new_names: Sequence[str]) -> None:
|
||||
# TODO: Ensure the names do not have duplicates
|
||||
for input, new_name in zip(model.graph.inputs, new_names):
|
||||
input.metadata_props["pkg.torch.onnx.original_node_name"] = str(input.name)
|
||||
input.name = new_name
|
||||
|
||||
|
||||
def rename_outputs(model: ir.Model, new_names: Sequence[str]) -> None:
|
||||
for output, new_name in zip(model.graph.outputs, new_names):
|
||||
output.metadata_props["pkg.torch.onnx.original_node_name"] = str(output.name)
|
||||
output.name = new_name
|
||||
|
||||
|
||||
def add_torchlib_common_imports(model: ir.Model) -> None:
|
||||
"""Hack to add torchlib common imports to the model."""
|
||||
|
||||
try:
|
||||
# TODO(justinchuby): Remove this hack and improved onnxscript
|
||||
from onnxscript.function_libs.torch_lib.ops import common as common_ops
|
||||
|
||||
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
||||
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
||||
is_scalar_func = ir.serde.deserialize_function(
|
||||
common_ops.IsScalar.to_function_proto()
|
||||
)
|
||||
model.functions[rank_func.identifier()] = rank_func
|
||||
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
||||
except Exception:
|
||||
logger.exception("Failed to add torchlib common imports to the model.")
|
55
torch/onnx/_internal/exporter/_isolated.py
Normal file
55
torch/onnx/_internal/exporter/_isolated.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Isolated calls to methods that may segfault."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
|
||||
_IS_WINDOWS = os.name == "nt"
|
||||
|
||||
|
||||
def _call_function_and_return_exception(func, args, kwargs):
|
||||
"""Call function and return a exception if there is one."""
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def safe_call(func: Callable, *args, **kwargs):
|
||||
"""Call a function in a separate process.
|
||||
|
||||
Args:
|
||||
func: The function to call.
|
||||
args: The positional arguments to pass to the function.
|
||||
kwargs: The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
The return value of the function.
|
||||
|
||||
Raises:
|
||||
Exception: If the function raised an exception.
|
||||
"""
|
||||
if _IS_WINDOWS:
|
||||
# On Windows, we cannot create a new process with fork.
|
||||
warnings.warn(
|
||||
f"A new process is not created for {func} on Windows.", stacklevel=1
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
# It is important to fork a process here to prevent the main logic from
|
||||
# running again when the user does not place it under a `if __name__ == "__main__":`
|
||||
# block.
|
||||
result = pool.apply_async(
|
||||
_call_function_and_return_exception, (func, args, kwargs)
|
||||
)
|
||||
result = result.get(timeout=5)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
288
torch/onnx/_internal/exporter/_onnx_program.py
Normal file
288
torch/onnx/_internal/exporter/_onnx_program.py
Normal file
@ -0,0 +1,288 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="attr-defined,name-defined"
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["ONNXProgram"]
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Callable, IO, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import _lazy_import
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
onnx = _lazy_import.onnx
|
||||
ir = _lazy_import.onnxscript_ir
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession:
|
||||
"""Initialize an ONNX Runtime inference session with the specified model."""
|
||||
import onnxruntime as ort
|
||||
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.log_severity_level = 3 # 3: Error
|
||||
possible_providers = (
|
||||
"CUDAExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
)
|
||||
available_providers = set(ort.get_available_providers())
|
||||
providers = [
|
||||
provider for provider in possible_providers if provider in available_providers
|
||||
]
|
||||
return ort.InferenceSession(
|
||||
model, providers=providers, sess_options=session_options
|
||||
)
|
||||
|
||||
|
||||
class ONNXProgram:
|
||||
"""A substitute class for `torch.onnx.ONNXProgram`."""
|
||||
|
||||
def __init__(self, model: ir.Model, exported_program: torch.export.ExportedProgram):
|
||||
self.model: ir.Model = model
|
||||
self.exported_program = exported_program
|
||||
self._inference_session: ort.InferenceSession | None = None
|
||||
self._tempdir: tempfile.TemporaryDirectory | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"""\
|
||||
ONNXProgram(
|
||||
model=
|
||||
{textwrap.indent(str(self.model), ' ' * 8)}
|
||||
,
|
||||
exported_program=
|
||||
{textwrap.indent(str(self.exported_program), ' ' * 8)}
|
||||
)
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]:
|
||||
"""Run the ONNX model with the same arguments you would provide to the GraphModule."""
|
||||
import onnxruntime as ort
|
||||
|
||||
flatten_args = _process_args(args, kwargs)
|
||||
|
||||
if self._inference_session is None:
|
||||
self.initialize_inference_session()
|
||||
|
||||
assert self._inference_session is not None
|
||||
|
||||
# We don't expect non-tensor as inputs
|
||||
ort_input = {
|
||||
k.name: v.numpy(force=True)
|
||||
for k, v in zip(self.model.graph.inputs, flatten_args)
|
||||
}
|
||||
run_options = ort.RunOptions()
|
||||
run_options.log_severity_level = 3 # 3: Error
|
||||
logger.debug("Running the inference session with %s arguments.", len(ort_input))
|
||||
outputs = self._inference_session.run(None, ort_input, run_options=run_options)
|
||||
logger.debug("Inference session run completed.")
|
||||
# TODO(justinchuby): Maybe output complex tensors as needed
|
||||
return tuple(torch.from_numpy(output) for output in outputs)
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto:
|
||||
"""Compatibility property for `torch.onnx.ONNXProgram.model_proto`."""
|
||||
return ir.serde.serialize_model(self.model)
|
||||
|
||||
def save(
|
||||
self,
|
||||
destination: str | os.PathLike | IO[bytes],
|
||||
*,
|
||||
include_initializers: bool = True,
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
external_data: bool | None = None,
|
||||
**_,
|
||||
):
|
||||
"""Save the ONNX model to the specified destination.
|
||||
|
||||
When `external_data` is `True` or the model is larger than 2GB,
|
||||
the weights are saved as external data in a separate file.
|
||||
|
||||
Args:
|
||||
destination: The path to save the ONNX model to.
|
||||
include_initializers: Whether to include the initializers in the saved model.
|
||||
keep_initializers_as_inputs: Whether to keep the initializers as inputs in the saved model.
|
||||
If `True`, the initializers are added as inputs to the model which means they can be overwritten.
|
||||
by providing the initializers as model inputs.
|
||||
external_data: Whether to save the weights as external data in a separate file.
|
||||
|
||||
Raises:
|
||||
TypeError: If `external_data` is `True` and `destination` is not a file path.
|
||||
"""
|
||||
if not include_initializers:
|
||||
self.model.graph.initializers.clear()
|
||||
logger.warning(
|
||||
"The initializers have been removed from the model. This is destructive. "
|
||||
"Developers: Please implement ir.Model copy() and remove initializers on the copied model."
|
||||
)
|
||||
if keep_initializers_as_inputs:
|
||||
self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type]
|
||||
logger.warning(
|
||||
"The initializers have been added as inputs to the model. This is destructive. "
|
||||
"Developers: Please implement ir.Model copy() and remove initializers on the copied model."
|
||||
)
|
||||
proto = ir.serde.serialize_model(self.model)
|
||||
byte_size = proto.ByteSize()
|
||||
model_too_large = (byte_size) >= 1 << 31
|
||||
if external_data or model_too_large:
|
||||
# TODO: Create an IR pass to handle external tensors conversion
|
||||
if model_too_large:
|
||||
logger.warning(
|
||||
"The serialized ONNX model is larger than 2GB (%s). "
|
||||
"Saving the weights as external data in a separate file.",
|
||||
byte_size,
|
||||
)
|
||||
if not isinstance(destination, (str, os.PathLike)):
|
||||
raise TypeError(
|
||||
"Saving the weights as external data is only supported when destination is a file path"
|
||||
)
|
||||
destination_path = pathlib.Path(destination)
|
||||
# Create the directory if it does not exist
|
||||
data_path = f"{destination_path.name}.data"
|
||||
onnx.save_model(
|
||||
proto,
|
||||
destination,
|
||||
save_as_external_data=True,
|
||||
location=data_path,
|
||||
)
|
||||
else:
|
||||
onnx.save_model(proto, destination)
|
||||
|
||||
def initialize_inference_session(
|
||||
self,
|
||||
initializer: Callable[
|
||||
[str | bytes], ort.InferenceSession
|
||||
] = _ort_session_initializer,
|
||||
) -> None:
|
||||
"""Initialize the ONNX Runtime inference session.
|
||||
|
||||
Args:
|
||||
initializer: The function to initialize the ONNX Runtime inference
|
||||
session with the specified model. By default, it uses the
|
||||
:func:`_ort_session_initializer` function.
|
||||
"""
|
||||
# TODO(justinchuby): Allow different inference options
|
||||
logger.debug("Initializing the inference session.")
|
||||
proto = ir.serde.serialize_model(self.model)
|
||||
byte_size = proto.ByteSize()
|
||||
model_too_large = (byte_size) >= 1 << 31
|
||||
|
||||
if model_too_large:
|
||||
logger.debug(
|
||||
"The serialized ONNX model is larger than 2GB (%s).", byte_size
|
||||
)
|
||||
# Save the model to a temporary file if too large
|
||||
self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
model_path = os.path.join(self._tempdir.name, "model.onnx")
|
||||
data_path = "model.onnx.data"
|
||||
onnx.save_model(
|
||||
proto,
|
||||
model_path,
|
||||
save_as_external_data=True,
|
||||
location=data_path,
|
||||
)
|
||||
model = model_path
|
||||
else:
|
||||
model = proto.SerializeToString() # type: ignore[assignment]
|
||||
|
||||
self._inference_session = initializer(model)
|
||||
logger.debug("Inference session initialized.")
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the inference session.
|
||||
|
||||
You may call this method to release the resources used by the inference session.
|
||||
"""
|
||||
# Release the inference session first so that the model file can be deleted
|
||||
if self._inference_session is not None:
|
||||
self._inference_session = None
|
||||
gc.collect()
|
||||
if self._tempdir is not None:
|
||||
self._tempdir.cleanup()
|
||||
self._tempdir = None
|
||||
|
||||
|
||||
def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:
|
||||
"""Process input arguments for the ONNX model."""
|
||||
args = _flatten_inputs(args, kwargs)
|
||||
args = _remove_none_from_inputs(args)
|
||||
args = _remove_non_tensor(args)
|
||||
args = _convert_complex_to_real_representation(args)
|
||||
return args
|
||||
|
||||
|
||||
def _flatten_inputs(model_args, model_kwargs):
|
||||
flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs))
|
||||
return flattened_args
|
||||
|
||||
|
||||
def _remove_none_from_inputs(model_args):
|
||||
return tuple(arg for arg in model_args if arg is not None)
|
||||
|
||||
|
||||
def _remove_non_tensor(model_args):
|
||||
"""Remove the non-tensor input arguments.
|
||||
|
||||
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
|
||||
|
||||
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
|
||||
The concrete value is embedded into the graph as a constant arg of a target node. Meta
|
||||
suggests in this case that one should rewrite the model code to make it tensor if the
|
||||
input value is supposed to change at runtime. We might need to further investigate
|
||||
the feasibility of that suggestion.
|
||||
|
||||
For example,
|
||||
|
||||
def func(x, b=1.0):
|
||||
y = x + b
|
||||
z = y.relu()
|
||||
return (y, z)
|
||||
|
||||
x = torch.randn(1, 1, 2, dtype=torch.float32)
|
||||
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
|
||||
|
||||
# class GraphModule(torch.nn.Module):
|
||||
# def forward(self, x, b):
|
||||
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
|
||||
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
|
||||
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
|
||||
|
||||
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
|
||||
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
|
||||
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
|
||||
|
||||
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
|
||||
it's ignored in ONNX graph. Thus, we delete the useless input here.
|
||||
|
||||
"""
|
||||
|
||||
return tuple(
|
||||
arg for arg in model_args if not isinstance(arg, (int, float, bool, str))
|
||||
)
|
||||
|
||||
|
||||
def _convert_complex_to_real_representation(model_args):
|
||||
"""Convert complex dtype tensors to real representation tensors.
|
||||
|
||||
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
|
||||
to real representation tensors (i.e., float dtype tensors with an extra dimension
|
||||
representing the real and imaginary parts of the complex number).
|
||||
"""
|
||||
return tuple(
|
||||
torch.view_as_real(arg.resolve_conj())
|
||||
if isinstance(arg, torch.Tensor) and arg.is_complex()
|
||||
else arg
|
||||
for arg in model_args
|
||||
)
|
275
torch/onnx/_internal/exporter/_registration.py
Normal file
275
torch/onnx/_internal/exporter/_registration.py
Normal file
@ -0,0 +1,275 @@
|
||||
"""Module for handling ATen to ONNX functions registration.
|
||||
|
||||
https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py
|
||||
"""
|
||||
|
||||
# NOTE: Why do we need a different registry than the one in torchlib?
|
||||
# The registry in torchlib is used to register functions that are already implemented in
|
||||
# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different
|
||||
# opsets etc. The registry implemented for the exporter is designed to be modifiable at
|
||||
# export time by users, and is designed with dispatching in mind.
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
import typing
|
||||
from typing import Callable, Literal, Mapping, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
from torch.onnx._internal.exporter import _schemas
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import onnxscript
|
||||
from onnxscript.function_libs.torch_lib import registration as torchlib_registration
|
||||
|
||||
_DEFAULT_OPSET_VERSION = 18
|
||||
|
||||
|
||||
TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OnnxDecompMeta:
|
||||
"""A wrapper of onnx-script function with additional metadata.
|
||||
|
||||
onnx_function: The onnx-script function from torchlib.
|
||||
fx_target: The PyTorch node callable target.
|
||||
is_custom: Whether the function is a custom function.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
device: The device the function is registered to. If None, it is registered to all devices.
|
||||
"""
|
||||
|
||||
onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction
|
||||
fx_target: TorchOp
|
||||
is_custom: bool = False
|
||||
is_complex: bool = False
|
||||
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
||||
|
||||
|
||||
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
|
||||
"""Obtain the torch op from <namespace>::<op_name>[.<overload>]"""
|
||||
# TODO(justinchuby): Handle arbitrary custom ops
|
||||
namespace, opname_overload = qualified_name.split("::")
|
||||
op_name, *maybe_overload = opname_overload.split(".", 1)
|
||||
if namespace == "_operator":
|
||||
# Builtin functions
|
||||
return getattr(operator, op_name)
|
||||
if namespace == "math":
|
||||
return getattr(math, op_name)
|
||||
if namespace == "torchvision":
|
||||
try:
|
||||
import torchvision.ops # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
|
||||
return None
|
||||
try:
|
||||
return getattr(torchvision.ops, op_name)
|
||||
except AttributeError:
|
||||
logger.warning("Failed to find torchvision op '%s'", qualified_name)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to find torchvision op '%s'", qualified_name)
|
||||
try:
|
||||
op_packet = getattr(getattr(torch.ops, namespace), op_name)
|
||||
if maybe_overload:
|
||||
overload = maybe_overload[0]
|
||||
elif "default" in op_packet._overload_names or "" in op_packet._overload_names:
|
||||
# Has a default overload
|
||||
overload = "default"
|
||||
else:
|
||||
logger.warning(
|
||||
"'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.",
|
||||
qualified_name,
|
||||
stacklevel=1,
|
||||
)
|
||||
return None
|
||||
|
||||
return getattr(op_packet, overload) # type: ignore[call-overload]
|
||||
except AttributeError:
|
||||
if qualified_name.endswith("getitem"):
|
||||
# This is a special case where we registered the function incorrectly,
|
||||
# but for BC reasons (pt<=2.4) we need to keep it.
|
||||
return None
|
||||
logger.info("'%s' is not found in this version of PyTorch.", qualified_name)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to find torch op '%s'", qualified_name)
|
||||
return None
|
||||
|
||||
|
||||
class ONNXRegistry:
|
||||
"""Registry for ONNX functions.
|
||||
|
||||
The registry maintains a mapping from qualified names to symbolic functions under a
|
||||
fixed opset version. It supports registering custom onnx-script functions and for
|
||||
dispatcher to dispatch calls to the appropriate function.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the registry"""
|
||||
|
||||
# TODO: Design multi-opset version support
|
||||
self._opset_version = _DEFAULT_OPSET_VERSION
|
||||
|
||||
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}
|
||||
|
||||
@property
|
||||
def opset_version(self) -> int:
|
||||
"""The ONNX opset version the exporter should target.
|
||||
|
||||
Defaults to the latest supported ONNX opset version: 18.
|
||||
The default version will increment over time as ONNX continues to evolve.
|
||||
"""
|
||||
|
||||
return self._opset_version
|
||||
|
||||
@classmethod
|
||||
def from_torchlib(
|
||||
cls,
|
||||
torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction]
|
||||
| None = None,
|
||||
) -> ONNXRegistry:
|
||||
"""Populates the registry with ATen functions from torchlib.
|
||||
|
||||
Args:
|
||||
torchlib_registry: The torchlib registry to use for populating the registry.
|
||||
"""
|
||||
registry = cls()
|
||||
if torchlib_registry is None:
|
||||
from onnxscript.function_libs.torch_lib import (
|
||||
registration as torchlib_registration,
|
||||
)
|
||||
|
||||
torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment]
|
||||
for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr]
|
||||
try:
|
||||
# NOTE: This is heavily guarded with try-except because we don't want
|
||||
# to fail the entire registry population if one function fails.
|
||||
if qualified_name.startswith("internal::"):
|
||||
# Skip the custom defined internal functions
|
||||
continue
|
||||
target = _get_overload(qualified_name)
|
||||
if target is None:
|
||||
continue
|
||||
for overload_func in aten_overloads_func.overloads:
|
||||
overload_func.signature = _schemas.OpSignature.from_function(
|
||||
overload_func,
|
||||
overload_func.function_ir.domain,
|
||||
overload_func.name,
|
||||
)
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=overload_func,
|
||||
fx_target=target,
|
||||
is_custom=False,
|
||||
is_complex=False,
|
||||
)
|
||||
registry._register(target, onnx_decomposition)
|
||||
|
||||
for complex_func in aten_overloads_func.complex:
|
||||
overload_func.signature = _schemas.OpSignature.from_function(
|
||||
overload_func,
|
||||
overload_func.function_ir.domain,
|
||||
overload_func.name,
|
||||
)
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=complex_func,
|
||||
fx_target=target,
|
||||
is_custom=False,
|
||||
is_complex=True,
|
||||
)
|
||||
registry._register(target, onnx_decomposition)
|
||||
except Exception:
|
||||
logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
||||
continue
|
||||
return registry
|
||||
|
||||
def _register(
|
||||
self,
|
||||
target: TorchOp,
|
||||
onnx_decomposition: OnnxDecompMeta,
|
||||
) -> None:
|
||||
"""Registers a OnnxDecompMeta to an operator.
|
||||
|
||||
Args:
|
||||
target: The PyTorch node callable target.
|
||||
onnx_decomposition: The OnnxDecompMeta to register.
|
||||
"""
|
||||
target_or_name: str | TorchOp
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
|
||||
# a dictionary is unreliable for some reason.
|
||||
target_or_name = target.name()
|
||||
else:
|
||||
target_or_name = target
|
||||
if onnx_decomposition.is_custom:
|
||||
self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition)
|
||||
else:
|
||||
self.functions.setdefault(target_or_name, []).append(onnx_decomposition)
|
||||
|
||||
def register_op(
|
||||
self,
|
||||
target: TorchOp,
|
||||
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
|
||||
is_complex: bool = False,
|
||||
) -> None:
|
||||
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
Args:
|
||||
target: The PyTorch node callable target.
|
||||
function: The onnx-script function to register.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
"""
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=function,
|
||||
fx_target=target,
|
||||
is_custom=True,
|
||||
is_complex=is_complex,
|
||||
)
|
||||
self._register(target, onnx_decomposition)
|
||||
|
||||
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
|
||||
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
The list is ordered by the time of registration. The custom operators should come
|
||||
first in the list.
|
||||
|
||||
Args:
|
||||
target: The PyTorch node callable target.
|
||||
Returns:
|
||||
A list of OnnxDecompMeta corresponding to the given name, or None if
|
||||
the name is not in the registry.
|
||||
"""
|
||||
target_or_name: str | TorchOp
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
|
||||
# a dictionary is unreliable for some reason.
|
||||
target_or_name = target.name()
|
||||
else:
|
||||
target_or_name = target
|
||||
decomps = self.functions.get(target_or_name, [])
|
||||
return sorted(decomps, key=lambda x: x.is_custom, reverse=True)
|
||||
|
||||
def is_registered(self, target: TorchOp) -> bool:
|
||||
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
Args:
|
||||
target: The PyTorch node callable target.
|
||||
|
||||
Returns:
|
||||
True if the given op is registered, otherwise False.
|
||||
"""
|
||||
return bool(self.get_decomps(target))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(functions={self.functions})"
|
193
torch/onnx/_internal/exporter/_reporting.py
Normal file
193
torch/onnx/_internal/exporter/_reporting.py
Normal file
@ -0,0 +1,193 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.onnx._internal.exporter import _analysis, _registration, _verification
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
from onnxscript import ir
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExportStatus:
|
||||
# Whether torch.export.export.export() succeeds
|
||||
torch_export: bool | None = None
|
||||
# Whether torch.export.export.export(..., strict=False) succeeds
|
||||
torch_export_non_strict: bool | None = None
|
||||
# Whether torch.jit.trace succeeds
|
||||
torch_jit: bool | None = None
|
||||
# Whether ONNX translation succeeds
|
||||
onnx_translation: bool | None = None
|
||||
# Whether ONNX model passes onnx.checker.check_model
|
||||
onnx_checker: bool | None = None
|
||||
# Whether ONNX model runs successfully with ONNX Runtime
|
||||
onnx_runtime: bool | None = None
|
||||
# Whether the output of the ONNX model is accurate
|
||||
output_accuracy: bool | None = None
|
||||
|
||||
|
||||
def _status_emoji(status: bool | None) -> str:
|
||||
if status is None:
|
||||
return "⚪"
|
||||
return "✅" if status else "❌"
|
||||
|
||||
|
||||
def _format_export_status(status: ExportStatus) -> str:
|
||||
return (
|
||||
f"```\n"
|
||||
f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n"
|
||||
f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
|
||||
f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n"
|
||||
f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
|
||||
f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"
|
||||
f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n"
|
||||
f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n"
|
||||
f"```\n\n"
|
||||
)
|
||||
|
||||
|
||||
def _strip_color_from_string(text: str) -> str:
|
||||
# This regular expression matches ANSI escape codes
|
||||
# https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788
|
||||
ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
|
||||
return ansi_escape.sub("", text)
|
||||
|
||||
|
||||
def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str:
|
||||
# Adapted from https://github.com/pytorch/pytorch/pull/128476
|
||||
# to remove colors
|
||||
# Even though we can call graph_module.print_readable directly, since the
|
||||
# colored option was added only recently, we can't guarantee that the
|
||||
# version of PyTorch used by the user has this option. Therefore, we
|
||||
# still call str(ExportedProgram)
|
||||
text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n"
|
||||
return text
|
||||
|
||||
|
||||
def construct_report_file_name(timestamp: str, status: ExportStatus) -> str:
|
||||
# Status could be None. So we need to check for False explicitly.
|
||||
if not (status.torch_export or status.torch_export_non_strict or status.torch_jit):
|
||||
# All strategies failed
|
||||
postfix = "pt_export"
|
||||
elif status.onnx_translation is False:
|
||||
postfix = "conversion"
|
||||
elif status.onnx_checker is False:
|
||||
postfix = "checker"
|
||||
elif status.onnx_runtime is False:
|
||||
postfix = "runtime"
|
||||
elif status.output_accuracy is False:
|
||||
postfix = "accuracy"
|
||||
elif status.torch_export is False or status.torch_export_non_strict is False:
|
||||
# Some strategies failed
|
||||
postfix = "strategies"
|
||||
else:
|
||||
postfix = "success"
|
||||
return f"onnx_export_{timestamp}_{postfix}.md"
|
||||
|
||||
|
||||
def format_decomp_comparison(
|
||||
pre_decomp_unique_ops: set[str],
|
||||
post_decomp_unique_ops: set[str],
|
||||
) -> str:
|
||||
"""Format the decomposition comparison result.
|
||||
|
||||
Args:
|
||||
unique_ops_in_a: The unique ops in the first program.
|
||||
unique_ops_in_b: The unique ops in the second program.
|
||||
|
||||
Returns:
|
||||
The formatted comparison result.
|
||||
"""
|
||||
return (
|
||||
f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n"
|
||||
f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n"
|
||||
)
|
||||
|
||||
|
||||
def format_verification_infos(
|
||||
verification_infos: list[_verification.VerificationInfo],
|
||||
) -> str:
|
||||
"""Format the verification result.
|
||||
|
||||
Args:
|
||||
verification_infos: The verification result.
|
||||
|
||||
Returns:
|
||||
The formatted verification result.
|
||||
"""
|
||||
return "\n".join(
|
||||
f"`{info.name}`: `abs_diff={info.absolute_difference:e}`, `rel_diff={info.relative_difference:e}`"
|
||||
for info in verification_infos
|
||||
)
|
||||
|
||||
|
||||
def create_torch_export_error_report(
|
||||
filename: str | os.PathLike,
|
||||
formatted_traceback: str,
|
||||
*,
|
||||
export_status: ExportStatus,
|
||||
profile_result: str | None,
|
||||
):
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("# PyTorch ONNX Conversion Error Report\n\n")
|
||||
f.write(_format_export_status(export_status))
|
||||
f.write("Error message:\n\n")
|
||||
f.write("```pytb\n")
|
||||
f.write(formatted_traceback)
|
||||
f.write("```\n\n")
|
||||
if profile_result is not None:
|
||||
f.write("## Profiling result\n\n")
|
||||
f.write("```\n")
|
||||
f.write(profile_result)
|
||||
f.write("```\n")
|
||||
|
||||
|
||||
def create_onnx_export_report(
|
||||
filename: str | os.PathLike,
|
||||
formatted_traceback: str,
|
||||
program: torch.export.ExportedProgram,
|
||||
*,
|
||||
decomp_comparison: str | None = None,
|
||||
export_status: ExportStatus,
|
||||
profile_result: str | None,
|
||||
model: ir.Model | None = None,
|
||||
registry: _registration.ONNXRegistry | None = None,
|
||||
verification_result: str | None = None,
|
||||
):
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("# PyTorch ONNX Conversion Report\n\n")
|
||||
f.write(_format_export_status(export_status))
|
||||
f.write("## Error messages\n\n")
|
||||
f.write("```pytb\n")
|
||||
f.write(formatted_traceback)
|
||||
f.write("\n```\n\n")
|
||||
f.write("## Exported program\n\n")
|
||||
f.write(_format_exported_program(program))
|
||||
if model is not None:
|
||||
f.write("## ONNX model\n\n")
|
||||
f.write("```python\n")
|
||||
f.write(str(model))
|
||||
f.write("\n```\n\n")
|
||||
f.write("## Analysis\n\n")
|
||||
_analysis.analyze(program, file=f, registry=registry)
|
||||
if decomp_comparison is not None:
|
||||
f.write("\n## Decomposition comparison\n\n")
|
||||
f.write(decomp_comparison)
|
||||
f.write("\n")
|
||||
if verification_result is not None:
|
||||
f.write("\n## Verification results\n\n")
|
||||
f.write(verification_result)
|
||||
f.write("\n")
|
||||
if profile_result is not None:
|
||||
f.write("\n## Profiling result\n\n")
|
||||
f.write("```\n")
|
||||
f.write(profile_result)
|
||||
f.write("```\n")
|
548
torch/onnx/_internal/exporter/_schemas.py
Normal file
548
torch/onnx/_internal/exporter/_schemas.py
Normal file
@ -0,0 +1,548 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import onnx
|
||||
|
||||
import onnxscript
|
||||
from onnxscript import ir
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# A special value to indicate that the default value is not specified
|
||||
class _Empty:
|
||||
def __repr__(self):
|
||||
return "_EMPTY_DEFAULT"
|
||||
|
||||
|
||||
_EMPTY_DEFAULT = _Empty()
|
||||
|
||||
# Map from python type to corresponding ONNX AttributeProto type
|
||||
_PY_TYPE_TO_ATTR_TYPE = {
|
||||
float: ir.AttributeType.FLOAT,
|
||||
int: ir.AttributeType.INT,
|
||||
str: ir.AttributeType.STRING,
|
||||
bool: ir.AttributeType.INT,
|
||||
ir.Tensor: ir.AttributeType.TENSOR,
|
||||
ir.TensorProtocol: ir.AttributeType.TENSOR,
|
||||
ir.Graph: ir.AttributeType.GRAPH,
|
||||
ir.GraphProtocol: ir.AttributeType.GRAPH,
|
||||
}
|
||||
|
||||
# Map from python type to corresponding ONNX AttributeProto type,
|
||||
# for repeated (i.e., list of) values
|
||||
_LIST_TYPE_TO_ATTR_TYPE = {
|
||||
float: ir.AttributeType.FLOATS,
|
||||
int: ir.AttributeType.INTS,
|
||||
str: ir.AttributeType.STRINGS,
|
||||
bool: ir.AttributeType.INTS,
|
||||
ir.Tensor: ir.AttributeType.TENSORS,
|
||||
ir.TensorProtocol: ir.AttributeType.TENSORS,
|
||||
ir.Graph: ir.AttributeType.GRAPHS,
|
||||
ir.GraphProtocol: ir.AttributeType.GRAPHS,
|
||||
}
|
||||
|
||||
_ALL_VALUE_TYPES = (
|
||||
{ir.TensorType(dtype) for dtype in ir.DataType}
|
||||
| {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}
|
||||
| {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType}
|
||||
)
|
||||
|
||||
# TypeAnnotationValue represents the (value of) valid type-annotations recognized
|
||||
# by ONNX Script. Currently, it supports
|
||||
# - float, int, str (primitive attribute types)
|
||||
# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
|
||||
# - Tensor types
|
||||
# - Sequence[Tensor] types
|
||||
# - Union of above 2
|
||||
# - TypeVars with above bounds
|
||||
# - Above types with annotation attached
|
||||
TypeAnnotationValue = Any
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TypeConstraintParam:
|
||||
"""Type constraint for a parameter.
|
||||
|
||||
Attributes:
|
||||
name: Name of the parameter. E.g. "TFloat"
|
||||
allowed_types: Allowed types for the parameter.
|
||||
"""
|
||||
|
||||
name: str
|
||||
allowed_types: set[ir.TypeProtocol]
|
||||
description: str = ""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.name, tuple(self.allowed_types)))
|
||||
|
||||
def __str__(self) -> str:
|
||||
allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
|
||||
return f"{self.name}={allowed_types_str}"
|
||||
|
||||
@classmethod
|
||||
def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
|
||||
return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description)
|
||||
|
||||
@classmethod
|
||||
def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
|
||||
return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Parameter:
|
||||
"""A formal parameter of an operator."""
|
||||
|
||||
name: str
|
||||
type_constraint: TypeConstraintParam
|
||||
required: bool
|
||||
variadic: bool
|
||||
default: Any = _EMPTY_DEFAULT
|
||||
# TODO: Add other properties too
|
||||
|
||||
def __str__(self) -> str:
|
||||
type_str = self.type_constraint.name
|
||||
if self.has_default():
|
||||
return f"{self.name}: {type_str} = {self.default}"
|
||||
return f"{self.name}: {type_str}"
|
||||
|
||||
def has_default(self) -> bool:
|
||||
return self.default is not _EMPTY_DEFAULT
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AttributeParameter:
|
||||
name: str
|
||||
type: ir.AttributeType
|
||||
required: bool
|
||||
default: ir.Attr | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
type_str = self.type.name
|
||||
if self.has_default():
|
||||
return f"{self.name}: {type_str} = {self.default}"
|
||||
return f"{self.name}: {type_str}"
|
||||
|
||||
def has_default(self) -> bool:
|
||||
return self.default is not None
|
||||
|
||||
|
||||
def _get_type_from_str(
|
||||
type_str: str,
|
||||
) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
|
||||
"""Converter a type_str from ONNX Opschema to ir.TypeProtocol.
|
||||
|
||||
A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
|
||||
"""
|
||||
|
||||
# TODO: Upstream this to IR
|
||||
|
||||
# Split the type_str a sequence types and dtypes
|
||||
# 1. Remove the ending ")"
|
||||
striped = type_str.rstrip(")")
|
||||
# 2. Split the type_str by "("
|
||||
type_parts = striped.split("(")
|
||||
|
||||
# Convert the dtype to ir.DataType
|
||||
dtype = ir.DataType[type_parts[-1].upper()]
|
||||
|
||||
# Create a place holder type first
|
||||
type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED)
|
||||
|
||||
# Construct the type
|
||||
for type_part in reversed(type_parts[:-1]):
|
||||
if type_part == "tensor":
|
||||
type_ = ir.TensorType(dtype)
|
||||
elif type_part == "seq":
|
||||
type_ = ir.SequenceType(type_)
|
||||
elif type_part == "optional":
|
||||
type_ = ir.OptionalType(type_)
|
||||
else:
|
||||
raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
|
||||
return type_ # type: ignore[return-value]
|
||||
|
||||
|
||||
def _convert_formal_parameter(
|
||||
param: onnx.defs.OpSchema.FormalParameter,
|
||||
type_constraints: Mapping[str, TypeConstraintParam],
|
||||
) -> Parameter:
|
||||
"""Convert a formal parameter from ONNX Opschema to Parameter."""
|
||||
if param.type_str in type_constraints:
|
||||
type_constraint = type_constraints[param.type_str]
|
||||
else:
|
||||
# param.type_str can be a plain type like 'int64'.
|
||||
type_constraint = TypeConstraintParam(
|
||||
name=param.name,
|
||||
allowed_types={_get_type_from_str(param.type_str)},
|
||||
)
|
||||
return Parameter(
|
||||
name=param.name,
|
||||
type_constraint=type_constraint,
|
||||
required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
|
||||
variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
|
||||
)
|
||||
|
||||
|
||||
def _is_optional(type_: type) -> bool:
|
||||
"""Returns whether a type_ is an Optional."""
|
||||
origin_type = typing.get_origin(type_)
|
||||
if origin_type is Union and type(None) in typing.get_args(type_):
|
||||
# Python < 3.10
|
||||
return True
|
||||
if origin_type is Optional:
|
||||
# Python >= 3.10
|
||||
return True
|
||||
if (
|
||||
hasattr(types, "UnionType")
|
||||
and origin_type is types.UnionType
|
||||
and type(None) in typing.get_args(type_)
|
||||
):
|
||||
# Python >= 3.10
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_attr_type(type_: type) -> ir.AttributeType:
|
||||
"""Obtain the type of the attribute from a Python class."""
|
||||
try:
|
||||
if type_ in _PY_TYPE_TO_ATTR_TYPE:
|
||||
return _PY_TYPE_TO_ATTR_TYPE[type_]
|
||||
origin_type = typing.get_origin(type_)
|
||||
if origin_type is None:
|
||||
return ir.AttributeType.UNDEFINED
|
||||
if origin_type in (
|
||||
collections.abc.Sequence,
|
||||
Sequence,
|
||||
typing.List,
|
||||
list,
|
||||
typing.Tuple,
|
||||
tuple,
|
||||
):
|
||||
inner_type = typing.get_args(type_)[0]
|
||||
if inner_type in _LIST_TYPE_TO_ATTR_TYPE:
|
||||
return _LIST_TYPE_TO_ATTR_TYPE[inner_type]
|
||||
except TypeError:
|
||||
logger.warning("TypeError when checking %s.", type_, exc_info=True)
|
||||
return ir.AttributeType.UNDEFINED
|
||||
|
||||
|
||||
def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None:
|
||||
"""Returns the name of the type constraint for a given type annotation.
|
||||
|
||||
Args:
|
||||
type_: A Python type.
|
||||
|
||||
Returns:
|
||||
The name of the type constraint if it is a TypeVar.
|
||||
- Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
|
||||
"""
|
||||
if isinstance(type_, TypeVar):
|
||||
return type_.__name__
|
||||
if _is_optional(type_):
|
||||
subtypes = typing.get_args(type_)
|
||||
for subtype in subtypes:
|
||||
if subtype is type(None):
|
||||
continue
|
||||
type_param_name = _get_type_constraint_name(subtype)
|
||||
return type_param_name if type_param_name else None
|
||||
origin_type = typing.get_origin(type_)
|
||||
if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
|
||||
subtypes = typing.get_args(type_)
|
||||
type_param_name = _get_type_constraint_name(subtypes[0])
|
||||
return f"Sequence_{type_param_name}" if type_param_name else None
|
||||
return None
|
||||
|
||||
|
||||
def _get_allowed_types_from_type_annotation(
|
||||
type_: TypeAnnotationValue,
|
||||
) -> set[ir.TypeProtocol]:
|
||||
"""Obtain the allowed types from a type annotation."""
|
||||
if type_ is onnxscript.onnx_types.TensorType:
|
||||
# Any tensor type
|
||||
return {ir.TensorType(dtype) for dtype in ir.DataType}
|
||||
|
||||
allowed_types: set[ir.TypeProtocol]
|
||||
|
||||
if isinstance(type_, TypeVar):
|
||||
allowed_types = set()
|
||||
if constraints := type_.__constraints__:
|
||||
for constraint in constraints:
|
||||
allowed_types.update(
|
||||
_get_allowed_types_from_type_annotation(constraint)
|
||||
)
|
||||
else:
|
||||
bound = type_.__bound__
|
||||
if bound is None:
|
||||
allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment]
|
||||
else:
|
||||
allowed_types.update(_get_allowed_types_from_type_annotation(bound))
|
||||
return allowed_types
|
||||
if hasattr(type_, "dtype"):
|
||||
# A single tensor type like INT64, FLOAT, etc.
|
||||
return {ir.TensorType(ir.DataType(type_.dtype))}
|
||||
if _is_optional(type_):
|
||||
allowed_types = set()
|
||||
subtypes = typing.get_args(type_)
|
||||
for subtype in subtypes:
|
||||
if subtype is type(None):
|
||||
continue
|
||||
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
|
||||
# NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful.
|
||||
return allowed_types
|
||||
|
||||
origin_type = typing.get_origin(type_)
|
||||
if origin_type is Union:
|
||||
allowed_types = set()
|
||||
subtypes = typing.get_args(type_)
|
||||
for subtype in subtypes:
|
||||
assert subtype is not type(
|
||||
None
|
||||
), "Union should not contain None type because it is handled by _is_optional."
|
||||
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
|
||||
return allowed_types
|
||||
|
||||
if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
|
||||
subtypes = typing.get_args(type_)
|
||||
return {
|
||||
ir.SequenceType(t)
|
||||
for t in _get_allowed_types_from_type_annotation(subtypes[0])
|
||||
}
|
||||
|
||||
# Allow everything by default
|
||||
return _ALL_VALUE_TYPES # type: ignore[return-value]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OpSignature:
|
||||
"""Schema for an operator.
|
||||
|
||||
Attributes:
|
||||
domain: Domain of the operator. E.g. "".
|
||||
name: Name of the operator. E.g. "Add".
|
||||
overload: Overload name of the operator.
|
||||
params: Input parameters. When the op is an ONNX function definition,
|
||||
the order is according to the function signature. This mean we can
|
||||
interleave ONNX inputs and ONNX attributes in the list.
|
||||
outputs: Output parameters.
|
||||
"""
|
||||
|
||||
domain: str
|
||||
name: str
|
||||
overload: str
|
||||
params: Sequence[Parameter | AttributeParameter]
|
||||
outputs: Sequence[Parameter]
|
||||
params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
|
||||
init=False, repr=False
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.params_map = {param.name: param for param in self.params}
|
||||
|
||||
def get(self, name: str) -> Parameter | AttributeParameter:
|
||||
return self.params_map[name]
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.params_map
|
||||
|
||||
def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
|
||||
return iter(self.params)
|
||||
|
||||
def __str__(self) -> str:
|
||||
domain = self.domain or "''"
|
||||
# TODO: Double check the separator for overload
|
||||
overload = f"::{self.overload}" if self.overload else ""
|
||||
params = ", ".join(str(param) for param in self.params)
|
||||
outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
|
||||
type_constraints = {}
|
||||
for param in self.params:
|
||||
if isinstance(param, Parameter):
|
||||
type_constraints[param.type_constraint.name] = param.type_constraint
|
||||
for param in self.outputs:
|
||||
type_constraints[param.type_constraint.name] = param.type_constraint
|
||||
type_constraints_str = ", ".join(
|
||||
str(type_constraint) for type_constraint in type_constraints.values()
|
||||
)
|
||||
return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"
|
||||
|
||||
@classmethod
|
||||
def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature:
|
||||
"""Produce an OpSignature from an ONNX Opschema."""
|
||||
type_constraints = {
|
||||
constraint.type_param_str: TypeConstraintParam(
|
||||
name=constraint.type_param_str,
|
||||
allowed_types={
|
||||
_get_type_from_str(type_str)
|
||||
for type_str in constraint.allowed_type_strs
|
||||
},
|
||||
description=constraint.description,
|
||||
)
|
||||
for constraint in opschema.type_constraints
|
||||
}
|
||||
|
||||
params = [
|
||||
_convert_formal_parameter(param, type_constraints)
|
||||
for param in opschema.inputs
|
||||
]
|
||||
|
||||
for param in opschema.attributes.values():
|
||||
default_attr = (
|
||||
ir.serde.deserialize_attribute(param.default_value)
|
||||
if param.default_value is not None
|
||||
else None
|
||||
)
|
||||
if default_attr is not None:
|
||||
# Set the name of the default attribute because it may have a different name from the parameter
|
||||
default_attr.name = param.name
|
||||
params.append(
|
||||
AttributeParameter(
|
||||
name=param.name,
|
||||
type=ir.AttributeType(param.type), # type: ignore[arg-type]
|
||||
required=param.required,
|
||||
default=default_attr, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
outputs = [
|
||||
_convert_formal_parameter(param, type_constraints)
|
||||
for param in opschema.outputs
|
||||
]
|
||||
|
||||
return cls(
|
||||
domain=opschema.domain,
|
||||
name=opschema.name,
|
||||
overload="",
|
||||
params=params,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls, func, domain: str, name: str | None = None, overload: str = ""
|
||||
) -> OpSignature:
|
||||
"""Produce an OpSignature from a function using type annotation."""
|
||||
|
||||
py_signature = inspect.signature(func)
|
||||
# Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases
|
||||
# https://github.com/python/cpython/issues/102405
|
||||
type_hints = typing.get_type_hints(func)
|
||||
|
||||
params = []
|
||||
# Create a mapping from type to a unique name
|
||||
type_constraints: dict[str, TypeConstraintParam] = {}
|
||||
|
||||
for param in py_signature.parameters.values():
|
||||
if param.name not in type_hints:
|
||||
logger.warning(
|
||||
"Missing annotation for parameter '%s' from %s. Treating as an Input.",
|
||||
param.name,
|
||||
py_signature,
|
||||
)
|
||||
type_constraints[param.name] = TypeConstraintParam.any_value(
|
||||
f"T_{param.name}"
|
||||
)
|
||||
else:
|
||||
type_ = type_hints[param.name]
|
||||
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
|
||||
# Construct the default attribute
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
# TODO: Use ir_convenience instead to handle int as float
|
||||
default = ir.Attr(param.name, attr_type, param.default)
|
||||
else:
|
||||
default = None
|
||||
params.append(
|
||||
AttributeParameter(
|
||||
name=param.name,
|
||||
type=attr_type,
|
||||
required=param.default is inspect.Parameter.empty,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Obtain the type constraint from the type annotation
|
||||
|
||||
# 1. Get a type constraint name from the type annotation
|
||||
# If the type annotation is a TypeVar or Optional[TypeVar], get its name
|
||||
# Otherwise, name it T_{param.name}
|
||||
type_constraint_name = _get_type_constraint_name(type_)
|
||||
if type_constraint_name is None:
|
||||
type_constraint_name = f"T_{param.name}"
|
||||
|
||||
# 2. If the type constraint param is already initialized, use it
|
||||
if type_constraint_name in type_constraints:
|
||||
type_constraint = type_constraints[type_constraint_name]
|
||||
else:
|
||||
# 3. Otherwise, create a new TypeConstraintParam
|
||||
type_constraint = TypeConstraintParam(
|
||||
name=type_constraint_name,
|
||||
allowed_types=_get_allowed_types_from_type_annotation(
|
||||
type_
|
||||
),
|
||||
)
|
||||
type_constraints[type_constraint_name] = type_constraint
|
||||
# 4. Create Parameter
|
||||
params.append(
|
||||
Parameter( # type: ignore[arg-type]
|
||||
name=param.name,
|
||||
type_constraint=type_constraint,
|
||||
required=param.default is inspect.Parameter.empty,
|
||||
# TODO: Handle variadic
|
||||
variadic=False,
|
||||
default=param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else _EMPTY_DEFAULT,
|
||||
)
|
||||
)
|
||||
|
||||
return_type = type_hints.get("return")
|
||||
|
||||
outputs = []
|
||||
if return_type is None:
|
||||
# No returns
|
||||
pass
|
||||
else:
|
||||
if typing.get_origin(return_type) is tuple:
|
||||
# Multiple returns
|
||||
return_types = typing.get_args(return_type)
|
||||
else:
|
||||
return_types = [return_type] # type: ignore[assignment]
|
||||
|
||||
for i, return_type_i in enumerate(return_types):
|
||||
if (
|
||||
return_param_name := _get_type_constraint_name(return_type_i)
|
||||
) in type_constraints:
|
||||
type_constraint = type_constraints[return_param_name]
|
||||
else:
|
||||
return_param_name = f"TReturn{i}"
|
||||
type_constraint = TypeConstraintParam(
|
||||
name=return_param_name,
|
||||
allowed_types=_get_allowed_types_from_type_annotation(
|
||||
return_type_i
|
||||
),
|
||||
)
|
||||
type_constraints[return_param_name] = type_constraint
|
||||
outputs.append(
|
||||
Parameter(
|
||||
name=return_param_name,
|
||||
type_constraint=type_constraint,
|
||||
required=True,
|
||||
variadic=False,
|
||||
default=_EMPTY_DEFAULT,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
domain=domain,
|
||||
name=name or func.__name__,
|
||||
overload=overload,
|
||||
params=params,
|
||||
outputs=outputs,
|
||||
)
|
98
torch/onnx/_internal/exporter/_tensors.py
Normal file
98
torch/onnx/_internal/exporter/_tensors.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Subclass of ir.Value that supports Python operators."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import onnxscript
|
||||
from onnxscript import ir
|
||||
|
||||
|
||||
class SymbolicTensor(ir.Value):
|
||||
"""A subclass of ir.Value that supports Python operators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
opset: onnxscript.values.Opset,
|
||||
name: str | None = None,
|
||||
shape: ir.Shape | None = None,
|
||||
type: ir.TypeProtocol | None = None,
|
||||
doc_string: str | None = None,
|
||||
const_value: ir.TensorProtocol | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
shape=shape,
|
||||
type=type,
|
||||
doc_string=doc_string,
|
||||
const_value=const_value,
|
||||
)
|
||||
self._opset = opset
|
||||
|
||||
@property
|
||||
def rank(self) -> int | None:
|
||||
if self.shape is None:
|
||||
return None
|
||||
return len(self.shape)
|
||||
|
||||
# TODO: Implement indexing
|
||||
|
||||
def __mod__(self, other):
|
||||
if self.dtype in {
|
||||
ir.DataType.FLOAT,
|
||||
ir.DataType.DOUBLE,
|
||||
ir.DataType.FLOAT16,
|
||||
ir.DataType.BFLOAT16,
|
||||
}:
|
||||
return self._opset.Mod(self, other, fmod=1)
|
||||
return self._opset.Mod(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return self._opset.Not(self._opset.Equal(self, other))
|
||||
|
||||
def __neg__(self):
|
||||
return self._opset.Neg(self)
|
||||
|
||||
def __add__(self, other):
|
||||
return self._opset.Add(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return self._opset.Add(other, self)
|
||||
|
||||
def __rand__(self, other):
|
||||
return self._opset.And(other, self)
|
||||
|
||||
def __mul__(self, other):
|
||||
return self._opset.Mul(self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return self._opset.Mul(other, self)
|
||||
|
||||
def __matmul__(self, other):
|
||||
return self._opset.MatMul(self, other)
|
||||
|
||||
def __pow__(self, other):
|
||||
return self._opset.Pow(self, other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self._opset.Sub(self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return self._opset.Sub(other, self)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self._opset.Div(self, other)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self._opset.Less(self, other)
|
||||
|
||||
def __le__(self, other):
|
||||
return self._opset.LessOrEqual(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._opset.Equal(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self._opset.GreaterOrEqual(self, other)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self._opset.Greater(self, other)
|
79
torch/onnx/_internal/exporter/_verification.py
Normal file
79
torch/onnx/_internal/exporter/_verification.py
Normal file
@ -0,0 +1,79 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.onnx._internal.exporter import _onnx_program
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VerificationInfo:
|
||||
name: str
|
||||
absolute_difference: float
|
||||
relative_difference: float
|
||||
expected_dtype: torch.dtype
|
||||
actual_dtype: torch.dtype
|
||||
# NOTE: We don't need to include shape because the expected shape is already known
|
||||
# and checked by the runtime
|
||||
|
||||
|
||||
def _compare_tensors(
|
||||
expected: torch.Tensor,
|
||||
actual: torch.Tensor,
|
||||
) -> tuple[float, float]:
|
||||
# Move tensors to the same device
|
||||
expected = expected.detach().cpu()
|
||||
actual = actual.detach().cpu()
|
||||
absolute_difference = torch.abs(expected - actual).max().item()
|
||||
eps = 1e-7
|
||||
relative_difference = (
|
||||
(torch.abs(expected - actual) / (torch.abs(expected) + eps)).max().item()
|
||||
)
|
||||
return absolute_difference, relative_difference
|
||||
|
||||
|
||||
def verify_onnx_program(
|
||||
onnx_program: _onnx_program.ONNXProgram,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> list[VerificationInfo]:
|
||||
exported_program = onnx_program.exported_program
|
||||
if args is None and kwargs is None:
|
||||
# User did not provide example inputs, use the default example inputs
|
||||
if exported_program.example_inputs is None:
|
||||
raise ValueError(
|
||||
"No example inputs provided and the exported_program does not contain example inputs. "
|
||||
"Please provide arguments to verify the ONNX program."
|
||||
)
|
||||
args, kwargs = exported_program.example_inputs
|
||||
if args is None:
|
||||
args = ()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
torch_module = exported_program.module()
|
||||
torch_outputs, _ = pytree.tree_flatten(torch_module(*args, **kwargs))
|
||||
onnx_outputs = onnx_program(*args, **kwargs)
|
||||
results = []
|
||||
for torch_output, onnx_output, output_val in zip(
|
||||
torch_outputs, onnx_outputs, onnx_program.model.graph.outputs
|
||||
):
|
||||
name = output_val.name
|
||||
absolute_difference, relative_difference = _compare_tensors(
|
||||
torch_output, onnx_output
|
||||
)
|
||||
results.append(
|
||||
VerificationInfo(
|
||||
name=str(name),
|
||||
absolute_difference=absolute_difference,
|
||||
relative_difference=relative_difference,
|
||||
expected_dtype=torch_output.dtype,
|
||||
actual_dtype=onnx_output.dtype,
|
||||
)
|
||||
)
|
||||
return results
|
30
torch/onnx/_internal/exporter/errors.py
Normal file
30
torch/onnx/_internal/exporter/errors.py
Normal file
@ -0,0 +1,30 @@
|
||||
class ExporterError(RuntimeError):
|
||||
"""Error during export."""
|
||||
|
||||
|
||||
class TorchExportError(ExporterError):
|
||||
"""Error during torch.export.export."""
|
||||
|
||||
|
||||
class OnnxConversionError(ExporterError):
|
||||
"""Error during ONNX conversion."""
|
||||
|
||||
|
||||
class DispatchError(OnnxConversionError):
|
||||
"""Error during ONNX Funtion dispatching."""
|
||||
|
||||
|
||||
class GraphConstructionError(OnnxConversionError):
|
||||
"""Error during graph construction."""
|
||||
|
||||
|
||||
class OnnxCheckerError(ExporterError):
|
||||
"""Error during ONNX model checking."""
|
||||
|
||||
|
||||
class OnnxRuntimeError(ExporterError):
|
||||
"""Error during ONNX Runtime execution."""
|
||||
|
||||
|
||||
class OnnxValidationError(ExporterError):
|
||||
"""Output value mismatch."""
|
@ -172,10 +172,7 @@ def _get_torch_export_args(
|
||||
|
||||
|
||||
def export(
|
||||
model: torch.nn.Module
|
||||
| torch.jit.ScriptModule
|
||||
| torch.jit.ScriptFunction
|
||||
| torch.export.ExportedProgram,
|
||||
model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction,
|
||||
args: tuple[Any, ...] | torch.Tensor,
|
||||
f: str | None = None,
|
||||
*,
|
||||
@ -191,13 +188,11 @@ def export(
|
||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||
| Mapping[str, Sequence[int]]
|
||||
| None = None,
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
||||
keep_initializers_as_inputs: bool | None = None,
|
||||
custom_opsets: Mapping[str, int] | None = None,
|
||||
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
|
||||
autograd_inlining: bool | None = True,
|
||||
dynamo: bool = False,
|
||||
) -> torch.onnx.ONNXProgram | None:
|
||||
autograd_inlining: bool = True,
|
||||
) -> None:
|
||||
r"""Exports a model into ONNX format.
|
||||
|
||||
If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
|
||||
@ -491,8 +486,6 @@ def export(
|
||||
autograd_inlining: Flag used to control whether to inline autograd functions.
|
||||
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
|
||||
|
||||
dynamo: Whether to export the model with Dynamo instead of TorchScript.
|
||||
|
||||
Raises:
|
||||
:class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
|
||||
:class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
|
||||
@ -515,65 +508,29 @@ def export(
|
||||
)
|
||||
|
||||
args = (args,) if isinstance(args, torch.Tensor) else args
|
||||
if kwargs is not None:
|
||||
args = args + (kwargs,)
|
||||
|
||||
if dynamo:
|
||||
if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)):
|
||||
raise TypeError(
|
||||
"Dynamo export does not support ScriptModule or ScriptFunction."
|
||||
)
|
||||
# TODO(justinchuby): Remove the warning once logic migration is done
|
||||
warnings.warn(
|
||||
"export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, "
|
||||
"do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and "
|
||||
"autograd_inlining are not supported for dynamo export at the moment."
|
||||
)
|
||||
args, kwargs = _get_torch_export_args(args, kwargs)
|
||||
if isinstance(model, torch.export.ExportedProgram):
|
||||
exported_program = model
|
||||
else:
|
||||
if dynamic_shapes is None and dynamic_axes is not None:
|
||||
dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
|
||||
model, dynamic_axes, input_names
|
||||
)
|
||||
exported_program = torch.export.export(
|
||||
model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type]
|
||||
)
|
||||
if kwargs is None:
|
||||
# TODO(justinchuby): dynamo_export requires kwargs to be unpacked. Once migration is done
|
||||
# we can pass kwargs as None
|
||||
kwargs = {}
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs)
|
||||
if f is not None:
|
||||
onnx_program.save(f)
|
||||
return onnx_program
|
||||
_export(
|
||||
model,
|
||||
args,
|
||||
f,
|
||||
export_params,
|
||||
verbose,
|
||||
training,
|
||||
input_names,
|
||||
output_names,
|
||||
operator_export_type=operator_export_type,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets,
|
||||
export_modules_as_functions=export_modules_as_functions,
|
||||
autograd_inlining=autograd_inlining,
|
||||
)
|
||||
|
||||
else:
|
||||
# Torch Script export path
|
||||
if f is None:
|
||||
raise ValueError("Export destination must be specified when dynamo=False.")
|
||||
if kwargs is not None:
|
||||
args = args + (kwargs,)
|
||||
|
||||
_export(
|
||||
model,
|
||||
args,
|
||||
f,
|
||||
export_params,
|
||||
verbose,
|
||||
training,
|
||||
input_names,
|
||||
output_names,
|
||||
operator_export_type=operator_export_type,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets,
|
||||
export_modules_as_functions=export_modules_as_functions,
|
||||
autograd_inlining=autograd_inlining,
|
||||
)
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _is_constant_tensor_list(node):
|
||||
@ -1531,7 +1488,7 @@ def _export(
|
||||
custom_opsets=None,
|
||||
add_node_names=True,
|
||||
onnx_shape_inference=True,
|
||||
export_modules_as_functions=False,
|
||||
export_modules_as_functions: Any = False,
|
||||
autograd_inlining=True,
|
||||
):
|
||||
assert GLOBALS.in_onnx_export is False
|
||||
@ -1560,9 +1517,7 @@ def _export(
|
||||
f"Exporting to ONNX opset version {opset_version} is not supported. "
|
||||
f"by 'torch.onnx.export()'. "
|
||||
f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. "
|
||||
f"To use a newer opset version, consider 'torch.onnx.dynamo_export()'. "
|
||||
f"Note that dynamo_export() is in preview. Please report errors with "
|
||||
f"dynamo_export() as Github issues to https://github.com/pytorch/pytorch/issues.",
|
||||
f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ",
|
||||
category=errors.OnnxExporterWarning,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user