mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Refactor torchscript based exporter (#161323)
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
793fc12aff
commit
524b78d4f6
@ -113,7 +113,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/
|
||||
.. autofunction:: register_custom_op_symbolic
|
||||
.. autofunction:: unregister_custom_op_symbolic
|
||||
.. autofunction:: select_model_mode_for_export
|
||||
.. autoclass:: JitScalarType
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
|
@ -1,4 +1,5 @@
|
||||
# torch.onnx.verification
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.onnx.verification
|
||||
```
|
||||
@ -11,23 +12,3 @@
|
||||
.. autoclass:: VerificationInfo
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: verify
|
||||
```
|
||||
|
||||
## Deprecated
|
||||
|
||||
The following classes and functions are deprecated.
|
||||
|
||||
<!-- Some deprecated members are not publicly shown -->
|
||||
```{eval-rst}
|
||||
.. py:class:: check_export_model_diff
|
||||
.. py:class:: GraphInfo
|
||||
.. py:class:: GraphInfoPrettyPrinter
|
||||
.. py:class:: OnnxBackend
|
||||
.. py:class:: OnnxTestCaseRepro
|
||||
.. py:class:: VerificationOptions
|
||||
.. py:function:: find_mismatch
|
||||
.. py:function:: verify_aten_graph
|
||||
```
|
||||
|
@ -4,7 +4,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal.torchscript_exporter import registration
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -17,7 +17,8 @@ import pytorch_test_common
|
||||
|
||||
import torch
|
||||
from torch import export as torch_export
|
||||
from torch.onnx import _constants, verification
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.opinfo import core as opinfo_core
|
||||
from torch.types import Number
|
||||
|
@ -5,13 +5,11 @@ from onnx_test_common import run_model_test
|
||||
|
||||
import torch
|
||||
from torch.onnx import OperatorExportTypes
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx.utils import _model_to_graph
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
opset_version = GLOBALS.export_onnx_opset_version
|
||||
opset_version = 20
|
||||
keep_initializers_as_inputs = False
|
||||
onnx_shape_inference = True
|
||||
|
||||
@ -133,7 +131,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
input = torch.ones(1, 5)
|
||||
|
||||
# Test ONNX_FALLTHROUGH_MODE
|
||||
graph, _, _ = _model_to_graph(
|
||||
graph, _, _ = torch.onnx.utils._model_to_graph(
|
||||
model,
|
||||
(input,),
|
||||
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
@ -142,7 +140,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
self.assertEqual(next(iter).kind(), "prim::PythonOp")
|
||||
|
||||
# Test ATEN_FALLBACK_MODE
|
||||
graph, _, _ = _model_to_graph(
|
||||
graph, _, _ = torch.onnx.utils._model_to_graph(
|
||||
model,
|
||||
(input,),
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
import torch.onnx
|
||||
from torch.nn import Module
|
||||
from torch.onnx import producer_name, producer_version
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ import onnxscript
|
||||
from onnxscript.onnx_types import FLOAT
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ import onnxscript
|
||||
from onnxscript.onnx_types import FLOAT
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -4,8 +4,11 @@ import pytorch_test_common
|
||||
from pytorch_test_common import skipIfNoCuda
|
||||
|
||||
import torch
|
||||
from torch.onnx import verification
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter.utils import (
|
||||
_trigger_symbolic_function_registration,
|
||||
)
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@ -20,6 +23,7 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
|
||||
"""
|
||||
|
||||
GLOBALS.export_onnx_opset_version = opset_version
|
||||
_trigger_symbolic_function_registration()
|
||||
graph = torch.onnx.utils._optimize_graph(
|
||||
graph, operator_export_type, params_dict={}
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.onnx import symbolic_helper, utils
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal.torchscript_exporter import registration
|
||||
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
||||
|
||||
|
||||
@ -430,9 +430,8 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
||||
torch.randn(3, 4, requires_grad=True),
|
||||
mocks=[
|
||||
unittest.mock.patch(
|
||||
"torch.onnx._internal.registration.registry.get_function_group",
|
||||
"torch.onnx._internal.torchscript_exporter.registration.registry.get_function_group",
|
||||
side_effect=break_is_registered_op_api,
|
||||
# wraps=registration.registry.get_function_group
|
||||
)
|
||||
],
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
|
@ -41,7 +41,9 @@ from pytorch_test_common import (
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.utils import rnn as rnn_utils
|
||||
from torch.onnx import errors, verification
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.onnx._internal.torchscript_exporter._type_utils import JitScalarType
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
|
||||
@ -13705,7 +13707,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
input_names=["x"],
|
||||
)
|
||||
exported = onnx.load_from_string(f.getvalue())
|
||||
expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
|
||||
expected_elem_type = JitScalarType.from_value(x).onnx_type()
|
||||
expected_output_type = onnx.helper.make_optional_type_proto(
|
||||
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
|
||||
)
|
||||
|
@ -10,8 +10,8 @@ from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import onnx
|
||||
|
||||
@ -23,7 +21,7 @@ import torch
|
||||
import torch.onnx
|
||||
import torch.utils.cpp_extension
|
||||
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.onnx.symbolic_helper import _unpack_list, parse_args
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
@ -86,86 +84,6 @@ class _BaseTestCase(pytorch_test_common.ExportTestCase):
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
|
||||
"""Unit tests for the `unconvertible_ops` function."""
|
||||
|
||||
def setUp(self):
|
||||
class EinsumModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.einsum("ii", x)
|
||||
|
||||
self.einsum_module = EinsumModule()
|
||||
|
||||
def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
|
||||
graph, unconvertible_ops = utils.unconvertible_ops(
|
||||
self.einsum_module, (x,), opset_version=9
|
||||
)
|
||||
nodes = graph.nodes()
|
||||
self.assertEqual(next(nodes).kind(), "prim::Constant")
|
||||
self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
|
||||
self.assertEqual(next(nodes).kind(), "prim::Constant")
|
||||
self.assertEqual(next(nodes).kind(), "aten::einsum")
|
||||
self.assertEqual(unconvertible_ops, ["aten::einsum"])
|
||||
|
||||
@common_utils.parametrize(
|
||||
"jit_function",
|
||||
[
|
||||
common_utils.subtest(
|
||||
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
|
||||
name="traced",
|
||||
),
|
||||
common_utils.subtest(torch.jit.script, name="scripted"),
|
||||
],
|
||||
)
|
||||
def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
|
||||
self, jit_function: Callable
|
||||
):
|
||||
module = jit_function(self.einsum_module)
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
|
||||
self.assertEqual(unconvertible_ops, ["aten::einsum"])
|
||||
|
||||
@common_utils.parametrize(
|
||||
"jit_function",
|
||||
[
|
||||
common_utils.subtest(lambda x: x, name="nn_module"),
|
||||
common_utils.subtest(
|
||||
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
|
||||
name="traced",
|
||||
),
|
||||
common_utils.subtest(torch.jit.script, name="scripted"),
|
||||
],
|
||||
)
|
||||
def test_it_returns_empty_list_when_all_ops_convertible(
|
||||
self, jit_function: Callable
|
||||
):
|
||||
module = jit_function(self.einsum_module)
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
|
||||
self.assertEqual(unconvertible_ops, [])
|
||||
|
||||
def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
|
||||
class SkipConnectionModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out += x
|
||||
out = torch.nn.functional.relu(out, inplace=True)
|
||||
return out
|
||||
|
||||
module = SkipConnectionModule()
|
||||
x = torch.randn(4, 4)
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
|
||||
self.assertEqual(unconvertible_ops, [])
|
||||
|
||||
|
||||
@parameterized.parameterized_class(
|
||||
[
|
||||
{"opset_version": opset}
|
||||
|
@ -13,7 +13,8 @@ import pytorch_test_common
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants, _experimental, verification
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal.torchscript_exporter import _experimental, verification
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -21,8 +21,6 @@ import torch.utils.cpp_extension
|
||||
import torch.utils.data
|
||||
from torch._utils import try_import
|
||||
from torch._utils_internal import deprecated
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
@ -790,65 +788,6 @@ class TestCollectEnv(TestCase):
|
||||
self.assertTrue(info_output.count("\n") >= 17)
|
||||
|
||||
|
||||
class TestONNXUtils(TestCase):
|
||||
def test_prepare_onnx_paddings(self):
|
||||
sizes = [2, 3, 4]
|
||||
pad = [1, 2, 3, 4]
|
||||
paddings = _prepare_onnx_paddings(len(sizes), pad)
|
||||
self.assertEqual(paddings, [0, 3, 1, 0, 4, 2])
|
||||
|
||||
def test_check_onnx_broadcast(self):
|
||||
def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail):
|
||||
broadcast = True
|
||||
fail = False
|
||||
try:
|
||||
broadcast = check_onnx_broadcast(dims1, dims2)
|
||||
except ValueError:
|
||||
fail = True
|
||||
self.assertEqual(broadcast, expect_broadcast)
|
||||
self.assertEqual(fail, expect_fail)
|
||||
|
||||
# Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1
|
||||
dims1 = [3, 4]
|
||||
dims2 = [2, 3, 4]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, True)
|
||||
|
||||
# Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1
|
||||
dims1 = [3, 4]
|
||||
dims2 = [1, 1, 1]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, False)
|
||||
|
||||
# Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1
|
||||
dims1 = [1, 1]
|
||||
dims2 = [1]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, False)
|
||||
|
||||
# Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2
|
||||
dims1 = [2, 3, 4]
|
||||
dims2 = [3, 4]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, False)
|
||||
|
||||
# Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2
|
||||
dims1 = [2, 3, 4]
|
||||
dims2 = [1, 4]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, True)
|
||||
|
||||
# Case 6, check the equal case, no broadcast
|
||||
dims1 = [3, 4]
|
||||
dims2 = [3, 4]
|
||||
try_check_onnx_broadcast(dims1, dims2, False, False)
|
||||
|
||||
# Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2
|
||||
dims1 = [3, 4]
|
||||
dims2 = [1, 4]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, True)
|
||||
|
||||
# Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1
|
||||
dims1 = [3, 4]
|
||||
dims2 = [1, 1]
|
||||
try_check_onnx_broadcast(dims1, dims2, True, False)
|
||||
|
||||
|
||||
class TestHipify(TestCase):
|
||||
def test_import_hipify(self):
|
||||
from torch.utils.hipify import hipify_python # noqa: F401
|
||||
|
@ -1,6 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
|
||||
def maybe_view(tensor, size, check_same_size=True):
|
||||
@ -26,38 +24,3 @@ def maybe_unexpand(tensor, old_size, check_same_size=True):
|
||||
for dim in expanded_dims:
|
||||
tensor = tensor.sum(dim, keepdim=True)
|
||||
return tensor
|
||||
|
||||
|
||||
# Check whether the op enable broadcasting, and whether it is supported by ONNX.
|
||||
# If dims1 and dims2 are different, then broadcast is True.
|
||||
# We always assume the combination of dims1 and dims2 is broadcastable.
|
||||
# The following types of broadcasting are supported in ONNX:
|
||||
# 1) Only one element in dims2, such as dims2 = [1, 1]
|
||||
# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4]
|
||||
# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
|
||||
def check_onnx_broadcast(dims1, dims2):
|
||||
broadcast = False
|
||||
supported = True
|
||||
len1 = len(dims1)
|
||||
len2 = len(dims2)
|
||||
|
||||
numel2 = reduce(operator.mul, dims2)
|
||||
if len1 < len2:
|
||||
broadcast = True
|
||||
if numel2 != 1:
|
||||
supported = False
|
||||
elif len1 > len2:
|
||||
broadcast = True
|
||||
if numel2 != 1 and dims1[len1 - len2 :] != dims2:
|
||||
supported = False
|
||||
else:
|
||||
if dims1 != dims2:
|
||||
broadcast = True
|
||||
if numel2 != 1:
|
||||
supported = False
|
||||
|
||||
if not supported:
|
||||
raise ValueError(
|
||||
f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}"
|
||||
)
|
||||
return broadcast
|
||||
|
@ -1053,7 +1053,8 @@ void _trace_post_record(
|
||||
}
|
||||
}
|
||||
}
|
||||
py::object onnx_globals = py::module::import("torch.onnx._globals");
|
||||
py::object onnx_globals =
|
||||
py::module::import("torch.onnx._internal.torchscript_exporter._globals");
|
||||
py::bool_ is_in_onnx_export =
|
||||
py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
|
||||
py::bool_ is_autograd_inlining_enabled =
|
||||
|
@ -261,9 +261,10 @@ void NodeToONNX(
|
||||
py::dict& env,
|
||||
py::set& values_in_env) {
|
||||
py::object onnx = py::module::import("torch.onnx");
|
||||
py::object onnx_globals = py::module::import("torch.onnx._globals");
|
||||
py::object onnx_registration =
|
||||
py::module::import("torch.onnx._internal.registration");
|
||||
py::object onnx_globals =
|
||||
py::module::import("torch.onnx._internal.torchscript_exporter._globals");
|
||||
py::object onnx_registration = py::module::import(
|
||||
"torch.onnx._internal.torchscript_exporter.registration");
|
||||
|
||||
// Setup all the lambda helper functions.
|
||||
|
||||
|
@ -4,92 +4,3 @@ Torch->ONNX converter / exporter.
|
||||
|
||||
- User-facing docs: https://pytorch.org/docs/main/onnx.html
|
||||
- Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter
|
||||
|
||||
> Read the following if you are contributing to `torch.onnx`
|
||||
|
||||
## Symbolic functions Opsets
|
||||
|
||||
Opset 9 is the base version. It is selected as the base version because
|
||||
|
||||
1. It is the first opset version supported by PyTorch export.
|
||||
2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
||||
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
||||
we chose to handle them as special cases separately.
|
||||
|
||||
Backward support for opset versions beyond opset 7 is not in our roadmap.
|
||||
|
||||
For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
||||
symbolic_opset9.py.
|
||||
|
||||
To extend support for updated operators in different opset versions on top of opset 9,
|
||||
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
||||
Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
||||
|
||||
## Editing Symbolic Files
|
||||
|
||||
- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example.
|
||||
- Parameter names must *exactly* match the names in
|
||||
aten/src/ATen/native/native_functions.yaml, because
|
||||
dispatch is done with keyword arguments.
|
||||
- Looking for inplace ops? They're detected by
|
||||
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
|
||||
transparently dispatched to their non inplace versions in
|
||||
"run_symbolic_function". See Note [Export inplace](#export-inplace)
|
||||
|
||||
### A note on Tensor types
|
||||
|
||||
In general, we should avoid depending on the type of Tensor Values contained
|
||||
within the trace graph. However, this is sometimes unavoidable (due to ONNX
|
||||
spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
|
||||
|
||||
In general, we should prefer to rely on the least specific information possible.
|
||||
For example, not relying on tensor properties at all is better than relying
|
||||
on the number of dimensions which is better than relying on
|
||||
concrete shapes. Doing so will make the export symbolics
|
||||
more robust to different graphs.
|
||||
|
||||
### Extra context for symbolic functions
|
||||
|
||||
The first argument of a symbolic function is always a `GraphContext` object.
|
||||
|
||||
`GraphContext` contains all methods defined in a `torch.Graph` object and context
|
||||
for the symbolic function.
|
||||
|
||||
In general, symbolic functions only require inputs and attributes to
|
||||
the original node. An example of a symbolic function needing context is
|
||||
`prim::Loop`. It needs access to the sub-block of the original node.
|
||||
|
||||
### Export inplace
|
||||
|
||||
It would be better for us to export inplace annotations,
|
||||
than to not export them, since it is useful information that can
|
||||
help the target of an ONNX export export more efficiently. However,
|
||||
ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
|
||||
inplace annotations, but we are losing information this way.
|
||||
|
||||
### Pointwise by scalar
|
||||
|
||||
What happens if you add a tensor with a constant (e.g., x + 2)? There are
|
||||
some moving parts to implementing the ONNX translation in this case:
|
||||
|
||||
- By the time we get the scalar in a symbolic function here, it is no longer a
|
||||
Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want
|
||||
it to be a zero dim tensor but this change has not happened yet.) However, the
|
||||
type of this scalar is *exactly* what the user wrote in Python, which may not
|
||||
match the tensor it is being added to. PyTorch will do implicit conversions on
|
||||
scalars; however, ONNX will not, so we must do the conversion ourselves. This
|
||||
is what `symbolic_helper._if_scalar_type_as()` and
|
||||
`_jit_pass_onnx_scalar_type_analysis` does.
|
||||
|
||||
- Dispatch to these functions takes advantage an outrageous coincidence
|
||||
between the tensor and scalar name. When we add two tensors together,
|
||||
you get the dispatch:
|
||||
|
||||
add(*[self, other], **{"alpha": alpha})
|
||||
|
||||
When you add a tensor and a scalar, you get the dispatch:
|
||||
|
||||
add(*[self], **{"other": other, "alpha": alpha})
|
||||
|
||||
By having the argument name line up with the name of the scalar attribute
|
||||
if it exists, we can write a single function for both overloads.
|
||||
|
@ -6,78 +6,43 @@ __all__ = [
|
||||
# Modules
|
||||
"errors",
|
||||
"ops",
|
||||
"symbolic_helper",
|
||||
"utils",
|
||||
# All opsets
|
||||
"symbolic_opset7",
|
||||
"symbolic_opset8",
|
||||
"symbolic_opset9",
|
||||
"symbolic_opset10",
|
||||
"symbolic_opset11",
|
||||
"symbolic_opset12",
|
||||
"symbolic_opset13",
|
||||
"symbolic_opset14",
|
||||
"symbolic_opset15",
|
||||
"symbolic_opset16",
|
||||
"symbolic_opset17",
|
||||
"symbolic_opset18",
|
||||
"symbolic_opset19",
|
||||
"symbolic_opset20",
|
||||
# Enums
|
||||
"OperatorExportTypes",
|
||||
"TrainingMode",
|
||||
"TensorProtoDataType",
|
||||
"JitScalarType",
|
||||
# Public functions
|
||||
"export",
|
||||
"is_in_onnx_export",
|
||||
"select_model_mode_for_export",
|
||||
"register_custom_op_symbolic",
|
||||
"unregister_custom_op_symbolic",
|
||||
# Base error
|
||||
"OnnxExporterError",
|
||||
"ONNXProgram",
|
||||
]
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
from torch._C._onnx import ( # Deprecated members that are excluded from __all__
|
||||
OperatorExportTypes as OperatorExportTypes,
|
||||
TensorProtoDataType as TensorProtoDataType,
|
||||
TrainingMode as TrainingMode,
|
||||
)
|
||||
|
||||
from . import errors, ops
|
||||
from ._internal.exporter._onnx_program import ONNXProgram
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import OnnxExporterError
|
||||
from .utils import (
|
||||
from ._internal.torchscript_exporter import ( # Deprecated members that are excluded from __all__
|
||||
symbolic_helper,
|
||||
symbolic_opset10,
|
||||
symbolic_opset9,
|
||||
utils,
|
||||
)
|
||||
from ._internal.torchscript_exporter._type_utils import (
|
||||
JitScalarType, # Deprecated members that are excluded from __all__
|
||||
)
|
||||
from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__
|
||||
_run_symbolic_function,
|
||||
_run_symbolic_method,
|
||||
register_custom_op_symbolic,
|
||||
select_model_mode_for_export,
|
||||
unregister_custom_op_symbolic,
|
||||
)
|
||||
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
errors,
|
||||
ops,
|
||||
symbolic_helper,
|
||||
symbolic_opset7,
|
||||
symbolic_opset8,
|
||||
symbolic_opset9,
|
||||
symbolic_opset10,
|
||||
symbolic_opset11,
|
||||
symbolic_opset12,
|
||||
symbolic_opset13,
|
||||
symbolic_opset14,
|
||||
symbolic_opset15,
|
||||
symbolic_opset16,
|
||||
symbolic_opset17,
|
||||
symbolic_opset18,
|
||||
symbolic_opset19,
|
||||
symbolic_opset20,
|
||||
utils,
|
||||
)
|
||||
from .errors import OnnxExporterError
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -85,10 +50,10 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Mapping, Sequence
|
||||
|
||||
# Set namespace for exposed private names
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
ONNXProgram.__module__ = "torch.onnx"
|
||||
OnnxExporterError.__module__ = "torch.onnx"
|
||||
|
||||
# TODO(justinchuby): Remove these two properties
|
||||
producer_name = "pytorch"
|
||||
producer_version = _C_onnx.PRODUCER_VERSION
|
||||
|
||||
@ -385,7 +350,7 @@ def export(
|
||||
else:
|
||||
import warnings
|
||||
|
||||
from torch.onnx.utils import export
|
||||
from ._internal.torchscript_exporter.utils import export
|
||||
|
||||
warnings.warn(
|
||||
"You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, "
|
||||
@ -429,7 +394,7 @@ def export(
|
||||
|
||||
def is_in_onnx_export() -> bool:
|
||||
"""Returns whether it is in the middle of ONNX export."""
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.exporter import _flags
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
|
||||
return GLOBALS.in_onnx_export or _flags._is_onnx_exporting
|
||||
|
91
torch/onnx/_internal/torchscript_exporter/README.md
Normal file
91
torch/onnx/_internal/torchscript_exporter/README.md
Normal file
@ -0,0 +1,91 @@
|
||||
# TorchScript Exporter
|
||||
|
||||
> [!NOTE]
|
||||
> This directory hosts code for the legacy TorchScript-based ONNX exporter. It is *deprecated* since PyTorch 2.9 and should be removed along with TorchScript.
|
||||
|
||||
## Symbolic functions Opsets
|
||||
|
||||
Opset 9 is the base version. It is selected as the base version because
|
||||
|
||||
1. It is the first opset version supported by PyTorch export.
|
||||
2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
||||
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
||||
we chose to handle them as special cases separately.
|
||||
|
||||
Backward support for opset versions beyond opset 7 is not in our roadmap.
|
||||
|
||||
For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
||||
symbolic_opset9.py.
|
||||
|
||||
To extend support for updated operators in different opset versions on top of opset 9,
|
||||
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
||||
Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
||||
|
||||
## Editing Symbolic Files
|
||||
|
||||
- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example.
|
||||
- Parameter names must *exactly* match the names in
|
||||
aten/src/ATen/native/native_functions.yaml, because
|
||||
dispatch is done with keyword arguments.
|
||||
- Looking for inplace ops? They're detected by
|
||||
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
|
||||
transparently dispatched to their non inplace versions in
|
||||
"run_symbolic_function". See Note [Export inplace](#export-inplace)
|
||||
|
||||
### A note on Tensor types
|
||||
|
||||
In general, we should avoid depending on the type of Tensor Values contained
|
||||
within the trace graph. However, this is sometimes unavoidable (due to ONNX
|
||||
spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
|
||||
|
||||
In general, we should prefer to rely on the least specific information possible.
|
||||
For example, not relying on tensor properties at all is better than relying
|
||||
on the number of dimensions which is better than relying on
|
||||
concrete shapes. Doing so will make the export symbolics
|
||||
more robust to different graphs.
|
||||
|
||||
### Extra context for symbolic functions
|
||||
|
||||
The first argument of a symbolic function is always a `GraphContext` object.
|
||||
|
||||
`GraphContext` contains all methods defined in a `torch.Graph` object and context
|
||||
for the symbolic function.
|
||||
|
||||
In general, symbolic functions only require inputs and attributes to
|
||||
the original node. An example of a symbolic function needing context is
|
||||
`prim::Loop`. It needs access to the sub-block of the original node.
|
||||
|
||||
### Export inplace
|
||||
|
||||
It would be better for us to export inplace annotations,
|
||||
than to not export them, since it is useful information that can
|
||||
help the target of an ONNX export export more efficiently. However,
|
||||
ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
|
||||
inplace annotations, but we are losing information this way.
|
||||
|
||||
### Pointwise by scalar
|
||||
|
||||
What happens if you add a tensor with a constant (e.g., x + 2)? There are
|
||||
some moving parts to implementing the ONNX translation in this case:
|
||||
|
||||
- By the time we get the scalar in a symbolic function here, it is no longer a
|
||||
Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want
|
||||
it to be a zero dim tensor but this change has not happened yet.) However, the
|
||||
type of this scalar is *exactly* what the user wrote in Python, which may not
|
||||
match the tensor it is being added to. PyTorch will do implicit conversions on
|
||||
scalars; however, ONNX will not, so we must do the conversion ourselves. This
|
||||
is what `symbolic_helper._if_scalar_type_as()` and
|
||||
`_jit_pass_onnx_scalar_type_analysis` does.
|
||||
|
||||
- Dispatch to these functions takes advantage an outrageous coincidence
|
||||
between the tensor and scalar name. When we add two tensors together,
|
||||
you get the dispatch:
|
||||
|
||||
add(*[self, other], **{"alpha": alpha})
|
||||
|
||||
When you add a tensor and a scalar, you get the dispatch:
|
||||
|
||||
add(*[self], **{"other": other, "alpha": alpha})
|
||||
|
||||
By having the argument name line up with the name of the scalar attribute
|
||||
if it exists, we can write a single function for both overloads.
|
@ -1,9 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Utilities for manipulating the torch.Graph object and the torchscript."""
|
||||
|
||||
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
|
||||
# them to the user.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
@ -14,8 +11,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal.torchscript_exporter import registration
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
|
||||
|
||||
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
||||
@ -89,7 +86,6 @@ class GraphContext:
|
||||
The value representing the single output of this operator (see the `outputs`
|
||||
keyword argument for multi-return nodes).
|
||||
"""
|
||||
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
|
||||
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
|
||||
|
||||
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
|
||||
@ -211,8 +207,6 @@ def _add_op(
|
||||
The set of operators and the inputs/attributes they take
|
||||
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
||||
|
||||
This function is monkey-patched onto Graph.
|
||||
|
||||
Args:
|
||||
graph_context: The Torch Graph or Block.
|
||||
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
||||
@ -337,7 +331,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
return getattr(node, f"{kind}_")(name, value)
|
||||
|
||||
|
||||
# TODO: Expose this to user when migrating symbolic helper functions to here.
|
||||
def _is_tensor(x: _C.Value) -> bool:
|
||||
return x.type().isSubtypeOf(_C.TensorType.get())
|
||||
|
@ -9,10 +9,9 @@ import shutil
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils, registration
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
2380
torch/onnx/_internal/torchscript_exporter/symbolic_helper.py
Normal file
2380
torch/onnx/_internal/torchscript_exporter/symbolic_helper.py
Normal file
File diff suppressed because it is too large
Load Diff
1187
torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py
Normal file
1187
torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py
Normal file
File diff suppressed because it is too large
Load Diff
1472
torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py
Normal file
1472
torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py
Normal file
File diff suppressed because it is too large
Load Diff
465
torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py
Normal file
465
torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py
Normal file
@ -0,0 +1,465 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
# This file exports ONNX ops for opset 12
|
||||
|
||||
__all__ = [
|
||||
"argmax",
|
||||
"argmin",
|
||||
"binary_cross_entropy_with_logits",
|
||||
"celu",
|
||||
"cross_entropy_loss",
|
||||
"dropout",
|
||||
"einsum",
|
||||
"ge",
|
||||
"le",
|
||||
"native_dropout",
|
||||
"nll_loss",
|
||||
"nll_loss2d",
|
||||
"nll_loss_nd",
|
||||
"outer",
|
||||
"pow",
|
||||
"tensordot",
|
||||
"unfold",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
|
||||
|
||||
|
||||
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
|
||||
if not tensors:
|
||||
raise RuntimeError("Einsum inputs are empty.")
|
||||
# ONNX does not support bool for Einsum inputs.
|
||||
if symbolic_helper._is_bool(tensors[0]):
|
||||
tensors = [
|
||||
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
for tensor in tensors
|
||||
]
|
||||
return g.op(
|
||||
"Cast",
|
||||
g.op("Einsum", *tensors, equation_s=equation),
|
||||
to_i=_C_onnx.TensorProtoDataType.BOOL,
|
||||
)
|
||||
else:
|
||||
return g.op("Einsum", *tensors, equation_s=equation)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::einsum")
|
||||
@symbolic_helper.parse_args("s", "v", "is")
|
||||
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
|
||||
tensors = symbolic_helper._unpack_list(tensor_list)
|
||||
return _einsum_helper(g, equation, tensors)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::outer")
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
def outer(g: jit_utils.GraphContext, input, other):
|
||||
# make sure to cast other to self's type
|
||||
if _type_utils.JitScalarType.from_value(
|
||||
other, _type_utils.JitScalarType.UNDEFINED
|
||||
) != _type_utils.JitScalarType.from_value(input):
|
||||
other = g.op(
|
||||
"Cast",
|
||||
other,
|
||||
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
|
||||
)
|
||||
return _einsum_helper(g, "i,j->ij", [input, other])
|
||||
|
||||
|
||||
def _dropout_returns_masked_input_and_mask(
|
||||
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
|
||||
) -> tuple[torch._C.Value, torch._C.Value | None]:
|
||||
symbolic_helper.check_training_mode(train, "dropout")
|
||||
# In eval mode, dropout is non-op. That is, if the node's
|
||||
# train param is set to False, dropout just returns its inputs.
|
||||
if not train:
|
||||
return input, None
|
||||
p = g.op("Constant", value_t=torch.tensor(p))
|
||||
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
|
||||
r, mask = g.op("Dropout", input, p, t, outputs=2)
|
||||
return r, mask
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
def dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
return masked
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::native_dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
def native_dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
return _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss")
|
||||
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
# sum reduction : onnx::Constant[value={2}]
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
reduction = reduction_vals[reduction]
|
||||
|
||||
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
|
||||
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
||||
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
||||
if weight.node().mustBeNone():
|
||||
nllloss = g.op(
|
||||
"NegativeLogLikelihoodLoss",
|
||||
self,
|
||||
target,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
else:
|
||||
nllloss = g.op(
|
||||
"NegativeLogLikelihoodLoss",
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
|
||||
return nllloss
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss2d")
|
||||
def nll_loss2d(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss_nd")
|
||||
def nll_loss_nd(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::cross_entropy_loss")
|
||||
def cross_entropy_loss(
|
||||
g: jit_utils.GraphContext,
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
# sum reduction : onnx::Constant[value={2}]
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
reduction = reduction_vals[reduction]
|
||||
|
||||
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
|
||||
if label_smoothing is not None and label_smoothing > 0.0:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX does not support label_smoothing", self
|
||||
)
|
||||
|
||||
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
|
||||
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
||||
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
||||
if weight.node().mustBeNone():
|
||||
celoss = g.op(
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
self,
|
||||
target,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
else:
|
||||
celoss = g.op(
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
|
||||
return celoss
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
||||
def binary_cross_entropy_with_logits(
|
||||
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
|
||||
):
|
||||
p = g.op("Constant", value_t=torch.tensor([1]))
|
||||
sig_x = opset9.sigmoid(g, input)
|
||||
log_sig_x = opset9.log(g, sig_x)
|
||||
sub_1_x = opset9.sub(g, p, sig_x)
|
||||
sub_1_y = opset9.sub(g, p, target)
|
||||
log_1_x = opset9.log(g, sub_1_x)
|
||||
if pos_weight is None or symbolic_helper._is_none(pos_weight):
|
||||
output = opset9.neg(
|
||||
g,
|
||||
opset9.add(
|
||||
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = opset9.neg(
|
||||
g,
|
||||
opset9.add(
|
||||
g,
|
||||
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
|
||||
opset9.mul(g, sub_1_y, log_1_x),
|
||||
),
|
||||
)
|
||||
|
||||
if weight is not None and not symbolic_helper._is_none(weight):
|
||||
output = opset9.mul(g, weight, output)
|
||||
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
if reduction == 0:
|
||||
return output
|
||||
elif reduction == 1:
|
||||
return g.op("ReduceMean", output, keepdims_i=0)
|
||||
elif reduction == 2:
|
||||
return g.op("ReduceSum", output, keepdims_i=0)
|
||||
else:
|
||||
return symbolic_helper._onnx_unsupported(
|
||||
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
|
||||
input,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::celu")
|
||||
def celu(g: jit_utils.GraphContext, self, alpha):
|
||||
alpha = symbolic_helper._maybe_get_const(alpha, "f")
|
||||
# if the input is of type double cast it to float
|
||||
if (
|
||||
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
|
||||
== _type_utils.JitScalarType.DOUBLE
|
||||
):
|
||||
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
out = g.op("Celu", self, alpha_f=alpha)
|
||||
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||||
|
||||
return g.op("Celu", self, alpha_f=alpha)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::argmax")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def argmax(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
dim: torch._C.Value,
|
||||
keepdim: bool,
|
||||
):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::argmin")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def argmin(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
dim: torch._C.Value,
|
||||
keepdim: bool,
|
||||
):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::pow")
|
||||
def pow(g: jit_utils.GraphContext, self, exponent):
|
||||
return g.op("Pow", self, exponent)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ge")
|
||||
def ge(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("GreaterOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::le")
|
||||
def le(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("LessOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unfold")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||||
const_size = symbolic_helper._maybe_get_const(size, "i")
|
||||
const_step = symbolic_helper._maybe_get_const(step, "i")
|
||||
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
|
||||
const_step
|
||||
):
|
||||
return opset9.unfold(g, input, dimension, const_size, const_step)
|
||||
|
||||
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
|
||||
if sizedim is not None:
|
||||
low_start = g.op("Constant", value_t=torch.tensor(0))
|
||||
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
|
||||
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
|
||||
low_indices = g.op("Range", low_start, low_end, step)
|
||||
hi_indices = g.op("Range", size, hi_end, step)
|
||||
|
||||
low_size = symbolic_helper._size_helper(
|
||||
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
|
||||
)
|
||||
hi_size = symbolic_helper._size_helper(
|
||||
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
|
||||
)
|
||||
|
||||
ndim = symbolic_helper._get_tensor_rank(input)
|
||||
assert ndim is not None
|
||||
perm = list(range(0, ndim))
|
||||
perm.append(perm.pop(dimension))
|
||||
|
||||
unsqueeze_list = []
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op(
|
||||
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
loop_len = g.op("Min", low_size, hi_size)
|
||||
|
||||
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", loop_len, loop_condition, n_blocks=1
|
||||
)
|
||||
|
||||
loop_block = loop_context.block
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block) # noqa: F841
|
||||
|
||||
starts = loop_context.op("Gather", low_indices, block_input_iter)
|
||||
ends = loop_context.op("Gather", hi_indices, block_input_iter)
|
||||
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
|
||||
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
|
||||
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
|
||||
stack = loop_context.op("Slice", input, starts, ends, axes)
|
||||
|
||||
unsqueeze = symbolic_helper._unsqueeze_helper(
|
||||
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
|
||||
)
|
||||
unsqueeze_list.append(unsqueeze)
|
||||
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
|
||||
|
||||
cond_out = loop_context.op(
|
||||
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
utils._add_output_to_block(loop_block, cond_out)
|
||||
utils._add_output_to_block(loop_block, concat)
|
||||
|
||||
loop_output = loop.node().output()
|
||||
perm = [0, 1, 2, 3, 4]
|
||||
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
|
||||
transpose = g.op("Transpose", loop_output, perm_i=perm)
|
||||
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
|
||||
|
||||
return squeeze
|
||||
|
||||
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tensordot")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
|
||||
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
|
||||
if out is not None:
|
||||
symbolic_helper._unimplemented(
|
||||
"Tensordot", "Out parameter is not supported for tensordot."
|
||||
)
|
||||
|
||||
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
|
||||
if dim_count_a is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
|
||||
input_a,
|
||||
)
|
||||
|
||||
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
|
||||
if dim_count_b is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
|
||||
input_b,
|
||||
)
|
||||
|
||||
dims_a = [
|
||||
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
|
||||
for i in range(len(dims_a))
|
||||
]
|
||||
dims_b = [
|
||||
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
|
||||
for i in range(len(dims_b))
|
||||
]
|
||||
|
||||
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
|
||||
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
|
||||
|
||||
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
|
||||
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
|
||||
|
||||
input_shape = g.op("Shape", new_input_a)
|
||||
left_sizes_a = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
|
||||
)
|
||||
shape_sizes = [
|
||||
left_sizes_a,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
]
|
||||
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", output_a)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
||||
)
|
||||
shape_sizes = [
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
slices,
|
||||
]
|
||||
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", new_input_b)
|
||||
left_sizes_b = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
|
||||
)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
|
||||
)
|
||||
shape_sizes = [
|
||||
slices,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
]
|
||||
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", output_b)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
||||
)
|
||||
shape_sizes = [
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
slices,
|
||||
]
|
||||
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
||||
|
||||
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
|
||||
|
||||
shape_sizes = [left_sizes_a, left_sizes_b]
|
||||
return opset9._reshape_from_tensor(g, output, shape_sizes)
|
1113
torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py
Normal file
1113
torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py
Normal file
File diff suppressed because it is too large
Load Diff
296
torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py
Normal file
296
torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py
Normal file
@ -0,0 +1,296 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 14.
|
||||
|
||||
Note [ONNX operators that are added/updated in opset 14]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
HardSwish, Trilu
|
||||
|
||||
Updated operators:
|
||||
Reshape
|
||||
Add, Sub, Mul, Div
|
||||
GRU, LSTM, RNN
|
||||
BatchNorm, Cumsum, Relu
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
)
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
|
||||
|
||||
__all__ = [
|
||||
"hardswish",
|
||||
"tril",
|
||||
"triu",
|
||||
"reshape",
|
||||
"batch_norm",
|
||||
"quantized_hardswish",
|
||||
"scaled_dot_product_attention",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::hardswish")
|
||||
@symbolic_helper.parse_args("v")
|
||||
def hardswish(g: jit_utils.GraphContext, self):
|
||||
return g.op("HardSwish", self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tril")
|
||||
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::triu")
|
||||
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::reshape")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
def reshape(g: jit_utils.GraphContext, self, shape):
|
||||
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
|
||||
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
|
||||
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::batch_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
|
||||
def batch_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
cudnn_enabled,
|
||||
):
|
||||
if (
|
||||
torch.is_autocast_enabled()
|
||||
and not symbolic_helper.args_have_same_dtype(
|
||||
[input, weight, bias, running_mean, running_var]
|
||||
)
|
||||
and GLOBALS.export_onnx_opset_version < 15
|
||||
):
|
||||
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||||
"BatchNormalization",
|
||||
14,
|
||||
15,
|
||||
"All input tensors must have the same `dtype`."
|
||||
" Turn off Autocast or export using opset version 15.",
|
||||
input,
|
||||
)
|
||||
|
||||
symbolic_helper.check_training_mode(training, "batch_norm")
|
||||
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
|
||||
g, input, weight, bias, running_mean, running_var
|
||||
)
|
||||
out = g.op(
|
||||
"BatchNormalization",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
epsilon_f=eps,
|
||||
momentum_f=1 - momentum,
|
||||
training_mode_i=0 if not training else 1,
|
||||
outputs=1 if not training else 3,
|
||||
)
|
||||
if not training:
|
||||
return out
|
||||
else:
|
||||
res, new_running_mean, new_running_var = out
|
||||
new_running_mean.setType(running_mean.type())
|
||||
new_running_var.setType(running_var.type())
|
||||
return res
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::hardswish")
|
||||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
output = hardswish(g, x)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
|
||||
# Ported from
|
||||
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
|
||||
# aten_scaled_dot_product_attention
|
||||
# NOTE: Need op.Trilu
|
||||
@_onnx_symbolic("aten::scaled_dot_product_attention")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
|
||||
def scaled_dot_product_attention(
|
||||
g: jit_utils.GraphContext,
|
||||
query: torch._C.Value,
|
||||
key: torch._C.Value,
|
||||
value: torch._C.Value,
|
||||
attn_mask: torch._C.Value | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: torch._C.Value | None = None,
|
||||
enable_gqa: bool = False,
|
||||
):
|
||||
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
||||
"is_causal and attn_mask cannot be set at the same time"
|
||||
)
|
||||
assert not enable_gqa, (
|
||||
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
)
|
||||
|
||||
if symbolic_helper._is_none(scale):
|
||||
scale = _attention_scale(g, query)
|
||||
|
||||
if is_causal:
|
||||
attn_mask = _causal_attention_mask(g, query, key)
|
||||
|
||||
# Swap the last two axes of key
|
||||
# NOTE: onnx-script has different logic here, because the attribute perms in
|
||||
# transpose needs list of ints
|
||||
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
|
||||
key_transposed_axes = list(range(key_shape_builtin))
|
||||
key_transposed_axes[-1], key_transposed_axes[-2] = (
|
||||
key_transposed_axes[-2],
|
||||
key_transposed_axes[-1],
|
||||
)
|
||||
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
|
||||
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
|
||||
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
|
||||
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
|
||||
|
||||
if symbolic_helper._is_none(attn_mask):
|
||||
mul_qk_add = mul_qk
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
elif (
|
||||
_type_utils.JitScalarType.from_value(attn_mask)
|
||||
== _type_utils.JitScalarType.BOOL
|
||||
):
|
||||
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
|
||||
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
|
||||
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
|
||||
# the behavior of PyTorch with boolean masks.
|
||||
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
|
||||
elif _type_utils.JitScalarType.from_value(attn_mask) in (
|
||||
_type_utils.JitScalarType.FLOAT,
|
||||
_type_utils.JitScalarType.HALF,
|
||||
_type_utils.JitScalarType.BFLOAT16,
|
||||
):
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
|
||||
)
|
||||
|
||||
if dropout_p != 0:
|
||||
attn_weight = g.op(
|
||||
"Dropout",
|
||||
attn_weight,
|
||||
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
|
||||
)
|
||||
|
||||
return g.op("MatMul", attn_weight, value)
|
||||
|
||||
|
||||
def _attention_scale(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Calculate the scale factor for the attention result.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
|
||||
Returns:
|
||||
Scalar scale factor := 1 / math.sqrt(query.size(-1))
|
||||
"""
|
||||
query_shape = g.op("Shape", query)
|
||||
query_shape_last = g.op(
|
||||
"Slice",
|
||||
query_shape,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
|
||||
g.op(
|
||||
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
|
||||
),
|
||||
)
|
||||
embedding_size = g.op(
|
||||
"Cast",
|
||||
query_shape_last,
|
||||
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
|
||||
)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
|
||||
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
|
||||
# Add a Cast to convert the scale back to original type
|
||||
scale = g.op(
|
||||
"Cast",
|
||||
scale,
|
||||
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
|
||||
)
|
||||
return scale
|
||||
|
||||
|
||||
def _causal_attention_mask(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Create a causal mask for the given query and key tensors.
|
||||
|
||||
Equivalent to::
|
||||
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_mask = torch.zeros(L, S, dtype=torch.float)
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
key: Tensor of shape [..., S, E]
|
||||
|
||||
Returns:
|
||||
Tensor of shape [L, S]
|
||||
"""
|
||||
|
||||
query_shape = g.op("Shape", query)
|
||||
key_shape = g.op("Shape", key)
|
||||
|
||||
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
|
||||
target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
|
||||
source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
|
||||
# attn_mask = torch.ones(L, S) := {
|
||||
size = g.op("Concat", target_length, source_length, axis_i=0)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0]))
|
||||
attn_mask = g.op("Expand", const_one, size)
|
||||
# }
|
||||
attn_mask = g.op("Trilu", attn_mask, upper_i=0)
|
||||
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op(
|
||||
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
|
||||
)
|
||||
return attn_mask
|
@ -0,0 +1,84 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 15.
|
||||
|
||||
Note [ONNX operators that are added/updated in opset 15]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
Bernoulli
|
||||
CastLike
|
||||
Optional
|
||||
OptionalGetElement
|
||||
OptionalHasElement
|
||||
|
||||
Updated operators:
|
||||
BatchNormalization https://github.com/onnx/onnx/pull/3545
|
||||
Backwards compatible
|
||||
TODO: test coverage for mixed types inputs.
|
||||
Pow https://github.com/onnx/onnx/pull/3412
|
||||
Backwards compatible
|
||||
TODO: bfloat16 support.
|
||||
Shape https://github.com/onnx/onnx/pull/3580
|
||||
Backwards compatible
|
||||
TODO: optional start/end attribute.
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__is_")
|
||||
def aten__is_(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._is_none(other):
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
none = g.op("OptionalHasElement", self)
|
||||
return g.op("Not", none)
|
||||
else:
|
||||
return g.op("Constant", value_t=torch.BoolTensor([0]))
|
||||
return opset9.eq(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__isnot_")
|
||||
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
|
||||
def aten__isnot_(g: jit_utils.GraphContext, self, other):
|
||||
return aten__is_(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::bernoulli")
|
||||
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
|
||||
if out is not None and not symbolic_helper._is_none(out):
|
||||
symbolic_helper._unimplemented(
|
||||
"Bernoulli", "out parameter is not supported for bernoulli", input
|
||||
)
|
||||
if generator is not None and not symbolic_helper._is_none(generator):
|
||||
symbolic_helper._unimplemented(
|
||||
"Bernoulli", "generator is not supported for bernoulli", input
|
||||
)
|
||||
if p is None or symbolic_helper._is_none(p):
|
||||
return g.op("Bernoulli", input)
|
||||
return opset9.bernoulli(g, input, p, generator, out)
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::unchecked_cast")
|
||||
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
|
||||
# exists to refine the type of the Value
|
||||
# if x is Optional[Tensor], unchecked_cast will cast
|
||||
# x to Tensor, so the rest of the graph knows that x is a Tensor.
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
return g.op("OptionalGetElement", self)
|
||||
|
||||
return self
|
191
torch/onnx/_internal/torchscript_exporter/symbolic_opset16.py
Normal file
191
torch/onnx/_internal/torchscript_exporter/symbolic_opset16.py
Normal file
@ -0,0 +1,191 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 16.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 16]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
GridSample https://github.com/onnx/onnx/pull/3557
|
||||
|
||||
Updated operators:
|
||||
Identity
|
||||
If
|
||||
LeakyRelu
|
||||
Loop
|
||||
PRelu
|
||||
RoiAlign
|
||||
Scan
|
||||
ScatterElements
|
||||
ScatterND
|
||||
Where
|
||||
GreaterOrEqual
|
||||
LessOrEqual
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import (
|
||||
GRID_SAMPLE_INTERPOLATION_MODES,
|
||||
GRID_SAMPLE_PADDING_MODES,
|
||||
)
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||||
|
||||
|
||||
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
||||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
def grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
grid,
|
||||
mode_enum,
|
||||
padding_mode_enum,
|
||||
align_corners,
|
||||
):
|
||||
# Check the input and grid tensor rank beforehand.
|
||||
if symbolic_helper._get_tensor_rank(input) == 5:
|
||||
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
|
||||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
|
||||
padding_mode_enum
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::scatter_add")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
src_type = _type_utils.JitScalarType.from_value(
|
||||
src, _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
src_sizes = symbolic_helper._get_tensor_sizes(src)
|
||||
index_sizes = symbolic_helper._get_tensor_sizes(index)
|
||||
|
||||
if len(src_sizes) != len(index_sizes):
|
||||
return symbolic_helper._unimplemented(
|
||||
"scatter_add",
|
||||
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
|
||||
)
|
||||
|
||||
# PyTorch only allows index shape <= src shape, so we can only consider
|
||||
# taking index as subset size to src, like PyTorch does. When sizes for src
|
||||
# and index are not matched or there are dynamic axes, we take index shape to
|
||||
# slice src to accommodate.
|
||||
if src_sizes != index_sizes or None in index_sizes:
|
||||
adjusted_shape = g.op("Shape", index)
|
||||
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
|
||||
src = g.op("Slice", src, starts, adjusted_shape)
|
||||
|
||||
src = symbolic_helper._maybe_get_scalar(src)
|
||||
if symbolic_helper._is_value(src):
|
||||
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
|
||||
else:
|
||||
# Check if scalar "src" has same type as self (PyTorch allows different
|
||||
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
||||
if _type_utils.JitScalarType.from_value(self) != src_type:
|
||||
src = g.op(
|
||||
"Cast",
|
||||
src,
|
||||
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
||||
)
|
||||
|
||||
return g.op(
|
||||
"ScatterElements",
|
||||
self,
|
||||
index,
|
||||
src,
|
||||
axis_i=dim,
|
||||
reduction_s="add",
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::scatter_reduce")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
|
||||
def scatter_reduce(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
dim: int,
|
||||
index: torch._C.Value,
|
||||
src: torch._C.Value,
|
||||
reduce: str,
|
||||
include_self: bool,
|
||||
):
|
||||
if reduce == "mean":
|
||||
raise errors.OnnxExporterError(
|
||||
"ONNX does not support mean reduction for scatter_reduce"
|
||||
)
|
||||
if not include_self:
|
||||
raise errors.OnnxExporterError(
|
||||
"ONNX does not support include_self=False for scatter_reduce"
|
||||
)
|
||||
|
||||
reduce_mode = { # convert torch string name to onnx string name
|
||||
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
|
||||
"sum": "add",
|
||||
"prod": "mul",
|
||||
"amin": "min",
|
||||
"amax": "max",
|
||||
}
|
||||
onnx_reduce = reduce_mode[reduce]
|
||||
|
||||
self_rank = g.op("Size", g.op("Shape", self))
|
||||
|
||||
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
|
||||
self_rank_is_zero = g.op(
|
||||
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||||
)
|
||||
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||||
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
|
||||
)
|
||||
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
|
||||
self_reshape = if_context.op("Reshape", self, neg_1)
|
||||
utils._add_output_to_block(if_context.block, self_reshape)
|
||||
index_reshape = if_context.op("Reshape", index, neg_1)
|
||||
utils._add_output_to_block(if_context.block, index_reshape)
|
||||
src_reshape = if_context.op("Reshape", src, neg_1)
|
||||
utils._add_output_to_block(if_context.block, src_reshape)
|
||||
|
||||
self_identity = else_context.op("Identity", self)
|
||||
utils._add_output_to_block(else_context.block, self_identity)
|
||||
index_identitye = else_context.op("Identity", index)
|
||||
utils._add_output_to_block(else_context.block, index_identitye)
|
||||
src_identity = else_context.op("Identity", src)
|
||||
utils._add_output_to_block(else_context.block, src_identity)
|
||||
|
||||
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
|
||||
|
||||
# if self_rank == 0:
|
||||
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||||
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
|
||||
)
|
||||
result_squeezed = if_context.op("Squeeze", result)
|
||||
utils._add_output_to_block(if_context.block, result_squeezed)
|
||||
result_identity = else_context.op("Identity", result)
|
||||
utils._add_output_to_block(else_context.block, result_identity)
|
||||
result_final = if_op.node().output()
|
||||
|
||||
return result_final
|
244
torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py
Normal file
244
torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py
Normal file
@ -0,0 +1,244 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 17.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 17]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
BlackmanWindow
|
||||
DFT
|
||||
HammingWindow
|
||||
HannWindow
|
||||
LayerNormalization
|
||||
MelWeightMatrix
|
||||
STFT
|
||||
SequenceMap
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
)
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
__all__ = ["layer_norm", "stft", "quantized_layer_norm"]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::layer_norm")
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
|
||||
def layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: _C.Value,
|
||||
bias: _C.Value,
|
||||
eps: float,
|
||||
cudnn_enable: bool,
|
||||
):
|
||||
# normalized_shape: input shape from an expected input of size
|
||||
# axis: The first normalization dimension.
|
||||
# layer_norm normalizes on the last D dimensions,
|
||||
# where D is the size of normalized_shape
|
||||
axis = -len(normalized_shape)
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
input, _type_utils.JitScalarType.FLOAT
|
||||
)
|
||||
dtype = scalar_type.dtype()
|
||||
if symbolic_helper._is_none(weight):
|
||||
weight_value = torch.ones(normalized_shape, dtype=dtype)
|
||||
weight = g.op("Constant", value_t=weight_value)
|
||||
if symbolic_helper._is_none(bias):
|
||||
bias_value = torch.zeros(normalized_shape, dtype=dtype)
|
||||
bias = g.op("Constant", value_t=bias_value)
|
||||
return g.op(
|
||||
"LayerNormalization",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
epsilon_f=eps,
|
||||
axis_i=axis,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::layer_norm")
|
||||
def quantized_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
x,
|
||||
normalized_shape,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
op_scale,
|
||||
op_zero_point,
|
||||
):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
output = layer_norm(g, x, normalized_shape, weight, bias, eps, False)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
|
||||
def _compute_edge_sizes(n_fft, window_size):
|
||||
"""Helper function to compute the sizes of the edges (left and right)
|
||||
of a given window centered within an FFT size."""
|
||||
left = (n_fft - window_size) // 2
|
||||
right = n_fft - left - window_size
|
||||
return left, right
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::stft")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
|
||||
def stft(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
n_fft: int,
|
||||
hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None,
|
||||
window: Optional[_C.Value] = None,
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = True,
|
||||
return_complex: Optional[bool] = False,
|
||||
align_to_window: Optional[bool] = None,
|
||||
) -> _C.Value:
|
||||
"""Associates `torch.stft` with the `STFT` ONNX operator.
|
||||
Note that torch.stft calls _VF.stft, without centering or padding options.
|
||||
Hence, this function does not contain these two arguments.
|
||||
See torch.stft source code for more info.
|
||||
|
||||
Args:
|
||||
g: Graph to write the ONNX representation into
|
||||
input: Input tensor for the transformation
|
||||
n_fft: FFT size
|
||||
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
|
||||
win_length: Size of the analysis window. Defaults to `n_fft`
|
||||
window: Analysis window. Defaults to a window of all ones
|
||||
normalized: Whether to return a normalized STFT
|
||||
onesided: Whether to return only half (+1) of the results, given the
|
||||
symmetry of the STFT
|
||||
return_complex: Whether to return the complex value (Note: Must be
|
||||
`False` or `None`)
|
||||
|
||||
Returns:
|
||||
op: Operator for torch.stft associated with STFT (ONNX)
|
||||
"""
|
||||
# Checks
|
||||
if return_complex:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT does not currently support complex types", value=input
|
||||
)
|
||||
|
||||
if align_to_window is not None:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT does not currently support the align_to_window option",
|
||||
value=input,
|
||||
) # TODO(#145944): add compatibility with align_to_window option.
|
||||
|
||||
# Get STFT sizes
|
||||
frame_step_value = hop_length if hop_length is not None else n_fft // 4
|
||||
frame_step_const = g.op(
|
||||
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
|
||||
)
|
||||
frame_length_const = g.op(
|
||||
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Pre-process input if needed
|
||||
signal = input
|
||||
signal_rank = symbolic_helper._get_tensor_rank(signal)
|
||||
if signal_rank == 1:
|
||||
# Add batch dimension
|
||||
signal = g.op(
|
||||
"Unsqueeze",
|
||||
signal,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
elif signal_rank is None or signal_rank > 2:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
|
||||
f"Current rank of signal is {signal_rank}, please reduce it.",
|
||||
value=input,
|
||||
)
|
||||
|
||||
# Get window and make sure it's the same size as `win_length` or `n_fft`
|
||||
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
|
||||
if n_win is not None:
|
||||
win_length_default = win_length if win_length else n_fft
|
||||
assert n_win == win_length_default, (
|
||||
"Analysis window size must equal `win_length` or `n_fft`. "
|
||||
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
|
||||
)
|
||||
|
||||
# Center window around zeros if needed (required by ONNX's STFT)
|
||||
if n_win < n_fft:
|
||||
left, right = _compute_edge_sizes(n_fft, n_win)
|
||||
left_win = g.op("Constant", value_t=torch.zeros(left))
|
||||
right_win = g.op("Constant", value_t=torch.zeros(right))
|
||||
window = g.op("Concat", left_win, window, right_win, axis_i=0)
|
||||
|
||||
# Create window, if needed
|
||||
if symbolic_helper._is_none(window):
|
||||
if win_length:
|
||||
if win_length > n_fft:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="The analysis window can't be longer than the size of the FFT. "
|
||||
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
|
||||
value=input,
|
||||
)
|
||||
|
||||
# Center window, if needed
|
||||
left, right = _compute_edge_sizes(n_fft, win_length)
|
||||
torch_window = torch.hstack(
|
||||
(torch.zeros(left), torch.ones(win_length), torch.zeros(right))
|
||||
)
|
||||
else:
|
||||
# Rectangle window
|
||||
torch_window = torch.ones(n_fft)
|
||||
assert torch_window.shape[0] == n_fft
|
||||
window = g.op("Constant", value_t=torch_window)
|
||||
window = g.op(
|
||||
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
|
||||
)
|
||||
|
||||
# Run STFT
|
||||
result = g.op(
|
||||
"STFT",
|
||||
signal,
|
||||
frame_step_const,
|
||||
window,
|
||||
frame_length_const,
|
||||
onesided_i=1 if onesided is None or onesided else 0,
|
||||
)
|
||||
|
||||
# Transpose to mimic torch.stft's behavior
|
||||
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
|
||||
|
||||
# Remove batch dimension, if needed
|
||||
if signal_rank == 1:
|
||||
result = g.op(
|
||||
"Squeeze",
|
||||
result,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
|
||||
# Normalize, if needed
|
||||
if normalized:
|
||||
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
|
||||
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
|
||||
|
||||
return result
|
270
torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py
Normal file
270
torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py
Normal file
@ -0,0 +1,270 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 18.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 18]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
BitwiseAnd
|
||||
CenterCropPad
|
||||
Col2Im
|
||||
Mish
|
||||
OptionalGetElement
|
||||
OptionalHasElement
|
||||
Pad
|
||||
Resize
|
||||
ScatterElements
|
||||
ScatterND
|
||||
Split
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = [
|
||||
"col2im",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__and_")
|
||||
@_onnx_symbolic("aten::bitwise_and")
|
||||
def __and_(g: jit_utils.GraphContext, self, other):
|
||||
# do type promotion (scalars don't seem to apply)
|
||||
args = [self, other]
|
||||
# type promotion doesn't happen with torch.bitwise_and(tensor, scalar)
|
||||
prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)]
|
||||
if len(prom_args) == 0:
|
||||
prom_args = args
|
||||
promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args)
|
||||
self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type)
|
||||
other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type)
|
||||
if promotion_jit_type == _type_utils.JitScalarType.BOOL:
|
||||
return g.op("And", self, other)
|
||||
return g.op("BitwiseAnd", self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::col2im")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
|
||||
def col2im(
|
||||
g,
|
||||
input: _C.Value,
|
||||
output_size: _C.Value,
|
||||
kernel_size: _C.Value,
|
||||
dilation: Sequence[int],
|
||||
padding: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
):
|
||||
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
||||
adjusted_padding: list[int] = []
|
||||
for pad in padding:
|
||||
adjusted_padding.extend(pad for _ in range(2))
|
||||
|
||||
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
||||
if not adjusted_padding:
|
||||
adjusted_padding = [0, 0] * num_dimensional_axis
|
||||
|
||||
if not dilation:
|
||||
dilation = [1] * num_dimensional_axis
|
||||
|
||||
if not stride:
|
||||
stride = [1] * num_dimensional_axis
|
||||
|
||||
return g.op(
|
||||
"Col2Im",
|
||||
input,
|
||||
output_size,
|
||||
kernel_size,
|
||||
dilations_i=dilation,
|
||||
pads_i=adjusted_padding,
|
||||
strides_i=stride,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::prod",
|
||||
decorate=[
|
||||
symbolic_helper._apply_params(
|
||||
"ReduceProd", "prod", allow_multi_dim_support=False
|
||||
)
|
||||
],
|
||||
)
|
||||
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||||
return symbolic_helper._reduce_with_dtype_helper(
|
||||
onnx_op, name, allow_multi_dim_support
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::native_layer_norm")
|
||||
@symbolic_helper.quantized_args(True, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
|
||||
def _native_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: _C.Value,
|
||||
bias: _C.Value,
|
||||
eps: float,
|
||||
) -> tuple[_C.Value, _C.Value, _C.Value]:
|
||||
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::glu")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
def _glu(g: jit_utils.GraphContext, input, dim):
|
||||
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
|
||||
if dim_size is not None:
|
||||
assert dim_size % 2 == 0
|
||||
|
||||
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
|
||||
return g.op("Mul", first, g.op("Sigmoid", second))
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::max")
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::maximum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
def maximum(g: jit_utils.GraphContext, input, other):
|
||||
return max(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::minimum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
def minimum(g: jit_utils.GraphContext, input, other):
|
||||
return min(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amin")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::aminmax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
if not symbolic_helper._is_none(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, axes, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, keepdims_i=keepdim
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::var_mean")
|
||||
def _var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
if len(args) == 1:
|
||||
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
|
||||
else:
|
||||
return symbolic_helper._var_mean_helper(g, input, *args)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::logsumexp")
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||||
if dim is None:
|
||||
return g.op("ReduceLogSumExp", input, keepdims_i=0)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||||
def _linalg_matrix_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: torch._C.Value,
|
||||
dim: list[int],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
):
|
||||
return symbolic_helper._embedding_bag_helper(
|
||||
g,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_vector_norm")
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
def linalg_vector_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: float,
|
||||
dim: Optional[Sequence[int]],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
|
@ -0,0 +1,31 @@
|
||||
"""This file exports ONNX ops for opset 19.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 19]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AveragePool
|
||||
Cast
|
||||
CastLike
|
||||
Constant
|
||||
DeformConv
|
||||
DequantizeLinear
|
||||
Equal
|
||||
Identity
|
||||
If
|
||||
Loop
|
||||
Pad
|
||||
QuantizeLinear
|
||||
Reshape
|
||||
Resize
|
||||
Scan
|
||||
Shape
|
||||
Size
|
||||
"""
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__: list[str] = []
|
@ -0,0 +1,95 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 20.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 20]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AffineGrid
|
||||
ConstantOfShape
|
||||
DFT
|
||||
Gelu
|
||||
GridSample
|
||||
ImageDecoder
|
||||
IsInf
|
||||
IsNaN
|
||||
ReduceMax
|
||||
ReduceMin
|
||||
RegexFullMatch
|
||||
StringConcat
|
||||
StringSplit
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import _C
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
)
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
|
||||
|
||||
|
||||
def convert_grid_sample_mode(mode_s):
|
||||
return (
|
||||
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
def _grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
grid: _C.Value,
|
||||
mode_enum: int,
|
||||
padding_mode_enum: int,
|
||||
align_corners: bool,
|
||||
):
|
||||
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
|
||||
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
|
||||
mode_s = convert_grid_sample_mode(mode_s)
|
||||
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
|
||||
padding_mode_enum # type: ignore[index]
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::affine_grid_generator")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def _affine_grid_generator(
|
||||
g: jit_utils.GraphContext,
|
||||
theta: _C.Value,
|
||||
size: _C.Value,
|
||||
align_corners: bool,
|
||||
):
|
||||
return g.op(
|
||||
"AffineGrid",
|
||||
theta,
|
||||
size,
|
||||
align_corners_i=int(align_corners),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::gelu")
|
||||
@symbolic_helper.parse_args("v", "s")
|
||||
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
|
||||
return g.op("Gelu", self, approximate_s=approximate)
|
71
torch/onnx/_internal/torchscript_exporter/symbolic_opset7.py
Normal file
71
torch/onnx/_internal/torchscript_exporter/symbolic_opset7.py
Normal file
@ -0,0 +1,71 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Note [ONNX operators that are added/updated from opset 7 to opset 8]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
Expand
|
||||
|
||||
Updated operators:
|
||||
Min, Max, Sum, Mean: supports multidirectional broadcasting.
|
||||
MaxPool: added optional indices output.
|
||||
Scan
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7)
|
||||
|
||||
block_listed_operators = (
|
||||
"scan",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"meshgrid",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d_with_indices",
|
||||
)
|
||||
|
||||
|
||||
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
@_onnx_symbolic("aten::max")
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn(
|
||||
"Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to max operators "
|
||||
"have different shapes"
|
||||
)
|
||||
return opset9.max(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn(
|
||||
"Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to min operators "
|
||||
"have different shapes"
|
||||
)
|
||||
return opset9.min(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
for block_listed_op in block_listed_operators:
|
||||
_onnx_symbolic(f"aten::{block_listed_op}")(
|
||||
symbolic_helper._block_list_in_opset(block_listed_op)
|
||||
)
|
469
torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py
Normal file
469
torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py
Normal file
@ -0,0 +1,469 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
Compress
|
||||
ConstantOfShape
|
||||
EyeLike
|
||||
MaxUnpool
|
||||
OneHot
|
||||
Sinh
|
||||
Cosh
|
||||
Asinh
|
||||
Acosh
|
||||
Atanh
|
||||
Shrink
|
||||
IsNaN
|
||||
Sign
|
||||
Erf
|
||||
Scatter
|
||||
Where
|
||||
NonZero
|
||||
TfIdfVectorizer
|
||||
MeanVarianceNormalization
|
||||
|
||||
Updated operators:
|
||||
BatchNormalization: removed spatial attribute.
|
||||
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
|
||||
Cast: more data types{string} supported.
|
||||
Upsample: moved scales from attribute to input.
|
||||
Scan
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import (
|
||||
_type_utils,
|
||||
jit_utils,
|
||||
registration,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
|
||||
|
||||
block_listed_operators = (
|
||||
"nonzero",
|
||||
"where",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"erf",
|
||||
"sign",
|
||||
"isnan",
|
||||
"gather",
|
||||
"arange",
|
||||
"masked_fill",
|
||||
"index_fill",
|
||||
"index_copy",
|
||||
"repeat_interleave",
|
||||
"any",
|
||||
"all",
|
||||
)
|
||||
|
||||
for block_listed_op in block_listed_operators:
|
||||
_onnx_symbolic(f"aten::{block_listed_op}")(
|
||||
symbolic_helper._block_list_in_opset(block_listed_op)
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||||
g, interpolate_mode, args
|
||||
)
|
||||
symbolic_helper._interpolate_warning(interpolate_mode)
|
||||
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
||||
if align_corners:
|
||||
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
||||
output_size = symbolic_helper._maybe_get_const(output_size, "is")
|
||||
if symbolic_helper._is_value(output_size):
|
||||
return symbolic_helper._unimplemented(
|
||||
name, "torch._C.Value (output_size) indexing"
|
||||
)
|
||||
if scales is None:
|
||||
scales = [
|
||||
1.0
|
||||
if i < 2
|
||||
else float(output_size[-(dim - i)])
|
||||
/ float(input.type().sizes()[-(dim - i)])
|
||||
for i in range(0, dim)
|
||||
]
|
||||
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__interpolate")
|
||||
def __interpolate(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
size,
|
||||
scale_factor,
|
||||
mode,
|
||||
align_corners,
|
||||
recompute_scale_factor,
|
||||
antialias,
|
||||
):
|
||||
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
|
||||
if not symbolic_helper._is_none(align_corners) and align_corners:
|
||||
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
|
||||
|
||||
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
|
||||
scale_factor
|
||||
):
|
||||
return symbolic_helper._unimplemented(
|
||||
"interpolate", "dynamic scales in opset 8"
|
||||
)
|
||||
|
||||
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
|
||||
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
|
||||
|
||||
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
||||
g, input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
|
||||
|
||||
|
||||
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
|
||||
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
|
||||
# is lost after casting.
|
||||
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
||||
floating_scalar_types = {
|
||||
_type_utils.JitScalarType.HALF,
|
||||
_type_utils.JitScalarType.FLOAT,
|
||||
_type_utils.JitScalarType.DOUBLE,
|
||||
}
|
||||
old_type = None
|
||||
# Cast the input tensor to Float if its scalarType is known and is not floating number.
|
||||
# If casting is performed, return the old scalarType, otherwise return None.
|
||||
arg0_type = _type_utils.JitScalarType.from_value(
|
||||
args[0], _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
|
||||
old_type = arg0_type
|
||||
if old_type not in floating_scalar_types:
|
||||
old_type = old_type.scalar_name() # type: ignore[assignment]
|
||||
args = tuple(
|
||||
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
for arg in args
|
||||
)
|
||||
else:
|
||||
return (None,) + args
|
||||
else:
|
||||
warnings.warn(
|
||||
"Only floating datatype is supported for these operators: "
|
||||
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
|
||||
"the onnx model to be incorrect, if inputs have integer datatypes."
|
||||
)
|
||||
return (old_type,) + args
|
||||
|
||||
|
||||
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
||||
if to_type is None:
|
||||
return input
|
||||
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
|
||||
|
||||
|
||||
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
||||
other = symbolic_helper._maybe_get_scalar(other)
|
||||
other = symbolic_helper._if_scalar_type_as(other, input)
|
||||
_, input, other = _try_cast_integer_to_float(g, input, other)
|
||||
return g.op(op_name, input, other)
|
||||
|
||||
|
||||
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
|
||||
# integer input type not supported in opset8. Cast to float if possible.
|
||||
@_onnx_symbolic("aten::gt")
|
||||
def gt(g: jit_utils.GraphContext, input, other):
|
||||
return _comparison_operator(g, input, other, "Greater")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::lt")
|
||||
def lt(g: jit_utils.GraphContext, input, other):
|
||||
return _comparison_operator(g, input, other, "Less")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::bmm")
|
||||
def bmm(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, other = _try_cast_integer_to_float(g, self, other)
|
||||
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
|
||||
else:
|
||||
return g.op("MatMul", self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::matmul")
|
||||
def matmul(g: jit_utils.GraphContext, self, other):
|
||||
return bmm(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::prelu")
|
||||
def prelu(g: jit_utils.GraphContext, self, weight):
|
||||
self_rank = symbolic_helper._get_tensor_rank(self)
|
||||
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
|
||||
if self_rank is not None and self_rank > 2:
|
||||
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
|
||||
elif self_rank == 0 and weight_sizes == [1]:
|
||||
# self and weight are both scalar but weight has rank == 1, squeeze weight.
|
||||
weight = symbolic_helper._squeeze_helper(g, weight, [0])
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
|
||||
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
|
||||
else:
|
||||
return g.op("PRelu", self, weight)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::mm")
|
||||
def mm(g: jit_utils.GraphContext, self, other):
|
||||
# Create a dummy C tensor. Only needed for API purposes, the value is
|
||||
# since beta = 0
|
||||
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
|
||||
if scalar_type is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"mm can only operate on tensors with known types", self
|
||||
)
|
||||
zero_constant = g.op(
|
||||
"Constant",
|
||||
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
|
||||
)
|
||||
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, other, zero_constant = _try_cast_integer_to_float(
|
||||
g, self, other, zero_constant
|
||||
)
|
||||
return _cast_to_type(
|
||||
g,
|
||||
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
|
||||
old_type,
|
||||
)
|
||||
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::addmm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
|
||||
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
|
||||
return _cast_to_type(
|
||||
g,
|
||||
g.op(
|
||||
"Gemm",
|
||||
mat1,
|
||||
mat2,
|
||||
self,
|
||||
beta_f=symbolic_helper._scalar(beta),
|
||||
alpha_f=symbolic_helper._scalar(alpha),
|
||||
),
|
||||
old_type,
|
||||
)
|
||||
else:
|
||||
return g.op(
|
||||
"Gemm",
|
||||
mat1,
|
||||
mat2,
|
||||
self,
|
||||
beta_f=symbolic_helper._scalar(beta),
|
||||
alpha_f=symbolic_helper._scalar(alpha),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::flatten")
|
||||
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
||||
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
|
||||
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
|
||||
|
||||
dim = input.type().dim()
|
||||
if end_dim_i < 0:
|
||||
end_dim_i = dim + end_dim_i
|
||||
# use ONNX's Flatten operator for cases where the output shape is 2D
|
||||
if start_dim_i == 1 and end_dim_i == dim - 1:
|
||||
if symbolic_helper._try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(
|
||||
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
|
||||
)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=start_dim_i)
|
||||
if start_dim_i == 0 and end_dim_i == dim - 2:
|
||||
if symbolic_helper._try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(
|
||||
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
|
||||
)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=end_dim_i + 1)
|
||||
|
||||
return opset9.flatten(g, input, start_dim, end_dim)
|
||||
|
||||
|
||||
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
|
||||
if dtype is None:
|
||||
scalar_type = _type_utils.JitScalarType.FLOAT
|
||||
else:
|
||||
scalar_type = _type_utils.JitScalarType(dtype)
|
||||
if not scalar_type.dtype().is_floating_point:
|
||||
result = g.op(
|
||||
"ConstantFill",
|
||||
sizes,
|
||||
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
|
||||
input_as_shape_i=1,
|
||||
value_f=const_value,
|
||||
)
|
||||
return g.op("Cast", result, to_i=scalar_type.onnx_type())
|
||||
else:
|
||||
return g.op(
|
||||
"ConstantFill",
|
||||
sizes,
|
||||
dtype_i=scalar_type.onnx_type(),
|
||||
input_as_shape_i=1,
|
||||
value_f=const_value,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::empty")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def empty(
|
||||
g: jit_utils.GraphContext,
|
||||
sizes,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::empty_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def empty_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
return zeros_like(g, input, dtype, layout, device, pin_memory)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::zeros")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||||
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||||
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
||||
return _constant_fill(g, sizes, dtype, 0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::zeros_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def zeros_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ones")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||||
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||||
return _constant_fill(g, sizes, dtype, 1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ones_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def ones_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::full")
|
||||
def full(
|
||||
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
|
||||
):
|
||||
const_value = symbolic_helper._maybe_get_const(value, "t")
|
||||
if symbolic_helper._is_value(const_value):
|
||||
tmp = zeros(g, sizes, dtype, layout, device)
|
||||
return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
||||
else:
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
return _constant_fill(g, sizes, dtype, const_value)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::full_like")
|
||||
@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
|
||||
def full_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
fill_value,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, fill_value)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::repeat")
|
||||
def repeat(g: jit_utils.GraphContext, self, repeats):
|
||||
if not symbolic_helper._is_value(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
if symbolic_helper._is_packed_list(repeats):
|
||||
repeat_size_len = len(symbolic_helper._unpack_list(repeats))
|
||||
else:
|
||||
const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
|
||||
repeat_size_len = len(const_repeats)
|
||||
if self.isCompleteTensor():
|
||||
sizes = self.type().sizes()
|
||||
diff_dims = repeat_size_len - len(sizes)
|
||||
if diff_dims > 0:
|
||||
self = opset9.view(
|
||||
g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
|
||||
)
|
||||
return g.op("Tile", self, repeats)
|
6656
torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py
Normal file
6656
torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py
Normal file
File diff suppressed because it is too large
Load Diff
1930
torch/onnx/_internal/torchscript_exporter/utils.py
Normal file
1930
torch/onnx/_internal/torchscript_exporter/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
1863
torch/onnx/_internal/torchscript_exporter/verification.py
Normal file
1863
torch/onnx/_internal/torchscript_exporter/verification.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,464 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset12."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import (
|
||||
_type_utils,
|
||||
errors,
|
||||
symbolic_helper,
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
# This file exports ONNX ops for opset 12
|
||||
|
||||
__all__ = [
|
||||
"argmax",
|
||||
"argmin",
|
||||
"binary_cross_entropy_with_logits",
|
||||
"celu",
|
||||
"cross_entropy_loss",
|
||||
"dropout",
|
||||
"einsum",
|
||||
"ge",
|
||||
"le",
|
||||
"native_dropout",
|
||||
"nll_loss",
|
||||
"nll_loss2d",
|
||||
"nll_loss_nd",
|
||||
"outer",
|
||||
"pow",
|
||||
"tensordot",
|
||||
"unfold",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
|
||||
|
||||
|
||||
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
|
||||
if not tensors:
|
||||
raise RuntimeError("Einsum inputs are empty.")
|
||||
# ONNX does not support bool for Einsum inputs.
|
||||
if symbolic_helper._is_bool(tensors[0]):
|
||||
tensors = [
|
||||
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
for tensor in tensors
|
||||
]
|
||||
return g.op(
|
||||
"Cast",
|
||||
g.op("Einsum", *tensors, equation_s=equation),
|
||||
to_i=_C_onnx.TensorProtoDataType.BOOL,
|
||||
)
|
||||
else:
|
||||
return g.op("Einsum", *tensors, equation_s=equation)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::einsum")
|
||||
@symbolic_helper.parse_args("s", "v", "is")
|
||||
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
|
||||
tensors = symbolic_helper._unpack_list(tensor_list)
|
||||
return _einsum_helper(g, equation, tensors)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::outer")
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
def outer(g: jit_utils.GraphContext, input, other):
|
||||
# make sure to cast other to self's type
|
||||
if _type_utils.JitScalarType.from_value(
|
||||
other, _type_utils.JitScalarType.UNDEFINED
|
||||
) != _type_utils.JitScalarType.from_value(input):
|
||||
other = g.op(
|
||||
"Cast",
|
||||
other,
|
||||
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
|
||||
)
|
||||
return _einsum_helper(g, "i,j->ij", [input, other])
|
||||
|
||||
|
||||
def _dropout_returns_masked_input_and_mask(
|
||||
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
|
||||
) -> tuple[torch._C.Value, torch._C.Value | None]:
|
||||
symbolic_helper.check_training_mode(train, "dropout")
|
||||
# In eval mode, dropout is non-op. That is, if the node's
|
||||
# train param is set to False, dropout just returns its inputs.
|
||||
if not train:
|
||||
return input, None
|
||||
p = g.op("Constant", value_t=torch.tensor(p))
|
||||
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
|
||||
r, mask = g.op("Dropout", input, p, t, outputs=2)
|
||||
return r, mask
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
def dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
return masked
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::native_dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
def native_dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
return _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss")
|
||||
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
# sum reduction : onnx::Constant[value={2}]
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
reduction = reduction_vals[reduction]
|
||||
|
||||
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
|
||||
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
||||
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
||||
if weight.node().mustBeNone():
|
||||
nllloss = g.op(
|
||||
"NegativeLogLikelihoodLoss",
|
||||
self,
|
||||
target,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
else:
|
||||
nllloss = g.op(
|
||||
"NegativeLogLikelihoodLoss",
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
|
||||
return nllloss
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss2d")
|
||||
def nll_loss2d(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss_nd")
|
||||
def nll_loss_nd(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::cross_entropy_loss")
|
||||
def cross_entropy_loss(
|
||||
g: jit_utils.GraphContext,
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
# sum reduction : onnx::Constant[value={2}]
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
reduction = reduction_vals[reduction]
|
||||
|
||||
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
|
||||
if label_smoothing is not None and label_smoothing > 0.0:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX does not support label_smoothing", self
|
||||
)
|
||||
|
||||
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
|
||||
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
||||
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
||||
if weight.node().mustBeNone():
|
||||
celoss = g.op(
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
self,
|
||||
target,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
else:
|
||||
celoss = g.op(
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
reduction_s=reduction,
|
||||
ignore_index_i=ignore_index,
|
||||
)
|
||||
|
||||
return celoss
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
||||
def binary_cross_entropy_with_logits(
|
||||
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
|
||||
):
|
||||
p = g.op("Constant", value_t=torch.tensor([1]))
|
||||
sig_x = opset9.sigmoid(g, input)
|
||||
log_sig_x = opset9.log(g, sig_x)
|
||||
sub_1_x = opset9.sub(g, p, sig_x)
|
||||
sub_1_y = opset9.sub(g, p, target)
|
||||
log_1_x = opset9.log(g, sub_1_x)
|
||||
if pos_weight is None or symbolic_helper._is_none(pos_weight):
|
||||
output = opset9.neg(
|
||||
g,
|
||||
opset9.add(
|
||||
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = opset9.neg(
|
||||
g,
|
||||
opset9.add(
|
||||
g,
|
||||
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
|
||||
opset9.mul(g, sub_1_y, log_1_x),
|
||||
),
|
||||
)
|
||||
|
||||
if weight is not None and not symbolic_helper._is_none(weight):
|
||||
output = opset9.mul(g, weight, output)
|
||||
|
||||
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
||||
if reduction == 0:
|
||||
return output
|
||||
elif reduction == 1:
|
||||
return g.op("ReduceMean", output, keepdims_i=0)
|
||||
elif reduction == 2:
|
||||
return g.op("ReduceSum", output, keepdims_i=0)
|
||||
else:
|
||||
return symbolic_helper._onnx_unsupported(
|
||||
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
|
||||
input,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::celu")
|
||||
def celu(g: jit_utils.GraphContext, self, alpha):
|
||||
alpha = symbolic_helper._maybe_get_const(alpha, "f")
|
||||
# if the input is of type double cast it to float
|
||||
if (
|
||||
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
|
||||
== _type_utils.JitScalarType.DOUBLE
|
||||
):
|
||||
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
out = g.op("Celu", self, alpha_f=alpha)
|
||||
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||||
|
||||
return g.op("Celu", self, alpha_f=alpha)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::argmax")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def argmax(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
dim: torch._C.Value,
|
||||
keepdim: bool,
|
||||
):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::argmin")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def argmin(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
dim: torch._C.Value,
|
||||
keepdim: bool,
|
||||
):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::pow")
|
||||
def pow(g: jit_utils.GraphContext, self, exponent):
|
||||
return g.op("Pow", self, exponent)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ge")
|
||||
def ge(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("GreaterOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::le")
|
||||
def le(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("LessOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unfold")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||||
const_size = symbolic_helper._maybe_get_const(size, "i")
|
||||
const_step = symbolic_helper._maybe_get_const(step, "i")
|
||||
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
|
||||
const_step
|
||||
):
|
||||
return opset9.unfold(g, input, dimension, const_size, const_step)
|
||||
|
||||
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
|
||||
if sizedim is not None:
|
||||
low_start = g.op("Constant", value_t=torch.tensor(0))
|
||||
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
|
||||
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
|
||||
low_indices = g.op("Range", low_start, low_end, step)
|
||||
hi_indices = g.op("Range", size, hi_end, step)
|
||||
|
||||
low_size = symbolic_helper._size_helper(
|
||||
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
|
||||
)
|
||||
hi_size = symbolic_helper._size_helper(
|
||||
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
|
||||
)
|
||||
|
||||
ndim = symbolic_helper._get_tensor_rank(input)
|
||||
assert ndim is not None
|
||||
perm = list(range(0, ndim))
|
||||
perm.append(perm.pop(dimension))
|
||||
|
||||
unsqueeze_list = []
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op(
|
||||
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
loop_len = g.op("Min", low_size, hi_size)
|
||||
|
||||
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", loop_len, loop_condition, n_blocks=1
|
||||
)
|
||||
|
||||
loop_block = loop_context.block
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block) # noqa: F841
|
||||
|
||||
starts = loop_context.op("Gather", low_indices, block_input_iter)
|
||||
ends = loop_context.op("Gather", hi_indices, block_input_iter)
|
||||
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
|
||||
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
|
||||
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
|
||||
stack = loop_context.op("Slice", input, starts, ends, axes)
|
||||
|
||||
unsqueeze = symbolic_helper._unsqueeze_helper(
|
||||
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
|
||||
)
|
||||
unsqueeze_list.append(unsqueeze)
|
||||
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
|
||||
|
||||
cond_out = loop_context.op(
|
||||
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
utils._add_output_to_block(loop_block, cond_out)
|
||||
utils._add_output_to_block(loop_block, concat)
|
||||
|
||||
loop_output = loop.node().output()
|
||||
perm = [0, 1, 2, 3, 4]
|
||||
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
|
||||
transpose = g.op("Transpose", loop_output, perm_i=perm)
|
||||
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
|
||||
|
||||
return squeeze
|
||||
|
||||
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tensordot")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
|
||||
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
|
||||
if out is not None:
|
||||
symbolic_helper._unimplemented(
|
||||
"Tensordot", "Out parameter is not supported for tensordot."
|
||||
)
|
||||
|
||||
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
|
||||
if dim_count_a is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
|
||||
input_a,
|
||||
)
|
||||
|
||||
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
|
||||
if dim_count_b is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
|
||||
input_b,
|
||||
)
|
||||
|
||||
dims_a = [
|
||||
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
|
||||
for i in range(len(dims_a))
|
||||
]
|
||||
dims_b = [
|
||||
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
|
||||
for i in range(len(dims_b))
|
||||
]
|
||||
|
||||
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
|
||||
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
|
||||
|
||||
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
|
||||
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
|
||||
|
||||
input_shape = g.op("Shape", new_input_a)
|
||||
left_sizes_a = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
|
||||
)
|
||||
shape_sizes = [
|
||||
left_sizes_a,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
]
|
||||
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", output_a)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
||||
)
|
||||
shape_sizes = [
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
slices,
|
||||
]
|
||||
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", new_input_b)
|
||||
left_sizes_b = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
|
||||
)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
|
||||
)
|
||||
shape_sizes = [
|
||||
slices,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
]
|
||||
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
||||
|
||||
input_shape = g.op("Shape", output_b)
|
||||
slices = symbolic_helper._slice_helper(
|
||||
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
||||
)
|
||||
shape_sizes = [
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
||||
slices,
|
||||
]
|
||||
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
||||
|
||||
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
|
||||
|
||||
shape_sizes = [left_sizes_a, left_sizes_b]
|
||||
return opset9._reshape_from_tensor(g, output, shape_sizes)
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F401,F403
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,291 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 14.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset14."""
|
||||
|
||||
Note [ONNX operators that are added/updated in opset 14]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
HardSwish, Trilu
|
||||
|
||||
Updated operators:
|
||||
Reshape
|
||||
Add, Sub, Mul, Div
|
||||
GRU, LSTM, RNN
|
||||
BatchNorm, Cumsum, Relu
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants, _type_utils, symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"hardswish",
|
||||
"tril",
|
||||
"triu",
|
||||
"reshape",
|
||||
"batch_norm",
|
||||
"quantized_hardswish",
|
||||
"scaled_dot_product_attention",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::hardswish")
|
||||
@symbolic_helper.parse_args("v")
|
||||
def hardswish(g: jit_utils.GraphContext, self):
|
||||
return g.op("HardSwish", self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tril")
|
||||
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::triu")
|
||||
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::reshape")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
def reshape(g: jit_utils.GraphContext, self, shape):
|
||||
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
|
||||
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
|
||||
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::batch_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
|
||||
def batch_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
cudnn_enabled,
|
||||
):
|
||||
if (
|
||||
torch.is_autocast_enabled()
|
||||
and not symbolic_helper.args_have_same_dtype(
|
||||
[input, weight, bias, running_mean, running_var]
|
||||
)
|
||||
and GLOBALS.export_onnx_opset_version < 15
|
||||
):
|
||||
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||||
"BatchNormalization",
|
||||
14,
|
||||
15,
|
||||
"All input tensors must have the same `dtype`."
|
||||
" Turn off Autocast or export using opset version 15.",
|
||||
input,
|
||||
)
|
||||
|
||||
symbolic_helper.check_training_mode(training, "batch_norm")
|
||||
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
|
||||
g, input, weight, bias, running_mean, running_var
|
||||
)
|
||||
out = g.op(
|
||||
"BatchNormalization",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
epsilon_f=eps,
|
||||
momentum_f=1 - momentum,
|
||||
training_mode_i=0 if not training else 1,
|
||||
outputs=1 if not training else 3,
|
||||
)
|
||||
if not training:
|
||||
return out
|
||||
else:
|
||||
res, new_running_mean, new_running_var = out
|
||||
new_running_mean.setType(running_mean.type())
|
||||
new_running_var.setType(running_var.type())
|
||||
return res
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::hardswish")
|
||||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
output = hardswish(g, x)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
|
||||
# Ported from
|
||||
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
|
||||
# aten_scaled_dot_product_attention
|
||||
# NOTE: Need op.Trilu
|
||||
@_onnx_symbolic("aten::scaled_dot_product_attention")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
|
||||
def scaled_dot_product_attention(
|
||||
g: jit_utils.GraphContext,
|
||||
query: torch._C.Value,
|
||||
key: torch._C.Value,
|
||||
value: torch._C.Value,
|
||||
attn_mask: torch._C.Value | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: torch._C.Value | None = None,
|
||||
enable_gqa: bool = False,
|
||||
):
|
||||
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
||||
"is_causal and attn_mask cannot be set at the same time"
|
||||
)
|
||||
assert not enable_gqa, (
|
||||
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
)
|
||||
|
||||
if symbolic_helper._is_none(scale):
|
||||
scale = _attention_scale(g, query)
|
||||
|
||||
if is_causal:
|
||||
attn_mask = _causal_attention_mask(g, query, key)
|
||||
|
||||
# Swap the last two axes of key
|
||||
# NOTE: onnx-script has different logic here, because the attribute perms in
|
||||
# transpose needs list of ints
|
||||
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
|
||||
key_transposed_axes = list(range(key_shape_builtin))
|
||||
key_transposed_axes[-1], key_transposed_axes[-2] = (
|
||||
key_transposed_axes[-2],
|
||||
key_transposed_axes[-1],
|
||||
)
|
||||
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
|
||||
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
|
||||
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
|
||||
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
|
||||
|
||||
if symbolic_helper._is_none(attn_mask):
|
||||
mul_qk_add = mul_qk
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
elif (
|
||||
_type_utils.JitScalarType.from_value(attn_mask)
|
||||
== _type_utils.JitScalarType.BOOL
|
||||
):
|
||||
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
|
||||
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
|
||||
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
|
||||
# the behavior of PyTorch with boolean masks.
|
||||
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
|
||||
elif _type_utils.JitScalarType.from_value(attn_mask) in (
|
||||
_type_utils.JitScalarType.FLOAT,
|
||||
_type_utils.JitScalarType.HALF,
|
||||
_type_utils.JitScalarType.BFLOAT16,
|
||||
):
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
|
||||
)
|
||||
|
||||
if dropout_p != 0:
|
||||
attn_weight = g.op(
|
||||
"Dropout",
|
||||
attn_weight,
|
||||
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
|
||||
)
|
||||
|
||||
return g.op("MatMul", attn_weight, value)
|
||||
|
||||
|
||||
def _attention_scale(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Calculate the scale factor for the attention result.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
|
||||
Returns:
|
||||
Scalar scale factor := 1 / math.sqrt(query.size(-1))
|
||||
"""
|
||||
query_shape = g.op("Shape", query)
|
||||
query_shape_last = g.op(
|
||||
"Slice",
|
||||
query_shape,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
|
||||
g.op(
|
||||
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
|
||||
),
|
||||
)
|
||||
embedding_size = g.op(
|
||||
"Cast",
|
||||
query_shape_last,
|
||||
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
|
||||
)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
|
||||
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
|
||||
# Add a Cast to convert the scale back to original type
|
||||
scale = g.op(
|
||||
"Cast",
|
||||
scale,
|
||||
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
|
||||
)
|
||||
return scale
|
||||
|
||||
|
||||
def _causal_attention_mask(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Create a causal mask for the given query and key tensors.
|
||||
|
||||
Equivalent to::
|
||||
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_mask = torch.zeros(L, S, dtype=torch.float)
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
key: Tensor of shape [..., S, E]
|
||||
|
||||
Returns:
|
||||
Tensor of shape [L, S]
|
||||
"""
|
||||
|
||||
query_shape = g.op("Shape", query)
|
||||
key_shape = g.op("Shape", key)
|
||||
|
||||
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
|
||||
target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
|
||||
source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
|
||||
# attn_mask = torch.ones(L, S) := {
|
||||
size = g.op("Concat", target_length, source_length, axis_i=0)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0]))
|
||||
attn_mask = g.op("Expand", const_one, size)
|
||||
# }
|
||||
attn_mask = g.op("Trilu", attn_mask, upper_i=0)
|
||||
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op(
|
||||
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
|
||||
)
|
||||
return attn_mask
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F401,F403
|
||||
|
@ -1,80 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 15.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset15."""
|
||||
|
||||
Note [ONNX operators that are added/updated in opset 15]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
Bernoulli
|
||||
CastLike
|
||||
Optional
|
||||
OptionalGetElement
|
||||
OptionalHasElement
|
||||
|
||||
Updated operators:
|
||||
BatchNormalization https://github.com/onnx/onnx/pull/3545
|
||||
Backwards compatible
|
||||
TODO: test coverage for mixed types inputs.
|
||||
Pow https://github.com/onnx/onnx/pull/3412
|
||||
Backwards compatible
|
||||
TODO: bfloat16 support.
|
||||
Shape https://github.com/onnx/onnx/pull/3580
|
||||
Backwards compatible
|
||||
TODO: optional start/end attribute.
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__is_")
|
||||
def aten__is_(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._is_none(other):
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
none = g.op("OptionalHasElement", self)
|
||||
return g.op("Not", none)
|
||||
else:
|
||||
return g.op("Constant", value_t=torch.BoolTensor([0]))
|
||||
return opset9.eq(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__isnot_")
|
||||
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
|
||||
def aten__isnot_(g: jit_utils.GraphContext, self, other):
|
||||
return aten__is_(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::bernoulli")
|
||||
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
|
||||
if out is not None and not symbolic_helper._is_none(out):
|
||||
symbolic_helper._unimplemented(
|
||||
"Bernoulli", "out parameter is not supported for bernoulli", input
|
||||
)
|
||||
if generator is not None and not symbolic_helper._is_none(generator):
|
||||
symbolic_helper._unimplemented(
|
||||
"Bernoulli", "generator is not supported for bernoulli", input
|
||||
)
|
||||
if p is None or symbolic_helper._is_none(p):
|
||||
return g.op("Bernoulli", input)
|
||||
return opset9.bernoulli(g, input, p, generator, out)
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::unchecked_cast")
|
||||
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
|
||||
# exists to refine the type of the Value
|
||||
# if x is Optional[Tensor], unchecked_cast will cast
|
||||
# x to Tensor, so the rest of the graph knows that x is a Tensor.
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
return g.op("OptionalGetElement", self)
|
||||
|
||||
return self
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F401,F403
|
||||
|
@ -1,185 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 16.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset16."""
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 16]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
GridSample https://github.com/onnx/onnx/pull/3557
|
||||
|
||||
Updated operators:
|
||||
Identity
|
||||
If
|
||||
LeakyRelu
|
||||
Loop
|
||||
PRelu
|
||||
RoiAlign
|
||||
Scan
|
||||
ScatterElements
|
||||
ScatterND
|
||||
Where
|
||||
GreaterOrEqual
|
||||
LessOrEqual
|
||||
"""
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import (
|
||||
GRID_SAMPLE_INTERPOLATION_MODES,
|
||||
GRID_SAMPLE_PADDING_MODES,
|
||||
)
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper, utils
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
||||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
def grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
grid,
|
||||
mode_enum,
|
||||
padding_mode_enum,
|
||||
align_corners,
|
||||
):
|
||||
# Check the input and grid tensor rank beforehand.
|
||||
if symbolic_helper._get_tensor_rank(input) == 5:
|
||||
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
|
||||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
|
||||
padding_mode_enum
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::scatter_add")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
src_type = _type_utils.JitScalarType.from_value(
|
||||
src, _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
src_sizes = symbolic_helper._get_tensor_sizes(src)
|
||||
index_sizes = symbolic_helper._get_tensor_sizes(index)
|
||||
|
||||
if len(src_sizes) != len(index_sizes):
|
||||
return symbolic_helper._unimplemented(
|
||||
"scatter_add",
|
||||
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
|
||||
)
|
||||
|
||||
# PyTorch only allows index shape <= src shape, so we can only consider
|
||||
# taking index as subset size to src, like PyTorch does. When sizes for src
|
||||
# and index are not matched or there are dynamic axes, we take index shape to
|
||||
# slice src to accommodate.
|
||||
if src_sizes != index_sizes or None in index_sizes:
|
||||
adjusted_shape = g.op("Shape", index)
|
||||
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
|
||||
src = g.op("Slice", src, starts, adjusted_shape)
|
||||
|
||||
src = symbolic_helper._maybe_get_scalar(src)
|
||||
if symbolic_helper._is_value(src):
|
||||
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
|
||||
else:
|
||||
# Check if scalar "src" has same type as self (PyTorch allows different
|
||||
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
||||
if _type_utils.JitScalarType.from_value(self) != src_type:
|
||||
src = g.op(
|
||||
"Cast",
|
||||
src,
|
||||
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
||||
)
|
||||
|
||||
return g.op(
|
||||
"ScatterElements",
|
||||
self,
|
||||
index,
|
||||
src,
|
||||
axis_i=dim,
|
||||
reduction_s="add",
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::scatter_reduce")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
|
||||
def scatter_reduce(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
dim: int,
|
||||
index: torch._C.Value,
|
||||
src: torch._C.Value,
|
||||
reduce: str,
|
||||
include_self: bool,
|
||||
):
|
||||
if reduce == "mean":
|
||||
raise errors.OnnxExporterError(
|
||||
"ONNX does not support mean reduction for scatter_reduce"
|
||||
)
|
||||
if not include_self:
|
||||
raise errors.OnnxExporterError(
|
||||
"ONNX does not support include_self=False for scatter_reduce"
|
||||
)
|
||||
|
||||
reduce_mode = { # convert torch string name to onnx string name
|
||||
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
|
||||
"sum": "add",
|
||||
"prod": "mul",
|
||||
"amin": "min",
|
||||
"amax": "max",
|
||||
}
|
||||
onnx_reduce = reduce_mode[reduce]
|
||||
|
||||
self_rank = g.op("Size", g.op("Shape", self))
|
||||
|
||||
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
|
||||
self_rank_is_zero = g.op(
|
||||
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||||
)
|
||||
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||||
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
|
||||
)
|
||||
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
|
||||
self_reshape = if_context.op("Reshape", self, neg_1)
|
||||
utils._add_output_to_block(if_context.block, self_reshape)
|
||||
index_reshape = if_context.op("Reshape", index, neg_1)
|
||||
utils._add_output_to_block(if_context.block, index_reshape)
|
||||
src_reshape = if_context.op("Reshape", src, neg_1)
|
||||
utils._add_output_to_block(if_context.block, src_reshape)
|
||||
|
||||
self_identity = else_context.op("Identity", self)
|
||||
utils._add_output_to_block(else_context.block, self_identity)
|
||||
index_identitye = else_context.op("Identity", index)
|
||||
utils._add_output_to_block(else_context.block, index_identitye)
|
||||
src_identity = else_context.op("Identity", src)
|
||||
utils._add_output_to_block(else_context.block, src_identity)
|
||||
|
||||
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
|
||||
|
||||
# if self_rank == 0:
|
||||
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||||
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
|
||||
)
|
||||
result_squeezed = if_context.op("Squeeze", result)
|
||||
utils._add_output_to_block(if_context.block, result_squeezed)
|
||||
result_identity = else_context.op("Identity", result)
|
||||
utils._add_output_to_block(else_context.block, result_identity)
|
||||
result_final = if_op.node().output()
|
||||
|
||||
return result_final
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F401,F403
|
||||
|
@ -1,239 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 17.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset17."""
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 17]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
BlackmanWindow
|
||||
DFT
|
||||
HammingWindow
|
||||
HannWindow
|
||||
LayerNormalization
|
||||
MelWeightMatrix
|
||||
STFT
|
||||
SequenceMap
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
__all__: list[str] = []
|
||||
|
||||
__all__ = ["layer_norm", "stft", "quantized_layer_norm"]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::layer_norm")
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
|
||||
def layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: _C.Value,
|
||||
bias: _C.Value,
|
||||
eps: float,
|
||||
cudnn_enable: bool,
|
||||
):
|
||||
# normalized_shape: input shape from an expected input of size
|
||||
# axis: The first normalization dimension.
|
||||
# layer_norm normalizes on the last D dimensions,
|
||||
# where D is the size of normalized_shape
|
||||
axis = -len(normalized_shape)
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
input, _type_utils.JitScalarType.FLOAT
|
||||
)
|
||||
dtype = scalar_type.dtype()
|
||||
if symbolic_helper._is_none(weight):
|
||||
weight_value = torch.ones(normalized_shape, dtype=dtype)
|
||||
weight = g.op("Constant", value_t=weight_value)
|
||||
if symbolic_helper._is_none(bias):
|
||||
bias_value = torch.zeros(normalized_shape, dtype=dtype)
|
||||
bias = g.op("Constant", value_t=bias_value)
|
||||
return g.op(
|
||||
"LayerNormalization",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
epsilon_f=eps,
|
||||
axis_i=axis,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::layer_norm")
|
||||
def quantized_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
x,
|
||||
normalized_shape,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
op_scale,
|
||||
op_zero_point,
|
||||
):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
output = layer_norm(g, x, normalized_shape, weight, bias, eps, False)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
|
||||
def _compute_edge_sizes(n_fft, window_size):
|
||||
"""Helper function to compute the sizes of the edges (left and right)
|
||||
of a given window centered within an FFT size."""
|
||||
left = (n_fft - window_size) // 2
|
||||
right = n_fft - left - window_size
|
||||
return left, right
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::stft")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
|
||||
def stft(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
n_fft: int,
|
||||
hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None,
|
||||
window: Optional[_C.Value] = None,
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = True,
|
||||
return_complex: Optional[bool] = False,
|
||||
align_to_window: Optional[bool] = None,
|
||||
) -> _C.Value:
|
||||
"""Associates `torch.stft` with the `STFT` ONNX operator.
|
||||
Note that torch.stft calls _VF.stft, without centering or padding options.
|
||||
Hence, this function does not contain these two arguments.
|
||||
See torch.stft source code for more info.
|
||||
|
||||
Args:
|
||||
g: Graph to write the ONNX representation into
|
||||
input: Input tensor for the transformation
|
||||
n_fft: FFT size
|
||||
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
|
||||
win_length: Size of the analysis window. Defaults to `n_fft`
|
||||
window: Analysis window. Defaults to a window of all ones
|
||||
normalized: Whether to return a normalized STFT
|
||||
onesided: Whether to return only half (+1) of the results, given the
|
||||
symmetry of the STFT
|
||||
return_complex: Whether to return the complex value (Note: Must be
|
||||
`False` or `None`)
|
||||
|
||||
Returns:
|
||||
op: Operator for torch.stft associated with STFT (ONNX)
|
||||
"""
|
||||
# Checks
|
||||
if return_complex:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT does not currently support complex types", value=input
|
||||
)
|
||||
|
||||
if align_to_window is not None:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT does not currently support the align_to_window option",
|
||||
value=input,
|
||||
) # TODO(#145944): add compatibility with align_to_window option.
|
||||
|
||||
# Get STFT sizes
|
||||
frame_step_value = hop_length if hop_length is not None else n_fft // 4
|
||||
frame_step_const = g.op(
|
||||
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
|
||||
)
|
||||
frame_length_const = g.op(
|
||||
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Pre-process input if needed
|
||||
signal = input
|
||||
signal_rank = symbolic_helper._get_tensor_rank(signal)
|
||||
if signal_rank == 1:
|
||||
# Add batch dimension
|
||||
signal = g.op(
|
||||
"Unsqueeze",
|
||||
signal,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
elif signal_rank is None or signal_rank > 2:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
|
||||
f"Current rank of signal is {signal_rank}, please reduce it.",
|
||||
value=input,
|
||||
)
|
||||
|
||||
# Get window and make sure it's the same size as `win_length` or `n_fft`
|
||||
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
|
||||
if n_win is not None:
|
||||
win_length_default = win_length if win_length else n_fft
|
||||
assert n_win == win_length_default, (
|
||||
"Analysis window size must equal `win_length` or `n_fft`. "
|
||||
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
|
||||
)
|
||||
|
||||
# Center window around zeros if needed (required by ONNX's STFT)
|
||||
if n_win < n_fft:
|
||||
left, right = _compute_edge_sizes(n_fft, n_win)
|
||||
left_win = g.op("Constant", value_t=torch.zeros(left))
|
||||
right_win = g.op("Constant", value_t=torch.zeros(right))
|
||||
window = g.op("Concat", left_win, window, right_win, axis_i=0)
|
||||
|
||||
# Create window, if needed
|
||||
if symbolic_helper._is_none(window):
|
||||
if win_length:
|
||||
if win_length > n_fft:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="The analysis window can't be longer than the size of the FFT. "
|
||||
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
|
||||
value=input,
|
||||
)
|
||||
|
||||
# Center window, if needed
|
||||
left, right = _compute_edge_sizes(n_fft, win_length)
|
||||
torch_window = torch.hstack(
|
||||
(torch.zeros(left), torch.ones(win_length), torch.zeros(right))
|
||||
)
|
||||
else:
|
||||
# Rectangle window
|
||||
torch_window = torch.ones(n_fft)
|
||||
assert torch_window.shape[0] == n_fft
|
||||
window = g.op("Constant", value_t=torch_window)
|
||||
window = g.op(
|
||||
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
|
||||
)
|
||||
|
||||
# Run STFT
|
||||
result = g.op(
|
||||
"STFT",
|
||||
signal,
|
||||
frame_step_const,
|
||||
window,
|
||||
frame_length_const,
|
||||
onesided_i=1 if onesided is None or onesided else 0,
|
||||
)
|
||||
|
||||
# Transpose to mimic torch.stft's behavior
|
||||
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
|
||||
|
||||
# Remove batch dimension, if needed
|
||||
if signal_rank == 1:
|
||||
result = g.op(
|
||||
"Squeeze",
|
||||
result,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
|
||||
# Normalize, if needed
|
||||
if normalized:
|
||||
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
|
||||
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
|
||||
|
||||
return result
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F401,F403
|
||||
|
@ -1,265 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 18.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset18."""
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 18]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
BitwiseAnd
|
||||
CenterCropPad
|
||||
Col2Im
|
||||
Mish
|
||||
OptionalGetElement
|
||||
OptionalHasElement
|
||||
Pad
|
||||
Resize
|
||||
ScatterElements
|
||||
ScatterND
|
||||
Split
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
__all__: list[str] = []
|
||||
|
||||
__all__ = [
|
||||
"col2im",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__and_")
|
||||
@_onnx_symbolic("aten::bitwise_and")
|
||||
def __and_(g: jit_utils.GraphContext, self, other):
|
||||
# do type promotion (scalars don't seem to apply)
|
||||
args = [self, other]
|
||||
# type promotion doesn't happen with torch.bitwise_and(tensor, scalar)
|
||||
prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)]
|
||||
if len(prom_args) == 0:
|
||||
prom_args = args
|
||||
promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args)
|
||||
self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type)
|
||||
other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type)
|
||||
if promotion_jit_type == _type_utils.JitScalarType.BOOL:
|
||||
return g.op("And", self, other)
|
||||
return g.op("BitwiseAnd", self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::col2im")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
|
||||
def col2im(
|
||||
g,
|
||||
input: _C.Value,
|
||||
output_size: _C.Value,
|
||||
kernel_size: _C.Value,
|
||||
dilation: Sequence[int],
|
||||
padding: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
):
|
||||
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
||||
adjusted_padding: list[int] = []
|
||||
for pad in padding:
|
||||
adjusted_padding.extend(pad for _ in range(2))
|
||||
|
||||
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
||||
if not adjusted_padding:
|
||||
adjusted_padding = [0, 0] * num_dimensional_axis
|
||||
|
||||
if not dilation:
|
||||
dilation = [1] * num_dimensional_axis
|
||||
|
||||
if not stride:
|
||||
stride = [1] * num_dimensional_axis
|
||||
|
||||
return g.op(
|
||||
"Col2Im",
|
||||
input,
|
||||
output_size,
|
||||
kernel_size,
|
||||
dilations_i=dilation,
|
||||
pads_i=adjusted_padding,
|
||||
strides_i=stride,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::prod",
|
||||
decorate=[
|
||||
symbolic_helper._apply_params(
|
||||
"ReduceProd", "prod", allow_multi_dim_support=False
|
||||
)
|
||||
],
|
||||
)
|
||||
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||||
return symbolic_helper._reduce_with_dtype_helper(
|
||||
onnx_op, name, allow_multi_dim_support
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::native_layer_norm")
|
||||
@symbolic_helper.quantized_args(True, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
|
||||
def _native_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: _C.Value,
|
||||
bias: _C.Value,
|
||||
eps: float,
|
||||
) -> tuple[_C.Value, _C.Value, _C.Value]:
|
||||
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::glu")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
def _glu(g: jit_utils.GraphContext, input, dim):
|
||||
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
|
||||
if dim_size is not None:
|
||||
assert dim_size % 2 == 0
|
||||
|
||||
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
|
||||
return g.op("Mul", first, g.op("Sigmoid", second))
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::max")
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::maximum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
def maximum(g: jit_utils.GraphContext, input, other):
|
||||
return max(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::minimum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
def minimum(g: jit_utils.GraphContext, input, other):
|
||||
return min(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amin")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::aminmax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
if not symbolic_helper._is_none(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, axes, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, keepdims_i=keepdim
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::var_mean")
|
||||
def _var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
if len(args) == 1:
|
||||
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
|
||||
else:
|
||||
return symbolic_helper._var_mean_helper(g, input, *args)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::logsumexp")
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||||
if dim is None:
|
||||
return g.op("ReduceLogSumExp", input, keepdims_i=0)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||||
def _linalg_matrix_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: torch._C.Value,
|
||||
dim: list[int],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
):
|
||||
return symbolic_helper._embedding_bag_helper(
|
||||
g,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_vector_norm")
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
def linalg_vector_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: float,
|
||||
dim: Optional[Sequence[int]],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F401,F403
|
||||
|
@ -1,31 +1,8 @@
|
||||
"""This file exports ONNX ops for opset 19.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset19."""
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 19]
|
||||
from __future__ import annotations
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AveragePool
|
||||
Cast
|
||||
CastLike
|
||||
Constant
|
||||
DeformConv
|
||||
DequantizeLinear
|
||||
Equal
|
||||
Identity
|
||||
If
|
||||
Loop
|
||||
Pad
|
||||
QuantizeLinear
|
||||
Reshape
|
||||
Resize
|
||||
Scan
|
||||
Shape
|
||||
Size
|
||||
"""
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F401,F403
|
||||
|
@ -1,92 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 20.
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset20."""
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 20]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AffineGrid
|
||||
ConstantOfShape
|
||||
DFT
|
||||
Gelu
|
||||
GridSample
|
||||
ImageDecoder
|
||||
IsInf
|
||||
IsNaN
|
||||
ReduceMax
|
||||
ReduceMin
|
||||
RegexFullMatch
|
||||
StringConcat
|
||||
StringSplit
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
__all__: list[str] = []
|
||||
|
||||
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
|
||||
|
||||
|
||||
def convert_grid_sample_mode(mode_s):
|
||||
return (
|
||||
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
def _grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
grid: _C.Value,
|
||||
mode_enum: int,
|
||||
padding_mode_enum: int,
|
||||
align_corners: bool,
|
||||
):
|
||||
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
|
||||
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
|
||||
mode_s = convert_grid_sample_mode(mode_s)
|
||||
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
|
||||
padding_mode_enum # type: ignore[index]
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::affine_grid_generator")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def _affine_grid_generator(
|
||||
g: jit_utils.GraphContext,
|
||||
theta: _C.Value,
|
||||
size: _C.Value,
|
||||
align_corners: bool,
|
||||
):
|
||||
return g.op(
|
||||
"AffineGrid",
|
||||
theta,
|
||||
size,
|
||||
align_corners_i=int(align_corners),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::gelu")
|
||||
@symbolic_helper.parse_args("v", "s")
|
||||
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
|
||||
return g.op("Gelu", self, approximate_s=approximate)
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F401,F403
|
||||
|
@ -1,67 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Note [ONNX operators that are added/updated from opset 7 to opset 8]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
Expand
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset7."""
|
||||
|
||||
Updated operators:
|
||||
Min, Max, Sum, Mean: supports multidirectional broadcasting.
|
||||
MaxPool: added optional indices output.
|
||||
Scan
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7)
|
||||
__all__: list[str] = []
|
||||
|
||||
block_listed_operators = (
|
||||
"scan",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"meshgrid",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d_with_indices",
|
||||
)
|
||||
|
||||
|
||||
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
@_onnx_symbolic("aten::max")
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn(
|
||||
"Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to max operators "
|
||||
"have different shapes"
|
||||
)
|
||||
return opset9.max(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn(
|
||||
"Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to min operators "
|
||||
"have different shapes"
|
||||
)
|
||||
return opset9.min(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
for block_listed_op in block_listed_operators:
|
||||
_onnx_symbolic(f"aten::{block_listed_op}")(
|
||||
symbolic_helper._block_list_in_opset(block_listed_op)
|
||||
)
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F401,F403
|
||||
|
@ -1,463 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
New operators:
|
||||
Compress
|
||||
ConstantOfShape
|
||||
EyeLike
|
||||
MaxUnpool
|
||||
OneHot
|
||||
Sinh
|
||||
Cosh
|
||||
Asinh
|
||||
Acosh
|
||||
Atanh
|
||||
Shrink
|
||||
IsNaN
|
||||
Sign
|
||||
Erf
|
||||
Scatter
|
||||
Where
|
||||
NonZero
|
||||
TfIdfVectorizer
|
||||
MeanVarianceNormalization
|
||||
"""Backward compatibility module for torch.onnx.symbolic_opset8."""
|
||||
|
||||
Updated operators:
|
||||
BatchNormalization: removed spatial attribute.
|
||||
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
|
||||
Cast: more data types{string} supported.
|
||||
Upsample: moved scales from attribute to input.
|
||||
Scan
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
|
||||
__all__: list[str] = []
|
||||
|
||||
block_listed_operators = (
|
||||
"nonzero",
|
||||
"where",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"erf",
|
||||
"sign",
|
||||
"isnan",
|
||||
"gather",
|
||||
"arange",
|
||||
"masked_fill",
|
||||
"index_fill",
|
||||
"index_copy",
|
||||
"repeat_interleave",
|
||||
"any",
|
||||
"all",
|
||||
)
|
||||
|
||||
for block_listed_op in block_listed_operators:
|
||||
_onnx_symbolic(f"aten::{block_listed_op}")(
|
||||
symbolic_helper._block_list_in_opset(block_listed_op)
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||||
g, interpolate_mode, args
|
||||
)
|
||||
symbolic_helper._interpolate_warning(interpolate_mode)
|
||||
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
||||
if align_corners:
|
||||
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
||||
output_size = symbolic_helper._maybe_get_const(output_size, "is")
|
||||
if symbolic_helper._is_value(output_size):
|
||||
return symbolic_helper._unimplemented(
|
||||
name, "torch._C.Value (output_size) indexing"
|
||||
)
|
||||
if scales is None:
|
||||
scales = [
|
||||
1.0
|
||||
if i < 2
|
||||
else float(output_size[-(dim - i)])
|
||||
/ float(input.type().sizes()[-(dim - i)])
|
||||
for i in range(0, dim)
|
||||
]
|
||||
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__interpolate")
|
||||
def __interpolate(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
size,
|
||||
scale_factor,
|
||||
mode,
|
||||
align_corners,
|
||||
recompute_scale_factor,
|
||||
antialias,
|
||||
):
|
||||
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
|
||||
if not symbolic_helper._is_none(align_corners) and align_corners:
|
||||
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
|
||||
|
||||
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
|
||||
scale_factor
|
||||
):
|
||||
return symbolic_helper._unimplemented(
|
||||
"interpolate", "dynamic scales in opset 8"
|
||||
)
|
||||
|
||||
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
|
||||
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
|
||||
|
||||
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
||||
g, input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
|
||||
|
||||
|
||||
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
|
||||
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
|
||||
# is lost after casting.
|
||||
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
||||
floating_scalar_types = {
|
||||
_type_utils.JitScalarType.HALF,
|
||||
_type_utils.JitScalarType.FLOAT,
|
||||
_type_utils.JitScalarType.DOUBLE,
|
||||
}
|
||||
old_type = None
|
||||
# Cast the input tensor to Float if its scalarType is known and is not floating number.
|
||||
# If casting is performed, return the old scalarType, otherwise return None.
|
||||
arg0_type = _type_utils.JitScalarType.from_value(
|
||||
args[0], _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
|
||||
old_type = arg0_type
|
||||
if old_type not in floating_scalar_types:
|
||||
old_type = old_type.scalar_name() # type: ignore[assignment]
|
||||
args = tuple(
|
||||
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
for arg in args
|
||||
)
|
||||
else:
|
||||
return (None,) + args
|
||||
else:
|
||||
warnings.warn(
|
||||
"Only floating datatype is supported for these operators: "
|
||||
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
|
||||
"the onnx model to be incorrect, if inputs have integer datatypes."
|
||||
)
|
||||
return (old_type,) + args
|
||||
|
||||
|
||||
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
||||
if to_type is None:
|
||||
return input
|
||||
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
|
||||
|
||||
|
||||
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
||||
other = symbolic_helper._maybe_get_scalar(other)
|
||||
other = symbolic_helper._if_scalar_type_as(other, input)
|
||||
_, input, other = _try_cast_integer_to_float(g, input, other)
|
||||
return g.op(op_name, input, other)
|
||||
|
||||
|
||||
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
|
||||
# integer input type not supported in opset8. Cast to float if possible.
|
||||
@_onnx_symbolic("aten::gt")
|
||||
def gt(g: jit_utils.GraphContext, input, other):
|
||||
return _comparison_operator(g, input, other, "Greater")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::lt")
|
||||
def lt(g: jit_utils.GraphContext, input, other):
|
||||
return _comparison_operator(g, input, other, "Less")
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::bmm")
|
||||
def bmm(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, other = _try_cast_integer_to_float(g, self, other)
|
||||
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
|
||||
else:
|
||||
return g.op("MatMul", self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::matmul")
|
||||
def matmul(g: jit_utils.GraphContext, self, other):
|
||||
return bmm(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::prelu")
|
||||
def prelu(g: jit_utils.GraphContext, self, weight):
|
||||
self_rank = symbolic_helper._get_tensor_rank(self)
|
||||
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
|
||||
if self_rank is not None and self_rank > 2:
|
||||
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
|
||||
elif self_rank == 0 and weight_sizes == [1]:
|
||||
# self and weight are both scalar but weight has rank == 1, squeeze weight.
|
||||
weight = symbolic_helper._squeeze_helper(g, weight, [0])
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
|
||||
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
|
||||
else:
|
||||
return g.op("PRelu", self, weight)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::mm")
|
||||
def mm(g: jit_utils.GraphContext, self, other):
|
||||
# Create a dummy C tensor. Only needed for API purposes, the value is
|
||||
# since beta = 0
|
||||
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
|
||||
if scalar_type is None:
|
||||
raise errors.SymbolicValueError(
|
||||
"mm can only operate on tensors with known types", self
|
||||
)
|
||||
zero_constant = g.op(
|
||||
"Constant",
|
||||
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
|
||||
)
|
||||
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, other, zero_constant = _try_cast_integer_to_float(
|
||||
g, self, other, zero_constant
|
||||
)
|
||||
return _cast_to_type(
|
||||
g,
|
||||
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
|
||||
old_type,
|
||||
)
|
||||
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::addmm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
|
||||
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
|
||||
if symbolic_helper._try_get_scalar_type(self):
|
||||
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
|
||||
return _cast_to_type(
|
||||
g,
|
||||
g.op(
|
||||
"Gemm",
|
||||
mat1,
|
||||
mat2,
|
||||
self,
|
||||
beta_f=symbolic_helper._scalar(beta),
|
||||
alpha_f=symbolic_helper._scalar(alpha),
|
||||
),
|
||||
old_type,
|
||||
)
|
||||
else:
|
||||
return g.op(
|
||||
"Gemm",
|
||||
mat1,
|
||||
mat2,
|
||||
self,
|
||||
beta_f=symbolic_helper._scalar(beta),
|
||||
alpha_f=symbolic_helper._scalar(alpha),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::flatten")
|
||||
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
||||
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
|
||||
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
|
||||
|
||||
dim = input.type().dim()
|
||||
if end_dim_i < 0:
|
||||
end_dim_i = dim + end_dim_i
|
||||
# use ONNX's Flatten operator for cases where the output shape is 2D
|
||||
if start_dim_i == 1 and end_dim_i == dim - 1:
|
||||
if symbolic_helper._try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(
|
||||
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
|
||||
)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=start_dim_i)
|
||||
if start_dim_i == 0 and end_dim_i == dim - 2:
|
||||
if symbolic_helper._try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(
|
||||
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
|
||||
)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=end_dim_i + 1)
|
||||
|
||||
return opset9.flatten(g, input, start_dim, end_dim)
|
||||
|
||||
|
||||
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
|
||||
if dtype is None:
|
||||
scalar_type = _type_utils.JitScalarType.FLOAT
|
||||
else:
|
||||
scalar_type = _type_utils.JitScalarType(dtype)
|
||||
if not scalar_type.dtype().is_floating_point:
|
||||
result = g.op(
|
||||
"ConstantFill",
|
||||
sizes,
|
||||
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
|
||||
input_as_shape_i=1,
|
||||
value_f=const_value,
|
||||
)
|
||||
return g.op("Cast", result, to_i=scalar_type.onnx_type())
|
||||
else:
|
||||
return g.op(
|
||||
"ConstantFill",
|
||||
sizes,
|
||||
dtype_i=scalar_type.onnx_type(),
|
||||
input_as_shape_i=1,
|
||||
value_f=const_value,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::empty")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def empty(
|
||||
g: jit_utils.GraphContext,
|
||||
sizes,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::empty_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def empty_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
return zeros_like(g, input, dtype, layout, device, pin_memory)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::zeros")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||||
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||||
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
||||
return _constant_fill(g, sizes, dtype, 0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::zeros_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def zeros_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ones")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||||
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||||
return _constant_fill(g, sizes, dtype, 1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ones_like")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def ones_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::full")
|
||||
def full(
|
||||
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
|
||||
):
|
||||
const_value = symbolic_helper._maybe_get_const(value, "t")
|
||||
if symbolic_helper._is_value(const_value):
|
||||
tmp = zeros(g, sizes, dtype, layout, device)
|
||||
return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
||||
else:
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
return _constant_fill(g, sizes, dtype, const_value)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::full_like")
|
||||
@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
|
||||
def full_like(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
fill_value,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory=False,
|
||||
memory_format=None,
|
||||
):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, fill_value)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::repeat")
|
||||
def repeat(g: jit_utils.GraphContext, self, repeats):
|
||||
if not symbolic_helper._is_value(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
if symbolic_helper._is_packed_list(repeats):
|
||||
repeat_size_len = len(symbolic_helper._unpack_list(repeats))
|
||||
else:
|
||||
const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
|
||||
repeat_size_len = len(const_repeats)
|
||||
if self.isCompleteTensor():
|
||||
sizes = self.type().sizes()
|
||||
diff_dims = repeat_size_len - len(sizes)
|
||||
if diff_dims > 0:
|
||||
self = opset9.view(
|
||||
g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
|
||||
)
|
||||
return g.op("Tile", self, repeats)
|
||||
from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F401,F403
|
||||
|
File diff suppressed because it is too large
Load Diff
1878
torch/onnx/utils.py
1878
torch/onnx/utils.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user