[ONNX] Refactor torchscript based exporter (#161323)

Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved.

- Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly
- Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused

@albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and
`torch/autograd/_functions/utils.py`? Thanks!

## BC Breaking
- **Deprecated members in `torch.onnx.verification` are removed**

Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323
Approved by: https://github.com/titaiwangms, https://github.com/angelayi
This commit is contained in:
Justin Chu
2025-08-29 15:11:16 -07:00
committed by PyTorch MergeBot
parent 793fc12aff
commit 524b78d4f6
63 changed files with 19034 additions and 18923 deletions

View File

@ -113,7 +113,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
.. autoclass:: JitScalarType
```
```{eval-rst}

View File

@ -1,4 +1,5 @@
# torch.onnx.verification
```{eval-rst}
.. automodule:: torch.onnx.verification
```
@ -11,23 +12,3 @@
.. autoclass:: VerificationInfo
:members:
```
```{eval-rst}
.. autofunction:: verify
```
## Deprecated
The following classes and functions are deprecated.
<!-- Some deprecated members are not publicly shown -->
```{eval-rst}
.. py:class:: check_export_model_diff
.. py:class:: GraphInfo
.. py:class:: GraphInfoPrettyPrinter
.. py:class:: OnnxBackend
.. py:class:: OnnxTestCaseRepro
.. py:class:: VerificationOptions
.. py:function:: find_mismatch
.. py:function:: verify_aten_graph
```

View File

@ -4,7 +4,7 @@
from collections.abc import Sequence
from torch.onnx import errors
from torch.onnx._internal import registration
from torch.onnx._internal.torchscript_exporter import registration
from torch.testing._internal import common_utils

View File

@ -17,7 +17,8 @@ import pytorch_test_common
import torch
from torch import export as torch_export
from torch.onnx import _constants, verification
from torch.onnx import _constants
from torch.onnx._internal.torchscript_exporter import verification
from torch.testing._internal import common_utils
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number

View File

@ -5,13 +5,11 @@ from onnx_test_common import run_model_test
import torch
from torch.onnx import OperatorExportTypes
from torch.onnx._globals import GLOBALS
from torch.onnx.utils import _model_to_graph
from torch.testing._internal import common_utils
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
opset_version = GLOBALS.export_onnx_opset_version
opset_version = 20
keep_initializers_as_inputs = False
onnx_shape_inference = True
@ -133,7 +131,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
input = torch.ones(1, 5)
# Test ONNX_FALLTHROUGH_MODE
graph, _, _ = _model_to_graph(
graph, _, _ = torch.onnx.utils._model_to_graph(
model,
(input,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
@ -142,7 +140,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
self.assertEqual(next(iter).kind(), "prim::PythonOp")
# Test ATEN_FALLBACK_MODE
graph, _, _ = _model_to_graph(
graph, _, _ = torch.onnx.utils._model_to_graph(
model,
(input,),
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,

View File

@ -11,7 +11,7 @@ import torch
import torch.onnx
from torch.nn import Module
from torch.onnx import producer_name, producer_version
from torch.onnx._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.testing._internal import common_utils

View File

@ -10,7 +10,7 @@ import onnxscript
from onnxscript.onnx_types import FLOAT
import torch
from torch.onnx._internal import jit_utils
from torch.onnx._internal.torchscript_exporter import jit_utils
from torch.testing._internal import common_utils

View File

@ -9,7 +9,7 @@ import onnxscript
from onnxscript.onnx_types import FLOAT
import torch
from torch.onnx._internal import jit_utils
from torch.onnx._internal.torchscript_exporter import jit_utils
from torch.testing._internal import common_utils

View File

@ -4,8 +4,11 @@ import pytorch_test_common
from pytorch_test_common import skipIfNoCuda
import torch
from torch.onnx import verification
from torch.onnx._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter import verification
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter.utils import (
_trigger_symbolic_function_registration,
)
from torch.testing._internal import common_utils
@ -20,6 +23,7 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
"""
GLOBALS.export_onnx_opset_version = opset_version
_trigger_symbolic_function_registration()
graph = torch.onnx.utils._optimize_graph(
graph, operator_export_type, params_dict={}
)

View File

@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, utils
from torch.onnx._internal import registration
from torch.onnx._internal.torchscript_exporter import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils
@ -430,9 +430,8 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
torch.randn(3, 4, requires_grad=True),
mocks=[
unittest.mock.patch(
"torch.onnx._internal.registration.registry.get_function_group",
"torch.onnx._internal.torchscript_exporter.registration.registry.get_function_group",
side_effect=break_is_registered_op_api,
# wraps=registration.registry.get_function_group
)
],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,

View File

@ -41,7 +41,9 @@ from pytorch_test_common import (
import torch
from torch import Tensor
from torch.nn.utils import rnn as rnn_utils
from torch.onnx import errors, verification
from torch.onnx import errors
from torch.onnx._internal.torchscript_exporter import verification
from torch.onnx._internal.torchscript_exporter._type_utils import JitScalarType
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoLapack
@ -13705,7 +13707,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
input_names=["x"],
)
exported = onnx.load_from_string(f.getvalue())
expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
expected_elem_type = JitScalarType.from_value(x).onnx_type()
expected_output_type = onnx.helper.make_optional_type_proto(
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
)

View File

@ -10,8 +10,8 @@ from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
import torch
from torch.onnx import _constants, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils
from torch.onnx._internal.torchscript_exporter import jit_utils
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.testing._internal import common_utils

View File

@ -3,7 +3,7 @@
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.testing._internal import common_utils

View File

@ -1,11 +1,9 @@
# Owner(s): ["module: onnx"]
import copy
import functools
import io
import re
import warnings
from typing import Callable
import onnx
@ -23,7 +21,7 @@ import torch
import torch.onnx
import torch.utils.cpp_extension
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoLapack
@ -86,86 +84,6 @@ class _BaseTestCase(pytorch_test_common.ExportTestCase):
return graph, params_dict, torch_out
@common_utils.instantiate_parametrized_tests
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
"""Unit tests for the `unconvertible_ops` function."""
def setUp(self):
class EinsumModule(torch.nn.Module):
def forward(self, x):
return torch.einsum("ii", x)
self.einsum_module = EinsumModule()
def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
x = torch.randn(4, 4)
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
graph, unconvertible_ops = utils.unconvertible_ops(
self.einsum_module, (x,), opset_version=9
)
nodes = graph.nodes()
self.assertEqual(next(nodes).kind(), "prim::Constant")
self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
self.assertEqual(next(nodes).kind(), "prim::Constant")
self.assertEqual(next(nodes).kind(), "aten::einsum")
self.assertEqual(unconvertible_ops, ["aten::einsum"])
@common_utils.parametrize(
"jit_function",
[
common_utils.subtest(
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
name="traced",
),
common_utils.subtest(torch.jit.script, name="scripted"),
],
)
def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
self, jit_function: Callable
):
module = jit_function(self.einsum_module)
x = torch.randn(4, 4)
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
self.assertEqual(unconvertible_ops, ["aten::einsum"])
@common_utils.parametrize(
"jit_function",
[
common_utils.subtest(lambda x: x, name="nn_module"),
common_utils.subtest(
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
name="traced",
),
common_utils.subtest(torch.jit.script, name="scripted"),
],
)
def test_it_returns_empty_list_when_all_ops_convertible(
self, jit_function: Callable
):
module = jit_function(self.einsum_module)
x = torch.randn(4, 4)
# Einsum is supported since opset 12
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
self.assertEqual(unconvertible_ops, [])
def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
class SkipConnectionModule(torch.nn.Module):
def forward(self, x):
out = x
out += x
out = torch.nn.functional.relu(out, inplace=True)
return out
module = SkipConnectionModule()
x = torch.randn(4, 4)
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
self.assertEqual(unconvertible_ops, [])
@parameterized.parameterized_class(
[
{"opset_version": opset}

View File

@ -13,7 +13,8 @@ import pytorch_test_common
from packaging import version
import torch
from torch.onnx import _constants, _experimental, verification
from torch.onnx import _constants
from torch.onnx._internal.torchscript_exporter import _experimental, verification
from torch.testing._internal import common_utils

View File

@ -21,8 +21,6 @@ import torch.utils.cpp_extension
import torch.utils.data
from torch._utils import try_import
from torch._utils_internal import deprecated
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@ -790,65 +788,6 @@ class TestCollectEnv(TestCase):
self.assertTrue(info_output.count("\n") >= 17)
class TestONNXUtils(TestCase):
def test_prepare_onnx_paddings(self):
sizes = [2, 3, 4]
pad = [1, 2, 3, 4]
paddings = _prepare_onnx_paddings(len(sizes), pad)
self.assertEqual(paddings, [0, 3, 1, 0, 4, 2])
def test_check_onnx_broadcast(self):
def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail):
broadcast = True
fail = False
try:
broadcast = check_onnx_broadcast(dims1, dims2)
except ValueError:
fail = True
self.assertEqual(broadcast, expect_broadcast)
self.assertEqual(fail, expect_fail)
# Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1
dims1 = [3, 4]
dims2 = [2, 3, 4]
try_check_onnx_broadcast(dims1, dims2, True, True)
# Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1
dims1 = [3, 4]
dims2 = [1, 1, 1]
try_check_onnx_broadcast(dims1, dims2, True, False)
# Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1
dims1 = [1, 1]
dims2 = [1]
try_check_onnx_broadcast(dims1, dims2, True, False)
# Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2
dims1 = [2, 3, 4]
dims2 = [3, 4]
try_check_onnx_broadcast(dims1, dims2, True, False)
# Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2
dims1 = [2, 3, 4]
dims2 = [1, 4]
try_check_onnx_broadcast(dims1, dims2, True, True)
# Case 6, check the equal case, no broadcast
dims1 = [3, 4]
dims2 = [3, 4]
try_check_onnx_broadcast(dims1, dims2, False, False)
# Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2
dims1 = [3, 4]
dims2 = [1, 4]
try_check_onnx_broadcast(dims1, dims2, True, True)
# Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1
dims1 = [3, 4]
dims2 = [1, 1]
try_check_onnx_broadcast(dims1, dims2, True, False)
class TestHipify(TestCase):
def test_import_hipify(self):
from torch.utils.hipify import hipify_python # noqa: F401

View File

@ -1,6 +1,4 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
def maybe_view(tensor, size, check_same_size=True):
@ -26,38 +24,3 @@ def maybe_unexpand(tensor, old_size, check_same_size=True):
for dim in expanded_dims:
tensor = tensor.sum(dim, keepdim=True)
return tensor
# Check whether the op enable broadcasting, and whether it is supported by ONNX.
# If dims1 and dims2 are different, then broadcast is True.
# We always assume the combination of dims1 and dims2 is broadcastable.
# The following types of broadcasting are supported in ONNX:
# 1) Only one element in dims2, such as dims2 = [1, 1]
# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4]
# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
def check_onnx_broadcast(dims1, dims2):
broadcast = False
supported = True
len1 = len(dims1)
len2 = len(dims2)
numel2 = reduce(operator.mul, dims2)
if len1 < len2:
broadcast = True
if numel2 != 1:
supported = False
elif len1 > len2:
broadcast = True
if numel2 != 1 and dims1[len1 - len2 :] != dims2:
supported = False
else:
if dims1 != dims2:
broadcast = True
if numel2 != 1:
supported = False
if not supported:
raise ValueError(
f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}"
)
return broadcast

View File

@ -1053,7 +1053,8 @@ void _trace_post_record(
}
}
}
py::object onnx_globals = py::module::import("torch.onnx._globals");
py::object onnx_globals =
py::module::import("torch.onnx._internal.torchscript_exporter._globals");
py::bool_ is_in_onnx_export =
py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
py::bool_ is_autograd_inlining_enabled =

View File

@ -261,9 +261,10 @@ void NodeToONNX(
py::dict& env,
py::set& values_in_env) {
py::object onnx = py::module::import("torch.onnx");
py::object onnx_globals = py::module::import("torch.onnx._globals");
py::object onnx_registration =
py::module::import("torch.onnx._internal.registration");
py::object onnx_globals =
py::module::import("torch.onnx._internal.torchscript_exporter._globals");
py::object onnx_registration = py::module::import(
"torch.onnx._internal.torchscript_exporter.registration");
// Setup all the lambda helper functions.

View File

@ -4,92 +4,3 @@ Torch->ONNX converter / exporter.
- User-facing docs: https://pytorch.org/docs/main/onnx.html
- Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter
> Read the following if you are contributing to `torch.onnx`
## Symbolic functions Opsets
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
## Editing Symbolic Files
- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example.
- Parameter names must *exactly* match the names in
aten/src/ATen/native/native_functions.yaml, because
dispatch is done with keyword arguments.
- Looking for inplace ops? They're detected by
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
transparently dispatched to their non inplace versions in
"run_symbolic_function". See Note [Export inplace](#export-inplace)
### A note on Tensor types
In general, we should avoid depending on the type of Tensor Values contained
within the trace graph. However, this is sometimes unavoidable (due to ONNX
spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
In general, we should prefer to rely on the least specific information possible.
For example, not relying on tensor properties at all is better than relying
on the number of dimensions which is better than relying on
concrete shapes. Doing so will make the export symbolics
more robust to different graphs.
### Extra context for symbolic functions
The first argument of a symbolic function is always a `GraphContext` object.
`GraphContext` contains all methods defined in a `torch.Graph` object and context
for the symbolic function.
In general, symbolic functions only require inputs and attributes to
the original node. An example of a symbolic function needing context is
`prim::Loop`. It needs access to the sub-block of the original node.
### Export inplace
It would be better for us to export inplace annotations,
than to not export them, since it is useful information that can
help the target of an ONNX export export more efficiently. However,
ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
inplace annotations, but we are losing information this way.
### Pointwise by scalar
What happens if you add a tensor with a constant (e.g., x + 2)? There are
some moving parts to implementing the ONNX translation in this case:
- By the time we get the scalar in a symbolic function here, it is no longer a
Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want
it to be a zero dim tensor but this change has not happened yet.) However, the
type of this scalar is *exactly* what the user wrote in Python, which may not
match the tensor it is being added to. PyTorch will do implicit conversions on
scalars; however, ONNX will not, so we must do the conversion ourselves. This
is what `symbolic_helper._if_scalar_type_as()` and
`_jit_pass_onnx_scalar_type_analysis` does.
- Dispatch to these functions takes advantage an outrageous coincidence
between the tensor and scalar name. When we add two tensors together,
you get the dispatch:
add(*[self, other], **{"alpha": alpha})
When you add a tensor and a scalar, you get the dispatch:
add(*[self], **{"other": other, "alpha": alpha})
By having the argument name line up with the name of the scalar attribute
if it exists, we can write a single function for both overloads.

View File

@ -6,78 +6,43 @@ __all__ = [
# Modules
"errors",
"ops",
"symbolic_helper",
"utils",
# All opsets
"symbolic_opset7",
"symbolic_opset8",
"symbolic_opset9",
"symbolic_opset10",
"symbolic_opset11",
"symbolic_opset12",
"symbolic_opset13",
"symbolic_opset14",
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
"symbolic_opset18",
"symbolic_opset19",
"symbolic_opset20",
# Enums
"OperatorExportTypes",
"TrainingMode",
"TensorProtoDataType",
"JitScalarType",
# Public functions
"export",
"is_in_onnx_export",
"select_model_mode_for_export",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
# Base error
"OnnxExporterError",
"ONNXProgram",
]
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import deprecated
import torch
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from torch._C._onnx import ( # Deprecated members that are excluded from __all__
OperatorExportTypes as OperatorExportTypes,
TensorProtoDataType as TensorProtoDataType,
TrainingMode as TrainingMode,
)
from . import errors, ops
from ._internal.exporter._onnx_program import ONNXProgram
from ._type_utils import JitScalarType
from .errors import OnnxExporterError
from .utils import (
from ._internal.torchscript_exporter import ( # Deprecated members that are excluded from __all__
symbolic_helper,
symbolic_opset10,
symbolic_opset9,
utils,
)
from ._internal.torchscript_exporter._type_utils import (
JitScalarType, # Deprecated members that are excluded from __all__
)
from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__
_run_symbolic_function,
_run_symbolic_method,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
errors,
ops,
symbolic_helper,
symbolic_opset7,
symbolic_opset8,
symbolic_opset9,
symbolic_opset10,
symbolic_opset11,
symbolic_opset12,
symbolic_opset13,
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
utils,
)
from .errors import OnnxExporterError
if TYPE_CHECKING:
@ -85,10 +50,10 @@ if TYPE_CHECKING:
from collections.abc import Collection, Mapping, Sequence
# Set namespace for exposed private names
JitScalarType.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
# TODO(justinchuby): Remove these two properties
producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
@ -385,7 +350,7 @@ def export(
else:
import warnings
from torch.onnx.utils import export
from ._internal.torchscript_exporter.utils import export
warnings.warn(
"You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, "
@ -429,7 +394,7 @@ def export(
def is_in_onnx_export() -> bool:
"""Returns whether it is in the middle of ONNX export."""
from torch.onnx._globals import GLOBALS
from torch.onnx._internal.exporter import _flags
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
return GLOBALS.in_onnx_export or _flags._is_onnx_exporting

View File

@ -0,0 +1,91 @@
# TorchScript Exporter
> [!NOTE]
> This directory hosts code for the legacy TorchScript-based ONNX exporter. It is *deprecated* since PyTorch 2.9 and should be removed along with TorchScript.
## Symbolic functions Opsets
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
## Editing Symbolic Files
- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example.
- Parameter names must *exactly* match the names in
aten/src/ATen/native/native_functions.yaml, because
dispatch is done with keyword arguments.
- Looking for inplace ops? They're detected by
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
transparently dispatched to their non inplace versions in
"run_symbolic_function". See Note [Export inplace](#export-inplace)
### A note on Tensor types
In general, we should avoid depending on the type of Tensor Values contained
within the trace graph. However, this is sometimes unavoidable (due to ONNX
spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
In general, we should prefer to rely on the least specific information possible.
For example, not relying on tensor properties at all is better than relying
on the number of dimensions which is better than relying on
concrete shapes. Doing so will make the export symbolics
more robust to different graphs.
### Extra context for symbolic functions
The first argument of a symbolic function is always a `GraphContext` object.
`GraphContext` contains all methods defined in a `torch.Graph` object and context
for the symbolic function.
In general, symbolic functions only require inputs and attributes to
the original node. An example of a symbolic function needing context is
`prim::Loop`. It needs access to the sub-block of the original node.
### Export inplace
It would be better for us to export inplace annotations,
than to not export them, since it is useful information that can
help the target of an ONNX export export more efficiently. However,
ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
inplace annotations, but we are losing information this way.
### Pointwise by scalar
What happens if you add a tensor with a constant (e.g., x + 2)? There are
some moving parts to implementing the ONNX translation in this case:
- By the time we get the scalar in a symbolic function here, it is no longer a
Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want
it to be a zero dim tensor but this change has not happened yet.) However, the
type of this scalar is *exactly* what the user wrote in Python, which may not
match the tensor it is being added to. PyTorch will do implicit conversions on
scalars; however, ONNX will not, so we must do the conversion ourselves. This
is what `symbolic_helper._if_scalar_type_as()` and
`_jit_pass_onnx_scalar_type_analysis` does.
- Dispatch to these functions takes advantage an outrageous coincidence
between the tensor and scalar name. When we add two tensors together,
you get the dispatch:
add(*[self, other], **{"alpha": alpha})
When you add a tensor and a scalar, you get the dispatch:
add(*[self], **{"other": other, "alpha": alpha})
By having the argument name line up with the name of the scalar attribute
if it exists, we can write a single function for both overloads.

View File

@ -1,9 +1,6 @@
# mypy: allow-untyped-defs
"""Utilities for manipulating the torch.Graph object and the torchscript."""
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
# them to the user.
from __future__ import annotations
import dataclasses
@ -14,8 +11,8 @@ from typing import Any
import torch
from torch import _C
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.onnx._internal.torchscript_exporter import registration
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
@ -89,7 +86,6 @@ class GraphContext:
The value representing the single output of this operator (see the `outputs`
keyword argument for multi-return nodes).
"""
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
@ -211,8 +207,6 @@ def _add_op(
The set of operators and the inputs/attributes they take
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
This function is monkey-patched onto Graph.
Args:
graph_context: The Torch Graph or Block.
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
@ -337,7 +331,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
return getattr(node, f"{kind}_")(name, value)
# TODO: Expose this to user when migrating symbolic helper functions to here.
def _is_tensor(x: _C.Value) -> bool:
return x.type().isSubtypeOf(_C.TensorType.get())

View File

@ -9,10 +9,9 @@ import shutil
from typing import Any, TYPE_CHECKING
import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import errors
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal.torchscript_exporter import jit_utils, registration
if TYPE_CHECKING:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,465 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
from __future__ import annotations
import functools
import sys
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import errors
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
symbolic_opset9 as opset9,
utils,
)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 12
__all__ = [
"argmax",
"argmin",
"binary_cross_entropy_with_logits",
"celu",
"cross_entropy_loss",
"dropout",
"einsum",
"ge",
"le",
"native_dropout",
"nll_loss",
"nll_loss2d",
"nll_loss_nd",
"outer",
"pow",
"tensordot",
"unfold",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
if not tensors:
raise RuntimeError("Einsum inputs are empty.")
# ONNX does not support bool for Einsum inputs.
if symbolic_helper._is_bool(tensors[0]):
tensors = [
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
for tensor in tensors
]
return g.op(
"Cast",
g.op("Einsum", *tensors, equation_s=equation),
to_i=_C_onnx.TensorProtoDataType.BOOL,
)
else:
return g.op("Einsum", *tensors, equation_s=equation)
@_onnx_symbolic("aten::einsum")
@symbolic_helper.parse_args("s", "v", "is")
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
tensors = symbolic_helper._unpack_list(tensor_list)
return _einsum_helper(g, equation, tensors)
@_onnx_symbolic("aten::outer")
@symbolic_helper.parse_args("v", "v")
def outer(g: jit_utils.GraphContext, input, other):
# make sure to cast other to self's type
if _type_utils.JitScalarType.from_value(
other, _type_utils.JitScalarType.UNDEFINED
) != _type_utils.JitScalarType.from_value(input):
other = g.op(
"Cast",
other,
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
)
return _einsum_helper(g, "i,j->ij", [input, other])
def _dropout_returns_masked_input_and_mask(
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
) -> tuple[torch._C.Value, torch._C.Value | None]:
symbolic_helper.check_training_mode(train, "dropout")
# In eval mode, dropout is non-op. That is, if the node's
# train param is set to False, dropout just returns its inputs.
if not train:
return input, None
p = g.op("Constant", value_t=torch.tensor(p))
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
r, mask = g.op("Dropout", input, p, t, outputs=2)
return r, mask
@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "b")
def dropout(g: jit_utils.GraphContext, input, p, train):
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
return masked
@_onnx_symbolic("aten::native_dropout")
@symbolic_helper.parse_args("v", "f", "b")
def native_dropout(g: jit_utils.GraphContext, input, p, train):
return _dropout_returns_masked_input_and_mask(g, input, p, train)
@_onnx_symbolic("aten::nll_loss")
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return nllloss
@_onnx_symbolic("aten::nll_loss2d")
def nll_loss2d(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::nll_loss_nd")
def nll_loss_nd(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::cross_entropy_loss")
def cross_entropy_loss(
g: jit_utils.GraphContext,
self,
target,
weight,
reduction,
ignore_index,
label_smoothing,
):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
if label_smoothing is not None and label_smoothing > 0.0:
raise errors.SymbolicValueError(
"Unsupported: ONNX does not support label_smoothing", self
)
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return celoss
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
def binary_cross_entropy_with_logits(
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
):
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = opset9.sigmoid(g, input)
log_sig_x = opset9.log(g, sig_x)
sub_1_x = opset9.sub(g, p, sig_x)
sub_1_y = opset9.sub(g, p, target)
log_1_x = opset9.log(g, sub_1_x)
if pos_weight is None or symbolic_helper._is_none(pos_weight):
output = opset9.neg(
g,
opset9.add(
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
),
)
else:
output = opset9.neg(
g,
opset9.add(
g,
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
opset9.mul(g, sub_1_y, log_1_x),
),
)
if weight is not None and not symbolic_helper._is_none(weight):
output = opset9.mul(g, weight, output)
reduction = symbolic_helper._maybe_get_const(reduction, "i")
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return g.op("ReduceSum", output, keepdims_i=0)
else:
return symbolic_helper._onnx_unsupported(
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
input,
)
@_onnx_symbolic("aten::celu")
def celu(g: jit_utils.GraphContext, self, alpha):
alpha = symbolic_helper._maybe_get_const(alpha, "f")
# if the input is of type double cast it to float
if (
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.DOUBLE
):
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = g.op("Celu", self, alpha_f=alpha)
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
return g.op("Celu", self, alpha_f=alpha)
@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
def argmax(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
def argmin(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
@_onnx_symbolic("aten::pow")
def pow(g: jit_utils.GraphContext, self, exponent):
return g.op("Pow", self, exponent)
@_onnx_symbolic("aten::ge")
def ge(g: jit_utils.GraphContext, input, other):
return g.op("GreaterOrEqual", input, other)
@_onnx_symbolic("aten::le")
def le(g: jit_utils.GraphContext, input, other):
return g.op("LessOrEqual", input, other)
@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "v", "v")
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
const_size = symbolic_helper._maybe_get_const(size, "i")
const_step = symbolic_helper._maybe_get_const(step, "i")
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
const_step
):
return opset9.unfold(g, input, dimension, const_size, const_step)
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
if sizedim is not None:
low_start = g.op("Constant", value_t=torch.tensor(0))
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
low_indices = g.op("Range", low_start, low_end, step)
hi_indices = g.op("Range", size, hi_end, step)
low_size = symbolic_helper._size_helper(
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
)
hi_size = symbolic_helper._size_helper(
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
)
ndim = symbolic_helper._get_tensor_rank(input)
assert ndim is not None
perm = list(range(0, ndim))
perm.append(perm.pop(dimension))
unsqueeze_list = []
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
loop_len = g.op("Min", low_size, hi_size)
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block) # noqa: F841
starts = loop_context.op("Gather", low_indices, block_input_iter)
ends = loop_context.op("Gather", hi_indices, block_input_iter)
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
stack = loop_context.op("Slice", input, starts, ends, axes)
unsqueeze = symbolic_helper._unsqueeze_helper(
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
)
unsqueeze_list.append(unsqueeze)
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, concat)
loop_output = loop.node().output()
perm = [0, 1, 2, 3, 4]
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
transpose = g.op("Transpose", loop_output, perm_i=perm)
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
return squeeze
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
@_onnx_symbolic("aten::tensordot")
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
if out is not None:
symbolic_helper._unimplemented(
"Tensordot", "Out parameter is not supported for tensordot."
)
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
if dim_count_a is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
input_a,
)
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
if dim_count_b is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
input_b,
)
dims_a = [
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
for i in range(len(dims_a))
]
dims_b = [
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
for i in range(len(dims_b))
]
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
input_shape = g.op("Shape", new_input_a)
left_sizes_a = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
)
shape_sizes = [
left_sizes_a,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", output_a)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", new_input_b)
left_sizes_b = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
)
shape_sizes = [
slices,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
input_shape = g.op("Shape", output_b)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
shape_sizes = [left_sizes_a, left_sizes_b]
return opset9._reshape_from_tensor(g, output, shape_sizes)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,296 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 14.
Note [ONNX operators that are added/updated in opset 14]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
HardSwish, Trilu
Updated operators:
Reshape
Add, Sub, Mul, Div
GRU, LSTM, RNN
BatchNorm, Cumsum, Relu
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
from __future__ import annotations
import functools
import torch
from torch.onnx import _constants
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
)
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
__all__ = [
"hardswish",
"tril",
"triu",
"reshape",
"batch_norm",
"quantized_hardswish",
"scaled_dot_product_attention",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
@_onnx_symbolic("aten::hardswish")
@symbolic_helper.parse_args("v")
def hardswish(g: jit_utils.GraphContext, self):
return g.op("HardSwish", self)
@_onnx_symbolic("aten::tril")
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=0)
@_onnx_symbolic("aten::triu")
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=1)
@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v")
def reshape(g: jit_utils.GraphContext, self, shape):
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
@_onnx_symbolic("aten::batch_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
def batch_norm(
g: jit_utils.GraphContext,
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
cudnn_enabled,
):
if (
torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype(
[input, weight, bias, running_mean, running_var]
)
and GLOBALS.export_onnx_opset_version < 15
):
return symbolic_helper._onnx_opset_unsupported_detailed(
"BatchNormalization",
14,
15,
"All input tensors must have the same `dtype`."
" Turn off Autocast or export using opset version 15.",
input,
)
symbolic_helper.check_training_mode(training, "batch_norm")
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
g, input, weight, bias, running_mean, running_var
)
out = g.op(
"BatchNormalization",
input,
weight,
bias,
running_mean,
running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
training_mode_i=0 if not training else 1,
outputs=1 if not training else 3,
)
if not training:
return out
else:
res, new_running_mean, new_running_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
return res
@_onnx_symbolic("quantized::hardswish")
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
# Ported from
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
# aten_scaled_dot_product_attention
# NOTE: Need op.Trilu
@_onnx_symbolic("aten::scaled_dot_product_attention")
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
def scaled_dot_product_attention(
g: jit_utils.GraphContext,
query: torch._C.Value,
key: torch._C.Value,
value: torch._C.Value,
attn_mask: torch._C.Value | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
enable_gqa: bool = False,
):
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
"is_causal and attn_mask cannot be set at the same time"
)
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)
if symbolic_helper._is_none(scale):
scale = _attention_scale(g, query)
if is_causal:
attn_mask = _causal_attention_mask(g, query, key)
# Swap the last two axes of key
# NOTE: onnx-script has different logic here, because the attribute perms in
# transpose needs list of ints
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
key_transposed_axes = list(range(key_shape_builtin))
key_transposed_axes[-1], key_transposed_axes[-2] = (
key_transposed_axes[-2],
key_transposed_axes[-1],
)
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
if symbolic_helper._is_none(attn_mask):
mul_qk_add = mul_qk
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
elif (
_type_utils.JitScalarType.from_value(attn_mask)
== _type_utils.JitScalarType.BOOL
):
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
# the behavior of PyTorch with boolean masks.
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
elif _type_utils.JitScalarType.from_value(attn_mask) in (
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.BFLOAT16,
):
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
else:
raise ValueError(
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
)
if dropout_p != 0:
attn_weight = g.op(
"Dropout",
attn_weight,
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
)
return g.op("MatMul", attn_weight, value)
def _attention_scale(
g: jit_utils.GraphContext, query: torch._C.Value
) -> torch._C.Value:
"""Calculate the scale factor for the attention result.
Args:
query: Tensor of shape [..., L, E]
Returns:
Scalar scale factor := 1 / math.sqrt(query.size(-1))
"""
query_shape = g.op("Shape", query)
query_shape_last = g.op(
"Slice",
query_shape,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
g.op(
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
),
)
embedding_size = g.op(
"Cast",
query_shape_last,
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
)
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
# Add a Cast to convert the scale back to original type
scale = g.op(
"Cast",
scale,
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
)
return scale
def _causal_attention_mask(
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
) -> torch._C.Value:
"""Create a causal mask for the given query and key tensors.
Equivalent to::
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.zeros(L, S, dtype=torch.float)
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
Args:
query: Tensor of shape [..., L, E]
key: Tensor of shape [..., S, E]
Returns:
Tensor of shape [L, S]
"""
query_shape = g.op("Shape", query)
key_shape = g.op("Shape", key)
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
# attn_mask = torch.ones(L, S) := {
size = g.op("Concat", target_length, source_length, axis_i=0)
const_one = g.op("Constant", value_t=torch.tensor([1.0]))
attn_mask = g.op("Expand", const_one, size)
# }
attn_mask = g.op("Trilu", attn_mask, upper_i=0)
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op(
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
)
return attn_mask

View File

@ -0,0 +1,84 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 15.
Note [ONNX operators that are added/updated in opset 15]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
New operators:
Bernoulli
CastLike
Optional
OptionalGetElement
OptionalHasElement
Updated operators:
BatchNormalization https://github.com/onnx/onnx/pull/3545
Backwards compatible
TODO: test coverage for mixed types inputs.
Pow https://github.com/onnx/onnx/pull/3412
Backwards compatible
TODO: bfloat16 support.
Shape https://github.com/onnx/onnx/pull/3580
Backwards compatible
TODO: optional start/end attribute.
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch import _C
from torch.onnx._internal.torchscript_exporter import (
jit_utils,
registration,
symbolic_helper,
symbolic_opset9 as opset9,
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
@_onnx_symbolic("aten::__is_")
def aten__is_(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
none = g.op("OptionalHasElement", self)
return g.op("Not", none)
else:
return g.op("Constant", value_t=torch.BoolTensor([0]))
return opset9.eq(g, self, other)
@_onnx_symbolic("aten::__isnot_")
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
def aten__isnot_(g: jit_utils.GraphContext, self, other):
return aten__is_(g, self, other)
@_onnx_symbolic("aten::bernoulli")
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
if out is not None and not symbolic_helper._is_none(out):
symbolic_helper._unimplemented(
"Bernoulli", "out parameter is not supported for bernoulli", input
)
if generator is not None and not symbolic_helper._is_none(generator):
symbolic_helper._unimplemented(
"Bernoulli", "generator is not supported for bernoulli", input
)
if p is None or symbolic_helper._is_none(p):
return g.op("Bernoulli", input)
return opset9.bernoulli(g, input, p, generator, out)
@_onnx_symbolic("prim::unchecked_cast")
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor.
if isinstance(self.type(), _C.OptionalType):
return g.op("OptionalGetElement", self)
return self

View File

@ -0,0 +1,191 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 16.
Note [ONNX Operators that are added/updated in opset 16]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
New operators:
GridSample https://github.com/onnx/onnx/pull/3557
Updated operators:
Identity
If
LeakyRelu
Loop
PRelu
RoiAlign
Scan
ScatterElements
ScatterND
Where
GreaterOrEqual
LessOrEqual
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import errors
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
utils,
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
def grid_sampler(
g: jit_utils.GraphContext,
input,
grid,
mode_enum,
padding_mode_enum,
align_corners,
):
# Check the input and grid tensor rank beforehand.
if symbolic_helper._get_tensor_rank(input) == 5:
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
padding_mode_enum
]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)
src_sizes = symbolic_helper._get_tensor_sizes(src)
index_sizes = symbolic_helper._get_tensor_sizes(index)
if len(src_sizes) != len(index_sizes):
return symbolic_helper._unimplemented(
"scatter_add",
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
)
# PyTorch only allows index shape <= src shape, so we can only consider
# taking index as subset size to src, like PyTorch does. When sizes for src
# and index are not matched or there are dynamic axes, we take index shape to
# slice src to accommodate.
if src_sizes != index_sizes or None in index_sizes:
adjusted_shape = g.op("Shape", index)
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
src = g.op("Slice", src, starts, adjusted_shape)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if _type_utils.JitScalarType.from_value(self) != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return g.op(
"ScatterElements",
self,
index,
src,
axis_i=dim,
reduction_s="add",
)
@_onnx_symbolic("aten::scatter_reduce")
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
def scatter_reduce(
g: jit_utils.GraphContext,
self: torch._C.Value,
dim: int,
index: torch._C.Value,
src: torch._C.Value,
reduce: str,
include_self: bool,
):
if reduce == "mean":
raise errors.OnnxExporterError(
"ONNX does not support mean reduction for scatter_reduce"
)
if not include_self:
raise errors.OnnxExporterError(
"ONNX does not support include_self=False for scatter_reduce"
)
reduce_mode = { # convert torch string name to onnx string name
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]
self_rank = g.op("Size", g.op("Shape", self))
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
self_rank_is_zero = g.op(
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
)
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
)
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
self_reshape = if_context.op("Reshape", self, neg_1)
utils._add_output_to_block(if_context.block, self_reshape)
index_reshape = if_context.op("Reshape", index, neg_1)
utils._add_output_to_block(if_context.block, index_reshape)
src_reshape = if_context.op("Reshape", src, neg_1)
utils._add_output_to_block(if_context.block, src_reshape)
self_identity = else_context.op("Identity", self)
utils._add_output_to_block(else_context.block, self_identity)
index_identitye = else_context.op("Identity", index)
utils._add_output_to_block(else_context.block, index_identitye)
src_identity = else_context.op("Identity", src)
utils._add_output_to_block(else_context.block, src_identity)
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
# if self_rank == 0:
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
)
result_squeezed = if_context.op("Squeeze", result)
utils._add_output_to_block(if_context.block, result_squeezed)
result_identity = else_context.op("Identity", result)
utils._add_output_to_block(else_context.block, result_identity)
result_final = if_op.node().output()
return result_final

View File

@ -0,0 +1,244 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 17.
Note [ONNX Operators that are added/updated in opset 17]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
New operators:
BlackmanWindow
DFT
HammingWindow
HannWindow
LayerNormalization
MelWeightMatrix
STFT
SequenceMap
"""
import functools
from collections.abc import Sequence
from typing import Optional
import torch
from torch import _C
from torch.onnx import errors
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__ = ["layer_norm", "stft", "quantized_layer_norm"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
@_onnx_symbolic("aten::layer_norm")
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
dtype = scalar_type.dtype()
if symbolic_helper._is_none(weight):
weight_value = torch.ones(normalized_shape, dtype=dtype)
weight = g.op("Constant", value_t=weight_value)
if symbolic_helper._is_none(bias):
bias_value = torch.zeros(normalized_shape, dtype=dtype)
bias = g.op("Constant", value_t=bias_value)
return g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
)
@_onnx_symbolic("quantized::layer_norm")
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
normalized_shape,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = layer_norm(g, x, normalized_shape, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
def _compute_edge_sizes(n_fft, window_size):
"""Helper function to compute the sizes of the edges (left and right)
of a given window centered within an FFT size."""
left = (n_fft - window_size) // 2
right = n_fft - left - window_size
return left, right
@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
def stft(
g: jit_utils.GraphContext,
input: _C.Value,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[_C.Value] = None,
normalized: bool = False,
onesided: Optional[bool] = True,
return_complex: Optional[bool] = False,
align_to_window: Optional[bool] = None,
) -> _C.Value:
"""Associates `torch.stft` with the `STFT` ONNX operator.
Note that torch.stft calls _VF.stft, without centering or padding options.
Hence, this function does not contain these two arguments.
See torch.stft source code for more info.
Args:
g: Graph to write the ONNX representation into
input: Input tensor for the transformation
n_fft: FFT size
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
win_length: Size of the analysis window. Defaults to `n_fft`
window: Analysis window. Defaults to a window of all ones
normalized: Whether to return a normalized STFT
onesided: Whether to return only half (+1) of the results, given the
symmetry of the STFT
return_complex: Whether to return the complex value (Note: Must be
`False` or `None`)
Returns:
op: Operator for torch.stft associated with STFT (ONNX)
"""
# Checks
if return_complex:
raise errors.SymbolicValueError(
msg="STFT does not currently support complex types", value=input
)
if align_to_window is not None:
raise errors.SymbolicValueError(
msg="STFT does not currently support the align_to_window option",
value=input,
) # TODO(#145944): add compatibility with align_to_window option.
# Get STFT sizes
frame_step_value = hop_length if hop_length is not None else n_fft // 4
frame_step_const = g.op(
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
)
frame_length_const = g.op(
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
)
# Pre-process input if needed
signal = input
signal_rank = symbolic_helper._get_tensor_rank(signal)
if signal_rank == 1:
# Add batch dimension
signal = g.op(
"Unsqueeze",
signal,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
elif signal_rank is None or signal_rank > 2:
raise errors.SymbolicValueError(
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
f"Current rank of signal is {signal_rank}, please reduce it.",
value=input,
)
# Get window and make sure it's the same size as `win_length` or `n_fft`
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
if n_win is not None:
win_length_default = win_length if win_length else n_fft
assert n_win == win_length_default, (
"Analysis window size must equal `win_length` or `n_fft`. "
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left, right = _compute_edge_sizes(n_fft, n_win)
left_win = g.op("Constant", value_t=torch.zeros(left))
right_win = g.op("Constant", value_t=torch.zeros(right))
window = g.op("Concat", left_win, window, right_win, axis_i=0)
# Create window, if needed
if symbolic_helper._is_none(window):
if win_length:
if win_length > n_fft:
raise errors.SymbolicValueError(
msg="The analysis window can't be longer than the size of the FFT. "
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
value=input,
)
# Center window, if needed
left, right = _compute_edge_sizes(n_fft, win_length)
torch_window = torch.hstack(
(torch.zeros(left), torch.ones(win_length), torch.zeros(right))
)
else:
# Rectangle window
torch_window = torch.ones(n_fft)
assert torch_window.shape[0] == n_fft
window = g.op("Constant", value_t=torch_window)
window = g.op(
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
)
# Run STFT
result = g.op(
"STFT",
signal,
frame_step_const,
window,
frame_length_const,
onesided_i=1 if onesided is None or onesided else 0,
)
# Transpose to mimic torch.stft's behavior
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
# Remove batch dimension, if needed
if signal_rank == 1:
result = g.op(
"Squeeze",
result,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
# Normalize, if needed
if normalized:
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
return result

View File

@ -0,0 +1,270 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 18.
Note [ONNX Operators that are added/updated in opset 18]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:
BitwiseAnd
CenterCropPad
Col2Im
Mish
OptionalGetElement
OptionalHasElement
Pad
Resize
ScatterElements
ScatterND
Split
"""
import functools
from collections.abc import Sequence
from typing import Optional
import torch
from torch import _C
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
symbolic_opset9 as opset9,
)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = [
"col2im",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@_onnx_symbolic("aten::__and_")
@_onnx_symbolic("aten::bitwise_and")
def __and_(g: jit_utils.GraphContext, self, other):
# do type promotion (scalars don't seem to apply)
args = [self, other]
# type promotion doesn't happen with torch.bitwise_and(tensor, scalar)
prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)]
if len(prom_args) == 0:
prom_args = args
promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args)
self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type)
other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type)
if promotion_jit_type == _type_utils.JitScalarType.BOOL:
return g.op("And", self, other)
return g.op("BitwiseAnd", self, other)
@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
def col2im(
g,
input: _C.Value,
output_size: _C.Value,
kernel_size: _C.Value,
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding: list[int] = []
for pad in padding:
adjusted_padding.extend(pad for _ in range(2))
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding:
adjusted_padding = [0, 0] * num_dimensional_axis
if not dilation:
dilation = [1] * num_dimensional_axis
if not stride:
stride = [1] * num_dimensional_axis
return g.op(
"Col2Im",
input,
output_size,
kernel_size,
dilations_i=dilation,
pads_i=adjusted_padding,
strides_i=stride,
)
@_onnx_symbolic(
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
)
@_onnx_symbolic(
"aten::prod",
decorate=[
symbolic_helper._apply_params(
"ReduceProd", "prod", allow_multi_dim_support=False
)
],
)
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
)
@_onnx_symbolic("aten::native_layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
def _native_layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
) -> tuple[_C.Value, _C.Value, _C.Value]:
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
@_onnx_symbolic("aten::glu")
@symbolic_helper.parse_args("v", "i")
def _glu(g: jit_utils.GraphContext, input, dim):
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
return g.op("Mul", first, g.op("Sigmoid", second))
@_onnx_symbolic("aten::max")
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
# TODO(justinchuby): Support multiple quantized args in output
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@symbolic_helper.quantized_args(True, True)
def maximum(g: jit_utils.GraphContext, input, other):
return max(g, input, dim_or_y=other)
@_onnx_symbolic("aten::min")
# TODO(justinchuby): Support multiple quantized args in output
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@symbolic_helper.quantized_args(True, True)
def minimum(g: jit_utils.GraphContext, input, other):
return min(g, input, dim_or_y=other)
@_onnx_symbolic("aten::amax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::amin")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::aminmax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i")
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
if not symbolic_helper._is_none(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
"ReduceMax", self, axes, keepdims_i=keepdim
)
else:
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
"ReduceMax", self, keepdims_i=keepdim
)
@_onnx_symbolic("aten::var_mean")
def _var_mean(g: jit_utils.GraphContext, input, *args):
if len(args) == 1:
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
else:
return symbolic_helper._var_mean_helper(g, input, *args)
@_onnx_symbolic("aten::logsumexp")
@symbolic_helper.parse_args("v", "is", "i")
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
if dim is None:
return g.op("ReduceLogSumExp", input, keepdims_i=0)
else:
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::linalg_matrix_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
def _linalg_matrix_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: torch._C.Value,
dim: list[int],
keepdim: bool,
dtype: torch._C.Value,
):
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
return symbolic_helper._embedding_bag_helper(
g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
def linalg_vector_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: float,
dim: Optional[Sequence[int]],
keepdim: bool,
dtype: torch._C.Value,
):
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)

View File

@ -0,0 +1,31 @@
"""This file exports ONNX ops for opset 19.
Note [ONNX Operators that are added/updated in opset 19]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
New operators:
AveragePool
Cast
CastLike
Constant
DeformConv
DequantizeLinear
Equal
Identity
If
Loop
Pad
QuantizeLinear
Reshape
Resize
Scan
Shape
Size
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__: list[str] = []

View File

@ -0,0 +1,95 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 20.
Note [ONNX Operators that are added/updated in opset 20]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
New operators:
AffineGrid
ConstantOfShape
DFT
Gelu
GridSample
ImageDecoder
IsInf
IsNaN
ReduceMax
ReduceMin
RegexFullMatch
StringConcat
StringSplit
"""
import functools
import torch.nn.functional as F
from torch import _C
from torch.onnx._internal.torchscript_exporter import (
jit_utils,
registration,
symbolic_helper,
)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
def convert_grid_sample_mode(mode_s):
return (
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
def _grid_sampler(
g: jit_utils.GraphContext,
input: _C.Value,
grid: _C.Value,
mode_enum: int,
padding_mode_enum: int,
align_corners: bool,
):
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
mode_s = convert_grid_sample_mode(mode_s)
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
padding_mode_enum # type: ignore[index]
]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::affine_grid_generator")
@symbolic_helper.parse_args("v", "v", "b")
def _affine_grid_generator(
g: jit_utils.GraphContext,
theta: _C.Value,
size: _C.Value,
align_corners: bool,
):
return g.op(
"AffineGrid",
theta,
size,
align_corners_i=int(align_corners),
)
@_onnx_symbolic("aten::gelu")
@symbolic_helper.parse_args("v", "s")
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
return g.op("Gelu", self, approximate_s=approximate)

View File

@ -0,0 +1,71 @@
# mypy: allow-untyped-defs
"""
Note [ONNX operators that are added/updated from opset 7 to opset 8]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
Expand
Updated operators:
Min, Max, Sum, Mean: supports multidirectional broadcasting.
MaxPool: added optional indices output.
Scan
"""
import functools
import warnings
from torch.onnx._internal.torchscript_exporter import (
jit_utils,
registration,
symbolic_helper,
symbolic_opset9 as opset9,
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7)
block_listed_operators = (
"scan",
"expand",
"expand_as",
"meshgrid",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"max_pool1d_with_indices",
"max_pool2d_with_indices",
"max_pool3d_with_indices",
)
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
@_onnx_symbolic("aten::max")
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.max(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn(
"Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to max operators "
"have different shapes"
)
return opset9.max(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::min")
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.min(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn(
"Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to min operators "
"have different shapes"
)
return opset9.min(g, self, dim_or_y, keepdim)
for block_listed_op in block_listed_operators:
_onnx_symbolic(f"aten::{block_listed_op}")(
symbolic_helper._block_list_in_opset(block_listed_op)
)

View File

@ -0,0 +1,469 @@
# mypy: allow-untyped-defs
"""
Note [ONNX operators that are added/updated from opset 8 to opset 9]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
Compress
ConstantOfShape
EyeLike
MaxUnpool
OneHot
Sinh
Cosh
Asinh
Acosh
Atanh
Shrink
IsNaN
Sign
Erf
Scatter
Where
NonZero
TfIdfVectorizer
MeanVarianceNormalization
Updated operators:
BatchNormalization: removed spatial attribute.
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
Cast: more data types{string} supported.
Upsample: moved scales from attribute to input.
Scan
"""
import functools
import warnings
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import errors
from torch.onnx._internal.torchscript_exporter import (
_type_utils,
jit_utils,
registration,
symbolic_helper,
symbolic_opset9 as opset9,
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
block_listed_operators = (
"nonzero",
"where",
"scatter",
"scatter_add",
"erf",
"sign",
"isnan",
"gather",
"arange",
"masked_fill",
"index_fill",
"index_copy",
"repeat_interleave",
"any",
"all",
)
for block_listed_op in block_listed_operators:
_onnx_symbolic(f"aten::{block_listed_op}")(
symbolic_helper._block_list_in_opset(block_listed_op)
)
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
output_size = symbolic_helper._maybe_get_const(output_size, "is")
if symbolic_helper._is_value(output_size):
return symbolic_helper._unimplemented(
name, "torch._C.Value (output_size) indexing"
)
if scales is None:
scales = [
1.0
if i < 2
else float(output_size[-(dim - i)])
/ float(input.type().sizes()[-(dim - i)])
for i in range(0, dim)
]
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
if not symbolic_helper._is_none(align_corners) and align_corners:
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
scale_factor
):
return symbolic_helper._unimplemented(
"interpolate", "dynamic scales in opset 8"
)
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
# is lost after casting.
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
floating_scalar_types = {
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.DOUBLE,
}
old_type = None
# Cast the input tensor to Float if its scalarType is known and is not floating number.
# If casting is performed, return the old scalarType, otherwise return None.
arg0_type = _type_utils.JitScalarType.from_value(
args[0], _type_utils.JitScalarType.UNDEFINED
)
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
old_type = arg0_type
if old_type not in floating_scalar_types:
old_type = old_type.scalar_name() # type: ignore[assignment]
args = tuple(
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
for arg in args
)
else:
return (None,) + args
else:
warnings.warn(
"Only floating datatype is supported for these operators: "
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
"the onnx model to be incorrect, if inputs have integer datatypes."
)
return (old_type,) + args
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
if to_type is None:
return input
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
other = symbolic_helper._maybe_get_scalar(other)
other = symbolic_helper._if_scalar_type_as(other, input)
_, input, other = _try_cast_integer_to_float(g, input, other)
return g.op(op_name, input, other)
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
# integer input type not supported in opset8. Cast to float if possible.
@_onnx_symbolic("aten::gt")
def gt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Greater")
@_onnx_symbolic("aten::lt")
def lt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Less")
@_onnx_symbolic("aten::bmm")
def bmm(g: jit_utils.GraphContext, self, other):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other = _try_cast_integer_to_float(g, self, other)
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
else:
return g.op("MatMul", self, other)
@_onnx_symbolic("aten::matmul")
def matmul(g: jit_utils.GraphContext, self, other):
return bmm(g, self, other)
@_onnx_symbolic("aten::prelu")
def prelu(g: jit_utils.GraphContext, self, weight):
self_rank = symbolic_helper._get_tensor_rank(self)
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
if self_rank is not None and self_rank > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
elif self_rank == 0 and weight_sizes == [1]:
# self and weight are both scalar but weight has rank == 1, squeeze weight.
weight = symbolic_helper._squeeze_helper(g, weight, [0])
if symbolic_helper._try_get_scalar_type(self):
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
else:
return g.op("PRelu", self, weight)
@_onnx_symbolic("aten::mm")
def mm(g: jit_utils.GraphContext, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
if scalar_type is None:
raise errors.SymbolicValueError(
"mm can only operate on tensors with known types", self
)
zero_constant = g.op(
"Constant",
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
)
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other, zero_constant = _try_cast_integer_to_float(
g, self, other, zero_constant
)
return _cast_to_type(
g,
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
old_type,
)
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::addmm")
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
return _cast_to_type(
g,
g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
),
old_type,
)
else:
return g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
)
@_onnx_symbolic("aten::flatten")
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
dim = input.type().dim()
if end_dim_i < 0:
end_dim_i = dim + end_dim_i
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim_i == 1 and end_dim_i == dim - 1:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
)
else:
return g.op("Flatten", input, axis_i=start_dim_i)
if start_dim_i == 0 and end_dim_i == dim - 2:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
)
else:
return g.op("Flatten", input, axis_i=end_dim_i + 1)
return opset9.flatten(g, input, start_dim, end_dim)
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
if not scalar_type.dtype().is_floating_point:
result = g.op(
"ConstantFill",
sizes,
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
input_as_shape_i=1,
value_f=const_value,
)
return g.op("Cast", result, to_i=scalar_type.onnx_type())
else:
return g.op(
"ConstantFill",
sizes,
dtype_i=scalar_type.onnx_type(),
input_as_shape_i=1,
value_f=const_value,
)
@_onnx_symbolic("aten::empty")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty(
g: jit_utils.GraphContext,
sizes,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
return zeros(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::empty_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
return zeros_like(g, input, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::zeros")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device and layout in ONNX, so we ignore it
return _constant_fill(g, sizes, dtype, 0)
@_onnx_symbolic("aten::zeros_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def zeros_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 0)
@_onnx_symbolic("aten::ones")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, 1)
@_onnx_symbolic("aten::ones_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def ones_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 1)
@_onnx_symbolic("aten::full")
def full(
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
):
const_value = symbolic_helper._maybe_get_const(value, "t")
if symbolic_helper._is_value(const_value):
tmp = zeros(g, sizes, dtype, layout, device)
return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
return _constant_fill(g, sizes, dtype, const_value)
@_onnx_symbolic("aten::full_like")
@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
def full_like(
g: jit_utils.GraphContext,
input,
fill_value,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, fill_value)
@_onnx_symbolic("aten::repeat")
def repeat(g: jit_utils.GraphContext, self, repeats):
if not symbolic_helper._is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if symbolic_helper._is_packed_list(repeats):
repeat_size_len = len(symbolic_helper._unpack_list(repeats))
else:
const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
repeat_size_len = len(const_repeats)
if self.isCompleteTensor():
sizes = self.type().sizes()
diff_dims = repeat_size_len - len(sizes)
if diff_dims > 0:
self = opset9.view(
g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
)
return g.op("Tile", self, repeats)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,464 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""Backward compatibility module for torch.onnx.symbolic_opset12."""
from __future__ import annotations
import functools
import sys
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import (
_type_utils,
errors,
symbolic_helper,
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import jit_utils, registration
__all__: list[str] = []
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 12
__all__ = [
"argmax",
"argmin",
"binary_cross_entropy_with_logits",
"celu",
"cross_entropy_loss",
"dropout",
"einsum",
"ge",
"le",
"native_dropout",
"nll_loss",
"nll_loss2d",
"nll_loss_nd",
"outer",
"pow",
"tensordot",
"unfold",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
if not tensors:
raise RuntimeError("Einsum inputs are empty.")
# ONNX does not support bool for Einsum inputs.
if symbolic_helper._is_bool(tensors[0]):
tensors = [
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
for tensor in tensors
]
return g.op(
"Cast",
g.op("Einsum", *tensors, equation_s=equation),
to_i=_C_onnx.TensorProtoDataType.BOOL,
)
else:
return g.op("Einsum", *tensors, equation_s=equation)
@_onnx_symbolic("aten::einsum")
@symbolic_helper.parse_args("s", "v", "is")
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
tensors = symbolic_helper._unpack_list(tensor_list)
return _einsum_helper(g, equation, tensors)
@_onnx_symbolic("aten::outer")
@symbolic_helper.parse_args("v", "v")
def outer(g: jit_utils.GraphContext, input, other):
# make sure to cast other to self's type
if _type_utils.JitScalarType.from_value(
other, _type_utils.JitScalarType.UNDEFINED
) != _type_utils.JitScalarType.from_value(input):
other = g.op(
"Cast",
other,
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
)
return _einsum_helper(g, "i,j->ij", [input, other])
def _dropout_returns_masked_input_and_mask(
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
) -> tuple[torch._C.Value, torch._C.Value | None]:
symbolic_helper.check_training_mode(train, "dropout")
# In eval mode, dropout is non-op. That is, if the node's
# train param is set to False, dropout just returns its inputs.
if not train:
return input, None
p = g.op("Constant", value_t=torch.tensor(p))
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
r, mask = g.op("Dropout", input, p, t, outputs=2)
return r, mask
@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "b")
def dropout(g: jit_utils.GraphContext, input, p, train):
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
return masked
@_onnx_symbolic("aten::native_dropout")
@symbolic_helper.parse_args("v", "f", "b")
def native_dropout(g: jit_utils.GraphContext, input, p, train):
return _dropout_returns_masked_input_and_mask(g, input, p, train)
@_onnx_symbolic("aten::nll_loss")
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return nllloss
@_onnx_symbolic("aten::nll_loss2d")
def nll_loss2d(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::nll_loss_nd")
def nll_loss_nd(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::cross_entropy_loss")
def cross_entropy_loss(
g: jit_utils.GraphContext,
self,
target,
weight,
reduction,
ignore_index,
label_smoothing,
):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
if label_smoothing is not None and label_smoothing > 0.0:
raise errors.SymbolicValueError(
"Unsupported: ONNX does not support label_smoothing", self
)
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return celoss
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
def binary_cross_entropy_with_logits(
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
):
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = opset9.sigmoid(g, input)
log_sig_x = opset9.log(g, sig_x)
sub_1_x = opset9.sub(g, p, sig_x)
sub_1_y = opset9.sub(g, p, target)
log_1_x = opset9.log(g, sub_1_x)
if pos_weight is None or symbolic_helper._is_none(pos_weight):
output = opset9.neg(
g,
opset9.add(
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
),
)
else:
output = opset9.neg(
g,
opset9.add(
g,
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
opset9.mul(g, sub_1_y, log_1_x),
),
)
if weight is not None and not symbolic_helper._is_none(weight):
output = opset9.mul(g, weight, output)
reduction = symbolic_helper._maybe_get_const(reduction, "i")
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return g.op("ReduceSum", output, keepdims_i=0)
else:
return symbolic_helper._onnx_unsupported(
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
input,
)
@_onnx_symbolic("aten::celu")
def celu(g: jit_utils.GraphContext, self, alpha):
alpha = symbolic_helper._maybe_get_const(alpha, "f")
# if the input is of type double cast it to float
if (
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.DOUBLE
):
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = g.op("Celu", self, alpha_f=alpha)
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
return g.op("Celu", self, alpha_f=alpha)
@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
def argmax(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
def argmin(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
@_onnx_symbolic("aten::pow")
def pow(g: jit_utils.GraphContext, self, exponent):
return g.op("Pow", self, exponent)
@_onnx_symbolic("aten::ge")
def ge(g: jit_utils.GraphContext, input, other):
return g.op("GreaterOrEqual", input, other)
@_onnx_symbolic("aten::le")
def le(g: jit_utils.GraphContext, input, other):
return g.op("LessOrEqual", input, other)
@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "v", "v")
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
const_size = symbolic_helper._maybe_get_const(size, "i")
const_step = symbolic_helper._maybe_get_const(step, "i")
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
const_step
):
return opset9.unfold(g, input, dimension, const_size, const_step)
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
if sizedim is not None:
low_start = g.op("Constant", value_t=torch.tensor(0))
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
low_indices = g.op("Range", low_start, low_end, step)
hi_indices = g.op("Range", size, hi_end, step)
low_size = symbolic_helper._size_helper(
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
)
hi_size = symbolic_helper._size_helper(
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
)
ndim = symbolic_helper._get_tensor_rank(input)
assert ndim is not None
perm = list(range(0, ndim))
perm.append(perm.pop(dimension))
unsqueeze_list = []
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
loop_len = g.op("Min", low_size, hi_size)
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block) # noqa: F841
starts = loop_context.op("Gather", low_indices, block_input_iter)
ends = loop_context.op("Gather", hi_indices, block_input_iter)
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
stack = loop_context.op("Slice", input, starts, ends, axes)
unsqueeze = symbolic_helper._unsqueeze_helper(
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
)
unsqueeze_list.append(unsqueeze)
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, concat)
loop_output = loop.node().output()
perm = [0, 1, 2, 3, 4]
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
transpose = g.op("Transpose", loop_output, perm_i=perm)
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
return squeeze
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
@_onnx_symbolic("aten::tensordot")
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
if out is not None:
symbolic_helper._unimplemented(
"Tensordot", "Out parameter is not supported for tensordot."
)
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
if dim_count_a is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
input_a,
)
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
if dim_count_b is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
input_b,
)
dims_a = [
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
for i in range(len(dims_a))
]
dims_b = [
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
for i in range(len(dims_b))
]
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
input_shape = g.op("Shape", new_input_a)
left_sizes_a = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
)
shape_sizes = [
left_sizes_a,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", output_a)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", new_input_b)
left_sizes_b = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
)
shape_sizes = [
slices,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
input_shape = g.op("Shape", output_b)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
shape_sizes = [left_sizes_a, left_sizes_b]
return opset9._reshape_from_tensor(g, output, shape_sizes)
from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F401,F403

File diff suppressed because it is too large Load Diff

View File

@ -1,291 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 14.
"""Backward compatibility module for torch.onnx.symbolic_opset14."""
Note [ONNX operators that are added/updated in opset 14]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
HardSwish, Trilu
Updated operators:
Reshape
Add, Sub, Mul, Div
GRU, LSTM, RNN
BatchNorm, Cumsum, Relu
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
from __future__ import annotations
import functools
import torch
from torch.onnx import _constants, _type_utils, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
__all__: list[str] = []
__all__ = [
"hardswish",
"tril",
"triu",
"reshape",
"batch_norm",
"quantized_hardswish",
"scaled_dot_product_attention",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
@_onnx_symbolic("aten::hardswish")
@symbolic_helper.parse_args("v")
def hardswish(g: jit_utils.GraphContext, self):
return g.op("HardSwish", self)
@_onnx_symbolic("aten::tril")
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=0)
@_onnx_symbolic("aten::triu")
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=1)
@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v")
def reshape(g: jit_utils.GraphContext, self, shape):
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
@_onnx_symbolic("aten::batch_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
def batch_norm(
g: jit_utils.GraphContext,
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
cudnn_enabled,
):
if (
torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype(
[input, weight, bias, running_mean, running_var]
)
and GLOBALS.export_onnx_opset_version < 15
):
return symbolic_helper._onnx_opset_unsupported_detailed(
"BatchNormalization",
14,
15,
"All input tensors must have the same `dtype`."
" Turn off Autocast or export using opset version 15.",
input,
)
symbolic_helper.check_training_mode(training, "batch_norm")
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
g, input, weight, bias, running_mean, running_var
)
out = g.op(
"BatchNormalization",
input,
weight,
bias,
running_mean,
running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
training_mode_i=0 if not training else 1,
outputs=1 if not training else 3,
)
if not training:
return out
else:
res, new_running_mean, new_running_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
return res
@_onnx_symbolic("quantized::hardswish")
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
# Ported from
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
# aten_scaled_dot_product_attention
# NOTE: Need op.Trilu
@_onnx_symbolic("aten::scaled_dot_product_attention")
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
def scaled_dot_product_attention(
g: jit_utils.GraphContext,
query: torch._C.Value,
key: torch._C.Value,
value: torch._C.Value,
attn_mask: torch._C.Value | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
enable_gqa: bool = False,
):
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
"is_causal and attn_mask cannot be set at the same time"
)
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)
if symbolic_helper._is_none(scale):
scale = _attention_scale(g, query)
if is_causal:
attn_mask = _causal_attention_mask(g, query, key)
# Swap the last two axes of key
# NOTE: onnx-script has different logic here, because the attribute perms in
# transpose needs list of ints
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
key_transposed_axes = list(range(key_shape_builtin))
key_transposed_axes[-1], key_transposed_axes[-2] = (
key_transposed_axes[-2],
key_transposed_axes[-1],
)
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
if symbolic_helper._is_none(attn_mask):
mul_qk_add = mul_qk
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
elif (
_type_utils.JitScalarType.from_value(attn_mask)
== _type_utils.JitScalarType.BOOL
):
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
# the behavior of PyTorch with boolean masks.
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
elif _type_utils.JitScalarType.from_value(attn_mask) in (
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.BFLOAT16,
):
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
else:
raise ValueError(
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
)
if dropout_p != 0:
attn_weight = g.op(
"Dropout",
attn_weight,
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
)
return g.op("MatMul", attn_weight, value)
def _attention_scale(
g: jit_utils.GraphContext, query: torch._C.Value
) -> torch._C.Value:
"""Calculate the scale factor for the attention result.
Args:
query: Tensor of shape [..., L, E]
Returns:
Scalar scale factor := 1 / math.sqrt(query.size(-1))
"""
query_shape = g.op("Shape", query)
query_shape_last = g.op(
"Slice",
query_shape,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
g.op(
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
),
)
embedding_size = g.op(
"Cast",
query_shape_last,
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
)
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
# Add a Cast to convert the scale back to original type
scale = g.op(
"Cast",
scale,
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
)
return scale
def _causal_attention_mask(
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
) -> torch._C.Value:
"""Create a causal mask for the given query and key tensors.
Equivalent to::
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.zeros(L, S, dtype=torch.float)
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
Args:
query: Tensor of shape [..., L, E]
key: Tensor of shape [..., S, E]
Returns:
Tensor of shape [L, S]
"""
query_shape = g.op("Shape", query)
key_shape = g.op("Shape", key)
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
# attn_mask = torch.ones(L, S) := {
size = g.op("Concat", target_length, source_length, axis_i=0)
const_one = g.op("Constant", value_t=torch.tensor([1.0]))
attn_mask = g.op("Expand", const_one, size)
# }
attn_mask = g.op("Trilu", attn_mask, upper_i=0)
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op(
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
)
return attn_mask
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F401,F403

View File

@ -1,80 +1,8 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 15.
"""Backward compatibility module for torch.onnx.symbolic_opset15."""
Note [ONNX operators that are added/updated in opset 15]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
New operators:
Bernoulli
CastLike
Optional
OptionalGetElement
OptionalHasElement
Updated operators:
BatchNormalization https://github.com/onnx/onnx/pull/3545
Backwards compatible
TODO: test coverage for mixed types inputs.
Pow https://github.com/onnx/onnx/pull/3412
Backwards compatible
TODO: bfloat16 support.
Shape https://github.com/onnx/onnx/pull/3580
Backwards compatible
TODO: optional start/end attribute.
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
__all__: list[str] = []
@_onnx_symbolic("aten::__is_")
def aten__is_(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
none = g.op("OptionalHasElement", self)
return g.op("Not", none)
else:
return g.op("Constant", value_t=torch.BoolTensor([0]))
return opset9.eq(g, self, other)
@_onnx_symbolic("aten::__isnot_")
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
def aten__isnot_(g: jit_utils.GraphContext, self, other):
return aten__is_(g, self, other)
@_onnx_symbolic("aten::bernoulli")
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
if out is not None and not symbolic_helper._is_none(out):
symbolic_helper._unimplemented(
"Bernoulli", "out parameter is not supported for bernoulli", input
)
if generator is not None and not symbolic_helper._is_none(generator):
symbolic_helper._unimplemented(
"Bernoulli", "generator is not supported for bernoulli", input
)
if p is None or symbolic_helper._is_none(p):
return g.op("Bernoulli", input)
return opset9.bernoulli(g, input, p, generator, out)
@_onnx_symbolic("prim::unchecked_cast")
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor.
if isinstance(self.type(), _C.OptionalType):
return g.op("OptionalGetElement", self)
return self
from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F401,F403

View File

@ -1,185 +1,8 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 16.
"""Backward compatibility module for torch.onnx.symbolic_opset16."""
Note [ONNX Operators that are added/updated in opset 16]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
New operators:
GridSample https://github.com/onnx/onnx/pull/3557
Updated operators:
Identity
If
LeakyRelu
Loop
PRelu
RoiAlign
Scan
ScatterElements
ScatterND
Where
GreaterOrEqual
LessOrEqual
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
__all__: list[str] = []
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
def grid_sampler(
g: jit_utils.GraphContext,
input,
grid,
mode_enum,
padding_mode_enum,
align_corners,
):
# Check the input and grid tensor rank beforehand.
if symbolic_helper._get_tensor_rank(input) == 5:
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
padding_mode_enum
]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)
src_sizes = symbolic_helper._get_tensor_sizes(src)
index_sizes = symbolic_helper._get_tensor_sizes(index)
if len(src_sizes) != len(index_sizes):
return symbolic_helper._unimplemented(
"scatter_add",
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
)
# PyTorch only allows index shape <= src shape, so we can only consider
# taking index as subset size to src, like PyTorch does. When sizes for src
# and index are not matched or there are dynamic axes, we take index shape to
# slice src to accommodate.
if src_sizes != index_sizes or None in index_sizes:
adjusted_shape = g.op("Shape", index)
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
src = g.op("Slice", src, starts, adjusted_shape)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if _type_utils.JitScalarType.from_value(self) != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return g.op(
"ScatterElements",
self,
index,
src,
axis_i=dim,
reduction_s="add",
)
@_onnx_symbolic("aten::scatter_reduce")
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
def scatter_reduce(
g: jit_utils.GraphContext,
self: torch._C.Value,
dim: int,
index: torch._C.Value,
src: torch._C.Value,
reduce: str,
include_self: bool,
):
if reduce == "mean":
raise errors.OnnxExporterError(
"ONNX does not support mean reduction for scatter_reduce"
)
if not include_self:
raise errors.OnnxExporterError(
"ONNX does not support include_self=False for scatter_reduce"
)
reduce_mode = { # convert torch string name to onnx string name
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]
self_rank = g.op("Size", g.op("Shape", self))
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
self_rank_is_zero = g.op(
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
)
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
)
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
self_reshape = if_context.op("Reshape", self, neg_1)
utils._add_output_to_block(if_context.block, self_reshape)
index_reshape = if_context.op("Reshape", index, neg_1)
utils._add_output_to_block(if_context.block, index_reshape)
src_reshape = if_context.op("Reshape", src, neg_1)
utils._add_output_to_block(if_context.block, src_reshape)
self_identity = else_context.op("Identity", self)
utils._add_output_to_block(else_context.block, self_identity)
index_identitye = else_context.op("Identity", index)
utils._add_output_to_block(else_context.block, index_identitye)
src_identity = else_context.op("Identity", src)
utils._add_output_to_block(else_context.block, src_identity)
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
# if self_rank == 0:
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
)
result_squeezed = if_context.op("Squeeze", result)
utils._add_output_to_block(if_context.block, result_squeezed)
result_identity = else_context.op("Identity", result)
utils._add_output_to_block(else_context.block, result_identity)
result_final = if_op.node().output()
return result_final
from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F401,F403

View File

@ -1,239 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 17.
"""Backward compatibility module for torch.onnx.symbolic_opset17."""
Note [ONNX Operators that are added/updated in opset 17]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
New operators:
BlackmanWindow
DFT
HammingWindow
HannWindow
LayerNormalization
MelWeightMatrix
STFT
SequenceMap
"""
import functools
from collections.abc import Sequence
from typing import Optional
import torch
from torch import _C
from torch.onnx import _type_utils, errors, symbolic_helper
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__: list[str] = []
__all__ = ["layer_norm", "stft", "quantized_layer_norm"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
@_onnx_symbolic("aten::layer_norm")
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
dtype = scalar_type.dtype()
if symbolic_helper._is_none(weight):
weight_value = torch.ones(normalized_shape, dtype=dtype)
weight = g.op("Constant", value_t=weight_value)
if symbolic_helper._is_none(bias):
bias_value = torch.zeros(normalized_shape, dtype=dtype)
bias = g.op("Constant", value_t=bias_value)
return g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
)
@_onnx_symbolic("quantized::layer_norm")
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
normalized_shape,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = layer_norm(g, x, normalized_shape, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
def _compute_edge_sizes(n_fft, window_size):
"""Helper function to compute the sizes of the edges (left and right)
of a given window centered within an FFT size."""
left = (n_fft - window_size) // 2
right = n_fft - left - window_size
return left, right
@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
def stft(
g: jit_utils.GraphContext,
input: _C.Value,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[_C.Value] = None,
normalized: bool = False,
onesided: Optional[bool] = True,
return_complex: Optional[bool] = False,
align_to_window: Optional[bool] = None,
) -> _C.Value:
"""Associates `torch.stft` with the `STFT` ONNX operator.
Note that torch.stft calls _VF.stft, without centering or padding options.
Hence, this function does not contain these two arguments.
See torch.stft source code for more info.
Args:
g: Graph to write the ONNX representation into
input: Input tensor for the transformation
n_fft: FFT size
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
win_length: Size of the analysis window. Defaults to `n_fft`
window: Analysis window. Defaults to a window of all ones
normalized: Whether to return a normalized STFT
onesided: Whether to return only half (+1) of the results, given the
symmetry of the STFT
return_complex: Whether to return the complex value (Note: Must be
`False` or `None`)
Returns:
op: Operator for torch.stft associated with STFT (ONNX)
"""
# Checks
if return_complex:
raise errors.SymbolicValueError(
msg="STFT does not currently support complex types", value=input
)
if align_to_window is not None:
raise errors.SymbolicValueError(
msg="STFT does not currently support the align_to_window option",
value=input,
) # TODO(#145944): add compatibility with align_to_window option.
# Get STFT sizes
frame_step_value = hop_length if hop_length is not None else n_fft // 4
frame_step_const = g.op(
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
)
frame_length_const = g.op(
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
)
# Pre-process input if needed
signal = input
signal_rank = symbolic_helper._get_tensor_rank(signal)
if signal_rank == 1:
# Add batch dimension
signal = g.op(
"Unsqueeze",
signal,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
elif signal_rank is None or signal_rank > 2:
raise errors.SymbolicValueError(
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
f"Current rank of signal is {signal_rank}, please reduce it.",
value=input,
)
# Get window and make sure it's the same size as `win_length` or `n_fft`
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
if n_win is not None:
win_length_default = win_length if win_length else n_fft
assert n_win == win_length_default, (
"Analysis window size must equal `win_length` or `n_fft`. "
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left, right = _compute_edge_sizes(n_fft, n_win)
left_win = g.op("Constant", value_t=torch.zeros(left))
right_win = g.op("Constant", value_t=torch.zeros(right))
window = g.op("Concat", left_win, window, right_win, axis_i=0)
# Create window, if needed
if symbolic_helper._is_none(window):
if win_length:
if win_length > n_fft:
raise errors.SymbolicValueError(
msg="The analysis window can't be longer than the size of the FFT. "
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
value=input,
)
# Center window, if needed
left, right = _compute_edge_sizes(n_fft, win_length)
torch_window = torch.hstack(
(torch.zeros(left), torch.ones(win_length), torch.zeros(right))
)
else:
# Rectangle window
torch_window = torch.ones(n_fft)
assert torch_window.shape[0] == n_fft
window = g.op("Constant", value_t=torch_window)
window = g.op(
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
)
# Run STFT
result = g.op(
"STFT",
signal,
frame_step_const,
window,
frame_length_const,
onesided_i=1 if onesided is None or onesided else 0,
)
# Transpose to mimic torch.stft's behavior
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
# Remove batch dimension, if needed
if signal_rank == 1:
result = g.op(
"Squeeze",
result,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
# Normalize, if needed
if normalized:
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
return result
from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F401,F403

View File

@ -1,265 +1,8 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 18.
"""Backward compatibility module for torch.onnx.symbolic_opset18."""
Note [ONNX Operators that are added/updated in opset 18]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:
BitwiseAnd
CenterCropPad
Col2Im
Mish
OptionalGetElement
OptionalHasElement
Pad
Resize
ScatterElements
ScatterND
Split
"""
import functools
from collections.abc import Sequence
from typing import Optional
import torch
from torch import _C
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__: list[str] = []
__all__ = [
"col2im",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@_onnx_symbolic("aten::__and_")
@_onnx_symbolic("aten::bitwise_and")
def __and_(g: jit_utils.GraphContext, self, other):
# do type promotion (scalars don't seem to apply)
args = [self, other]
# type promotion doesn't happen with torch.bitwise_and(tensor, scalar)
prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)]
if len(prom_args) == 0:
prom_args = args
promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args)
self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type)
other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type)
if promotion_jit_type == _type_utils.JitScalarType.BOOL:
return g.op("And", self, other)
return g.op("BitwiseAnd", self, other)
@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
def col2im(
g,
input: _C.Value,
output_size: _C.Value,
kernel_size: _C.Value,
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding: list[int] = []
for pad in padding:
adjusted_padding.extend(pad for _ in range(2))
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding:
adjusted_padding = [0, 0] * num_dimensional_axis
if not dilation:
dilation = [1] * num_dimensional_axis
if not stride:
stride = [1] * num_dimensional_axis
return g.op(
"Col2Im",
input,
output_size,
kernel_size,
dilations_i=dilation,
pads_i=adjusted_padding,
strides_i=stride,
)
@_onnx_symbolic(
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
)
@_onnx_symbolic(
"aten::prod",
decorate=[
symbolic_helper._apply_params(
"ReduceProd", "prod", allow_multi_dim_support=False
)
],
)
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
)
@_onnx_symbolic("aten::native_layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
def _native_layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
) -> tuple[_C.Value, _C.Value, _C.Value]:
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
@_onnx_symbolic("aten::glu")
@symbolic_helper.parse_args("v", "i")
def _glu(g: jit_utils.GraphContext, input, dim):
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
return g.op("Mul", first, g.op("Sigmoid", second))
@_onnx_symbolic("aten::max")
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
# TODO(justinchuby): Support multiple quantized args in output
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@symbolic_helper.quantized_args(True, True)
def maximum(g: jit_utils.GraphContext, input, other):
return max(g, input, dim_or_y=other)
@_onnx_symbolic("aten::min")
# TODO(justinchuby): Support multiple quantized args in output
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@symbolic_helper.quantized_args(True, True)
def minimum(g: jit_utils.GraphContext, input, other):
return min(g, input, dim_or_y=other)
@_onnx_symbolic("aten::amax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::amin")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::aminmax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i")
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
if not symbolic_helper._is_none(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
"ReduceMax", self, axes, keepdims_i=keepdim
)
else:
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
"ReduceMax", self, keepdims_i=keepdim
)
@_onnx_symbolic("aten::var_mean")
def _var_mean(g: jit_utils.GraphContext, input, *args):
if len(args) == 1:
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
else:
return symbolic_helper._var_mean_helper(g, input, *args)
@_onnx_symbolic("aten::logsumexp")
@symbolic_helper.parse_args("v", "is", "i")
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
if dim is None:
return g.op("ReduceLogSumExp", input, keepdims_i=0)
else:
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::linalg_matrix_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
def _linalg_matrix_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: torch._C.Value,
dim: list[int],
keepdim: bool,
dtype: torch._C.Value,
):
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
return symbolic_helper._embedding_bag_helper(
g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
def linalg_vector_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: float,
dim: Optional[Sequence[int]],
keepdim: bool,
dtype: torch._C.Value,
):
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F401,F403

View File

@ -1,31 +1,8 @@
"""This file exports ONNX ops for opset 19.
"""Backward compatibility module for torch.onnx.symbolic_opset19."""
Note [ONNX Operators that are added/updated in opset 19]
from __future__ import annotations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
New operators:
AveragePool
Cast
CastLike
Constant
DeformConv
DequantizeLinear
Equal
Identity
If
Loop
Pad
QuantizeLinear
Reshape
Resize
Scan
Shape
Size
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__: list[str] = []
from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F401,F403

View File

@ -1,92 +1,8 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 20.
"""Backward compatibility module for torch.onnx.symbolic_opset20."""
Note [ONNX Operators that are added/updated in opset 20]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
New operators:
AffineGrid
ConstantOfShape
DFT
Gelu
GridSample
ImageDecoder
IsInf
IsNaN
ReduceMax
ReduceMin
RegexFullMatch
StringConcat
StringSplit
"""
import functools
import torch.nn.functional as F
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__: list[str] = []
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
def convert_grid_sample_mode(mode_s):
return (
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
def _grid_sampler(
g: jit_utils.GraphContext,
input: _C.Value,
grid: _C.Value,
mode_enum: int,
padding_mode_enum: int,
align_corners: bool,
):
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
mode_s = convert_grid_sample_mode(mode_s)
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
padding_mode_enum # type: ignore[index]
]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::affine_grid_generator")
@symbolic_helper.parse_args("v", "v", "b")
def _affine_grid_generator(
g: jit_utils.GraphContext,
theta: _C.Value,
size: _C.Value,
align_corners: bool,
):
return g.op(
"AffineGrid",
theta,
size,
align_corners_i=int(align_corners),
)
@_onnx_symbolic("aten::gelu")
@symbolic_helper.parse_args("v", "s")
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
return g.op("Gelu", self, approximate_s=approximate)
from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F401,F403

View File

@ -1,67 +1,8 @@
# mypy: allow-untyped-defs
"""
Note [ONNX operators that are added/updated from opset 7 to opset 8]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
Expand
"""Backward compatibility module for torch.onnx.symbolic_opset7."""
Updated operators:
Min, Max, Sum, Mean: supports multidirectional broadcasting.
MaxPool: added optional indices output.
Scan
"""
import functools
import warnings
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7)
__all__: list[str] = []
block_listed_operators = (
"scan",
"expand",
"expand_as",
"meshgrid",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"max_pool1d_with_indices",
"max_pool2d_with_indices",
"max_pool3d_with_indices",
)
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
@_onnx_symbolic("aten::max")
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.max(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn(
"Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to max operators "
"have different shapes"
)
return opset9.max(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::min")
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.min(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn(
"Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to min operators "
"have different shapes"
)
return opset9.min(g, self, dim_or_y, keepdim)
for block_listed_op in block_listed_operators:
_onnx_symbolic(f"aten::{block_listed_op}")(
symbolic_helper._block_list_in_opset(block_listed_op)
)
from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F401,F403

View File

@ -1,463 +1,8 @@
# mypy: allow-untyped-defs
"""
Note [ONNX operators that are added/updated from opset 8 to opset 9]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
Compress
ConstantOfShape
EyeLike
MaxUnpool
OneHot
Sinh
Cosh
Asinh
Acosh
Atanh
Shrink
IsNaN
Sign
Erf
Scatter
Where
NonZero
TfIdfVectorizer
MeanVarianceNormalization
"""Backward compatibility module for torch.onnx.symbolic_opset8."""
Updated operators:
BatchNormalization: removed spatial attribute.
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
Cast: more data types{string} supported.
Upsample: moved scales from attribute to input.
Scan
"""
import functools
import warnings
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from __future__ import annotations
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
__all__: list[str] = []
block_listed_operators = (
"nonzero",
"where",
"scatter",
"scatter_add",
"erf",
"sign",
"isnan",
"gather",
"arange",
"masked_fill",
"index_fill",
"index_copy",
"repeat_interleave",
"any",
"all",
)
for block_listed_op in block_listed_operators:
_onnx_symbolic(f"aten::{block_listed_op}")(
symbolic_helper._block_list_in_opset(block_listed_op)
)
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
output_size = symbolic_helper._maybe_get_const(output_size, "is")
if symbolic_helper._is_value(output_size):
return symbolic_helper._unimplemented(
name, "torch._C.Value (output_size) indexing"
)
if scales is None:
scales = [
1.0
if i < 2
else float(output_size[-(dim - i)])
/ float(input.type().sizes()[-(dim - i)])
for i in range(0, dim)
]
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
if not symbolic_helper._is_none(align_corners) and align_corners:
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
scale_factor
):
return symbolic_helper._unimplemented(
"interpolate", "dynamic scales in opset 8"
)
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
# is lost after casting.
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
floating_scalar_types = {
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.DOUBLE,
}
old_type = None
# Cast the input tensor to Float if its scalarType is known and is not floating number.
# If casting is performed, return the old scalarType, otherwise return None.
arg0_type = _type_utils.JitScalarType.from_value(
args[0], _type_utils.JitScalarType.UNDEFINED
)
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
old_type = arg0_type
if old_type not in floating_scalar_types:
old_type = old_type.scalar_name() # type: ignore[assignment]
args = tuple(
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
for arg in args
)
else:
return (None,) + args
else:
warnings.warn(
"Only floating datatype is supported for these operators: "
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
"the onnx model to be incorrect, if inputs have integer datatypes."
)
return (old_type,) + args
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
if to_type is None:
return input
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
other = symbolic_helper._maybe_get_scalar(other)
other = symbolic_helper._if_scalar_type_as(other, input)
_, input, other = _try_cast_integer_to_float(g, input, other)
return g.op(op_name, input, other)
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
# integer input type not supported in opset8. Cast to float if possible.
@_onnx_symbolic("aten::gt")
def gt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Greater")
@_onnx_symbolic("aten::lt")
def lt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Less")
@_onnx_symbolic("aten::bmm")
def bmm(g: jit_utils.GraphContext, self, other):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other = _try_cast_integer_to_float(g, self, other)
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
else:
return g.op("MatMul", self, other)
@_onnx_symbolic("aten::matmul")
def matmul(g: jit_utils.GraphContext, self, other):
return bmm(g, self, other)
@_onnx_symbolic("aten::prelu")
def prelu(g: jit_utils.GraphContext, self, weight):
self_rank = symbolic_helper._get_tensor_rank(self)
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
if self_rank is not None and self_rank > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
elif self_rank == 0 and weight_sizes == [1]:
# self and weight are both scalar but weight has rank == 1, squeeze weight.
weight = symbolic_helper._squeeze_helper(g, weight, [0])
if symbolic_helper._try_get_scalar_type(self):
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
else:
return g.op("PRelu", self, weight)
@_onnx_symbolic("aten::mm")
def mm(g: jit_utils.GraphContext, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
if scalar_type is None:
raise errors.SymbolicValueError(
"mm can only operate on tensors with known types", self
)
zero_constant = g.op(
"Constant",
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
)
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other, zero_constant = _try_cast_integer_to_float(
g, self, other, zero_constant
)
return _cast_to_type(
g,
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
old_type,
)
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::addmm")
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
return _cast_to_type(
g,
g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
),
old_type,
)
else:
return g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
)
@_onnx_symbolic("aten::flatten")
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
dim = input.type().dim()
if end_dim_i < 0:
end_dim_i = dim + end_dim_i
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim_i == 1 and end_dim_i == dim - 1:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
)
else:
return g.op("Flatten", input, axis_i=start_dim_i)
if start_dim_i == 0 and end_dim_i == dim - 2:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
)
else:
return g.op("Flatten", input, axis_i=end_dim_i + 1)
return opset9.flatten(g, input, start_dim, end_dim)
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
if not scalar_type.dtype().is_floating_point:
result = g.op(
"ConstantFill",
sizes,
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
input_as_shape_i=1,
value_f=const_value,
)
return g.op("Cast", result, to_i=scalar_type.onnx_type())
else:
return g.op(
"ConstantFill",
sizes,
dtype_i=scalar_type.onnx_type(),
input_as_shape_i=1,
value_f=const_value,
)
@_onnx_symbolic("aten::empty")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty(
g: jit_utils.GraphContext,
sizes,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
return zeros(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::empty_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
return zeros_like(g, input, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::zeros")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device and layout in ONNX, so we ignore it
return _constant_fill(g, sizes, dtype, 0)
@_onnx_symbolic("aten::zeros_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def zeros_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 0)
@_onnx_symbolic("aten::ones")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, 1)
@_onnx_symbolic("aten::ones_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def ones_like(
g: jit_utils.GraphContext,
input,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 1)
@_onnx_symbolic("aten::full")
def full(
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
):
const_value = symbolic_helper._maybe_get_const(value, "t")
if symbolic_helper._is_value(const_value):
tmp = zeros(g, sizes, dtype, layout, device)
return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
return _constant_fill(g, sizes, dtype, const_value)
@_onnx_symbolic("aten::full_like")
@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
def full_like(
g: jit_utils.GraphContext,
input,
fill_value,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, fill_value)
@_onnx_symbolic("aten::repeat")
def repeat(g: jit_utils.GraphContext, self, repeats):
if not symbolic_helper._is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if symbolic_helper._is_packed_list(repeats):
repeat_size_len = len(symbolic_helper._unpack_list(repeats))
else:
const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
repeat_size_len = len(const_repeats)
if self.isCompleteTensor():
sizes = self.type().sizes()
diff_dims = repeat_size_len - len(sizes)
if diff_dims > 0:
self = opset9.view(
g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
)
return g.op("Tile", self, repeats)
from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F401,F403

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff