diff --git a/docs/source/onnx.md b/docs/source/onnx.md index b0ed78dbe69b..73a24b671553 100644 --- a/docs/source/onnx.md +++ b/docs/source/onnx.md @@ -113,7 +113,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ .. autofunction:: register_custom_op_symbolic .. autofunction:: unregister_custom_op_symbolic .. autofunction:: select_model_mode_for_export -.. autoclass:: JitScalarType ``` ```{eval-rst} diff --git a/docs/source/onnx_verification.md b/docs/source/onnx_verification.md index cbaad021e960..4036aea8f81a 100644 --- a/docs/source/onnx_verification.md +++ b/docs/source/onnx_verification.md @@ -1,4 +1,5 @@ # torch.onnx.verification + ```{eval-rst} .. automodule:: torch.onnx.verification ``` @@ -11,23 +12,3 @@ .. autoclass:: VerificationInfo :members: ``` - -```{eval-rst} -.. autofunction:: verify -``` - -## Deprecated - -The following classes and functions are deprecated. - - -```{eval-rst} -.. py:class:: check_export_model_diff -.. py:class:: GraphInfo -.. py:class:: GraphInfoPrettyPrinter -.. py:class:: OnnxBackend -.. py:class:: OnnxTestCaseRepro -.. py:class:: VerificationOptions -.. py:function:: find_mismatch -.. py:function:: verify_aten_graph -``` diff --git a/test/onnx/internal/test_registraion.py b/test/onnx/internal/test_registraion.py index e357dbff713a..fcc4cdeedd92 100644 --- a/test/onnx/internal/test_registraion.py +++ b/test/onnx/internal/test_registraion.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from torch.onnx import errors -from torch.onnx._internal import registration +from torch.onnx._internal.torchscript_exporter import registration from torch.testing._internal import common_utils diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index be6cc066e6b9..ab2bfb51bdea 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -17,7 +17,8 @@ import pytorch_test_common import torch from torch import export as torch_export -from torch.onnx import _constants, verification +from torch.onnx import _constants +from torch.onnx._internal.torchscript_exporter import verification from torch.testing._internal import common_utils from torch.testing._internal.opinfo import core as opinfo_core from torch.types import Number diff --git a/test/onnx/test_autograd_funs.py b/test/onnx/test_autograd_funs.py index cfeec9553ab7..81c70d7d9877 100644 --- a/test/onnx/test_autograd_funs.py +++ b/test/onnx/test_autograd_funs.py @@ -5,13 +5,11 @@ from onnx_test_common import run_model_test import torch from torch.onnx import OperatorExportTypes -from torch.onnx._globals import GLOBALS -from torch.onnx.utils import _model_to_graph from torch.testing._internal import common_utils class TestAutogradFuns(pytorch_test_common.ExportTestCase): - opset_version = GLOBALS.export_onnx_opset_version + opset_version = 20 keep_initializers_as_inputs = False onnx_shape_inference = True @@ -133,7 +131,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase): input = torch.ones(1, 5) # Test ONNX_FALLTHROUGH_MODE - graph, _, _ = _model_to_graph( + graph, _, _ = torch.onnx.utils._model_to_graph( model, (input,), operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, @@ -142,7 +140,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase): self.assertEqual(next(iter).kind(), "prim::PythonOp") # Test ATEN_FALLBACK_MODE - graph, _, _ = _model_to_graph( + graph, _, _ = torch.onnx.utils._model_to_graph( model, (input,), operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 29ac8f108c2d..f29062cdd0ca 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -11,7 +11,7 @@ import torch import torch.onnx from torch.nn import Module from torch.onnx import producer_name, producer_version -from torch.onnx._globals import GLOBALS +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS from torch.testing._internal import common_utils diff --git a/test/onnx/test_onnxscript_no_runtime.py b/test/onnx/test_onnxscript_no_runtime.py index 17e92f0e0117..98c44b115cb2 100644 --- a/test/onnx/test_onnxscript_no_runtime.py +++ b/test/onnx/test_onnxscript_no_runtime.py @@ -10,7 +10,7 @@ import onnxscript from onnxscript.onnx_types import FLOAT import torch -from torch.onnx._internal import jit_utils +from torch.onnx._internal.torchscript_exporter import jit_utils from torch.testing._internal import common_utils diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py index 23205045e838..dc19971498d9 100644 --- a/test/onnx/test_onnxscript_runtime.py +++ b/test/onnx/test_onnxscript_runtime.py @@ -9,7 +9,7 @@ import onnxscript from onnxscript.onnx_types import FLOAT import torch -from torch.onnx._internal import jit_utils +from torch.onnx._internal.torchscript_exporter import jit_utils from torch.testing._internal import common_utils diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index 68f26aea8b89..bc3c64ab8679 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -4,8 +4,11 @@ import pytorch_test_common from pytorch_test_common import skipIfNoCuda import torch -from torch.onnx import verification -from torch.onnx._globals import GLOBALS +from torch.onnx._internal.torchscript_exporter import verification +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +from torch.onnx._internal.torchscript_exporter.utils import ( + _trigger_symbolic_function_registration, +) from torch.testing._internal import common_utils @@ -20,6 +23,7 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): """ GLOBALS.export_onnx_opset_version = opset_version + _trigger_symbolic_function_registration() graph = torch.onnx.utils._optimize_graph( graph, operator_export_type, params_dict={} ) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index b3a3aa01cf3c..590e985460c2 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -22,7 +22,7 @@ import torch import torch.nn.functional as F from torch import Tensor from torch.onnx import symbolic_helper, utils -from torch.onnx._internal import registration +from torch.onnx._internal.torchscript_exporter import registration from torch.testing._internal import common_quantization, common_utils, jit_utils @@ -430,9 +430,8 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): torch.randn(3, 4, requires_grad=True), mocks=[ unittest.mock.patch( - "torch.onnx._internal.registration.registry.get_function_group", + "torch.onnx._internal.torchscript_exporter.registration.registry.get_function_group", side_effect=break_is_registered_op_api, - # wraps=registration.registry.get_function_group ) ], operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index f99380840679..1e86829c43ba 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -41,7 +41,9 @@ from pytorch_test_common import ( import torch from torch import Tensor from torch.nn.utils import rnn as rnn_utils -from torch.onnx import errors, verification +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import verification +from torch.onnx._internal.torchscript_exporter._type_utils import JitScalarType from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfNoLapack @@ -13705,7 +13707,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime): input_names=["x"], ) exported = onnx.load_from_string(f.getvalue()) - expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type() + expected_elem_type = JitScalarType.from_value(x).onnx_type() expected_output_type = onnx.helper.make_optional_type_proto( onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)) ) diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index 801d84844935..7cdf48937963 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -10,8 +10,8 @@ from pytorch_test_common import skipIfUnsupportedMinOpsetVersion import torch from torch.onnx import _constants, utils -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils +from torch.onnx._internal.torchscript_exporter import jit_utils +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS from torch.testing._internal import common_utils diff --git a/test/onnx/test_symbolic_helper.py b/test/onnx/test_symbolic_helper.py index b7358fc1ec41..cc7a3a133732 100644 --- a/test/onnx/test_symbolic_helper.py +++ b/test/onnx/test_symbolic_helper.py @@ -3,7 +3,7 @@ import torch from torch.onnx import symbolic_helper -from torch.onnx._globals import GLOBALS +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS from torch.testing._internal import common_utils diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 387a8985879b..fe3a4b162235 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1,11 +1,9 @@ # Owner(s): ["module: onnx"] import copy -import functools import io import re import warnings -from typing import Callable import onnx @@ -23,7 +21,7 @@ import torch import torch.onnx import torch.utils.cpp_extension from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils -from torch.onnx._globals import GLOBALS +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfNoLapack @@ -86,86 +84,6 @@ class _BaseTestCase(pytorch_test_common.ExportTestCase): return graph, params_dict, torch_out -@common_utils.instantiate_parametrized_tests -class TestUnconvertibleOps(pytorch_test_common.ExportTestCase): - """Unit tests for the `unconvertible_ops` function.""" - - def setUp(self): - class EinsumModule(torch.nn.Module): - def forward(self, x): - return torch.einsum("ii", x) - - self.einsum_module = EinsumModule() - - def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self): - x = torch.randn(4, 4) - - # Einsum is supported since opset 12. It should be unconvertible at opset 9. - graph, unconvertible_ops = utils.unconvertible_ops( - self.einsum_module, (x,), opset_version=9 - ) - nodes = graph.nodes() - self.assertEqual(next(nodes).kind(), "prim::Constant") - self.assertEqual(next(nodes).kind(), "prim::ListConstruct") - self.assertEqual(next(nodes).kind(), "prim::Constant") - self.assertEqual(next(nodes).kind(), "aten::einsum") - self.assertEqual(unconvertible_ops, ["aten::einsum"]) - - @common_utils.parametrize( - "jit_function", - [ - common_utils.subtest( - functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), - name="traced", - ), - common_utils.subtest(torch.jit.script, name="scripted"), - ], - ) - def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module( - self, jit_function: Callable - ): - module = jit_function(self.einsum_module) - x = torch.randn(4, 4) - - # Einsum is supported since opset 12. It should be unconvertible at opset 9. - _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9) - self.assertEqual(unconvertible_ops, ["aten::einsum"]) - - @common_utils.parametrize( - "jit_function", - [ - common_utils.subtest(lambda x: x, name="nn_module"), - common_utils.subtest( - functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), - name="traced", - ), - common_utils.subtest(torch.jit.script, name="scripted"), - ], - ) - def test_it_returns_empty_list_when_all_ops_convertible( - self, jit_function: Callable - ): - module = jit_function(self.einsum_module) - x = torch.randn(4, 4) - - # Einsum is supported since opset 12 - _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12) - self.assertEqual(unconvertible_ops, []) - - def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self): - class SkipConnectionModule(torch.nn.Module): - def forward(self, x): - out = x - out += x - out = torch.nn.functional.relu(out, inplace=True) - return out - - module = SkipConnectionModule() - x = torch.randn(4, 4) - _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13) - self.assertEqual(unconvertible_ops, []) - - @parameterized.parameterized_class( [ {"opset_version": opset} diff --git a/test/onnx/test_verification.py b/test/onnx/test_verification.py index 4d2b4676d9b1..ac9a3a475376 100644 --- a/test/onnx/test_verification.py +++ b/test/onnx/test_verification.py @@ -13,7 +13,8 @@ import pytorch_test_common from packaging import version import torch -from torch.onnx import _constants, _experimental, verification +from torch.onnx import _constants +from torch.onnx._internal.torchscript_exporter import _experimental, verification from torch.testing._internal import common_utils diff --git a/test/test_utils.py b/test/test_utils.py index 1c515c9dcac2..0314da6e320a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -21,8 +21,6 @@ import torch.utils.cpp_extension import torch.utils.data from torch._utils import try_import from torch._utils_internal import deprecated -from torch.autograd._functions.utils import check_onnx_broadcast -from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, @@ -790,65 +788,6 @@ class TestCollectEnv(TestCase): self.assertTrue(info_output.count("\n") >= 17) -class TestONNXUtils(TestCase): - def test_prepare_onnx_paddings(self): - sizes = [2, 3, 4] - pad = [1, 2, 3, 4] - paddings = _prepare_onnx_paddings(len(sizes), pad) - self.assertEqual(paddings, [0, 3, 1, 0, 4, 2]) - - def test_check_onnx_broadcast(self): - def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail): - broadcast = True - fail = False - try: - broadcast = check_onnx_broadcast(dims1, dims2) - except ValueError: - fail = True - self.assertEqual(broadcast, expect_broadcast) - self.assertEqual(fail, expect_fail) - - # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1 - dims1 = [3, 4] - dims2 = [2, 3, 4] - try_check_onnx_broadcast(dims1, dims2, True, True) - - # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1 - dims1 = [3, 4] - dims2 = [1, 1, 1] - try_check_onnx_broadcast(dims1, dims2, True, False) - - # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1 - dims1 = [1, 1] - dims2 = [1] - try_check_onnx_broadcast(dims1, dims2, True, False) - - # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2 - dims1 = [2, 3, 4] - dims2 = [3, 4] - try_check_onnx_broadcast(dims1, dims2, True, False) - - # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2 - dims1 = [2, 3, 4] - dims2 = [1, 4] - try_check_onnx_broadcast(dims1, dims2, True, True) - - # Case 6, check the equal case, no broadcast - dims1 = [3, 4] - dims2 = [3, 4] - try_check_onnx_broadcast(dims1, dims2, False, False) - - # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2 - dims1 = [3, 4] - dims2 = [1, 4] - try_check_onnx_broadcast(dims1, dims2, True, True) - - # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1 - dims1 = [3, 4] - dims2 = [1, 1] - try_check_onnx_broadcast(dims1, dims2, True, False) - - class TestHipify(TestCase): def test_import_hipify(self): from torch.utils.hipify import hipify_python # noqa: F401 diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py index a3f242920c7e..1e74e21d3cef 100644 --- a/torch/autograd/_functions/utils.py +++ b/torch/autograd/_functions/utils.py @@ -1,6 +1,4 @@ # mypy: allow-untyped-defs -import operator -from functools import reduce def maybe_view(tensor, size, check_same_size=True): @@ -26,38 +24,3 @@ def maybe_unexpand(tensor, old_size, check_same_size=True): for dim in expanded_dims: tensor = tensor.sum(dim, keepdim=True) return tensor - - -# Check whether the op enable broadcasting, and whether it is supported by ONNX. -# If dims1 and dims2 are different, then broadcast is True. -# We always assume the combination of dims1 and dims2 is broadcastable. -# The following types of broadcasting are supported in ONNX: -# 1) Only one element in dims2, such as dims2 = [1, 1] -# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4] -# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm -def check_onnx_broadcast(dims1, dims2): - broadcast = False - supported = True - len1 = len(dims1) - len2 = len(dims2) - - numel2 = reduce(operator.mul, dims2) - if len1 < len2: - broadcast = True - if numel2 != 1: - supported = False - elif len1 > len2: - broadcast = True - if numel2 != 1 and dims1[len1 - len2 :] != dims2: - supported = False - else: - if dims1 != dims2: - broadcast = True - if numel2 != 1: - supported = False - - if not supported: - raise ValueError( - f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}" - ) - return broadcast diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index e209b4a3a14b..14591bc1fb4a 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1053,7 +1053,8 @@ void _trace_post_record( } } } - py::object onnx_globals = py::module::import("torch.onnx._globals"); + py::object onnx_globals = + py::module::import("torch.onnx._internal.torchscript_exporter._globals"); py::bool_ is_in_onnx_export = py::module::import("torch.onnx.__init__").attr("is_in_onnx_export"); py::bool_ is_autograd_inlining_enabled = diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index a0e6babe54b6..942c15125793 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -261,9 +261,10 @@ void NodeToONNX( py::dict& env, py::set& values_in_env) { py::object onnx = py::module::import("torch.onnx"); - py::object onnx_globals = py::module::import("torch.onnx._globals"); - py::object onnx_registration = - py::module::import("torch.onnx._internal.registration"); + py::object onnx_globals = + py::module::import("torch.onnx._internal.torchscript_exporter._globals"); + py::object onnx_registration = py::module::import( + "torch.onnx._internal.torchscript_exporter.registration"); // Setup all the lambda helper functions. diff --git a/torch/onnx/README.md b/torch/onnx/README.md index 7c8596365f27..3878f48d70be 100644 --- a/torch/onnx/README.md +++ b/torch/onnx/README.md @@ -4,92 +4,3 @@ Torch->ONNX converter / exporter. - User-facing docs: https://pytorch.org/docs/main/onnx.html - Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter - -> Read the following if you are contributing to `torch.onnx` - -## Symbolic functions Opsets - -Opset 9 is the base version. It is selected as the base version because - -1. It is the first opset version supported by PyTorch export. -2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations - that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, - we chose to handle them as special cases separately. - -Backward support for opset versions beyond opset 7 is not in our roadmap. - -For opset versions other than 9, by default they will inherit the symbolic functions defined in -symbolic_opset9.py. - -To extend support for updated operators in different opset versions on top of opset 9, -simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. -Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. - -## Editing Symbolic Files - -- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example. -- Parameter names must *exactly* match the names in - aten/src/ATen/native/native_functions.yaml, because - dispatch is done with keyword arguments. -- Looking for inplace ops? They're detected by - `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and - transparently dispatched to their non inplace versions in - "run_symbolic_function". See Note [Export inplace](#export-inplace) - -### A note on Tensor types - -In general, we should avoid depending on the type of Tensor Values contained -within the trace graph. However, this is sometimes unavoidable (due to ONNX -spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise. - -In general, we should prefer to rely on the least specific information possible. -For example, not relying on tensor properties at all is better than relying -on the number of dimensions which is better than relying on -concrete shapes. Doing so will make the export symbolics -more robust to different graphs. - -### Extra context for symbolic functions - -The first argument of a symbolic function is always a `GraphContext` object. - -`GraphContext` contains all methods defined in a `torch.Graph` object and context -for the symbolic function. - -In general, symbolic functions only require inputs and attributes to -the original node. An example of a symbolic function needing context is -`prim::Loop`. It needs access to the sub-block of the original node. - -### Export inplace - -It would be better for us to export inplace annotations, -than to not export them, since it is useful information that can -help the target of an ONNX export export more efficiently. However, -ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop -inplace annotations, but we are losing information this way. - -### Pointwise by scalar - -What happens if you add a tensor with a constant (e.g., x + 2)? There are -some moving parts to implementing the ONNX translation in this case: - -- By the time we get the scalar in a symbolic function here, it is no longer a - Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want - it to be a zero dim tensor but this change has not happened yet.) However, the - type of this scalar is *exactly* what the user wrote in Python, which may not - match the tensor it is being added to. PyTorch will do implicit conversions on - scalars; however, ONNX will not, so we must do the conversion ourselves. This - is what `symbolic_helper._if_scalar_type_as()` and - `_jit_pass_onnx_scalar_type_analysis` does. - -- Dispatch to these functions takes advantage an outrageous coincidence - between the tensor and scalar name. When we add two tensors together, - you get the dispatch: - - add(*[self, other], **{"alpha": alpha}) - - When you add a tensor and a scalar, you get the dispatch: - - add(*[self], **{"other": other, "alpha": alpha}) - - By having the argument name line up with the name of the scalar attribute - if it exists, we can write a single function for both overloads. diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 7eaa0a5677c4..748ecede13bc 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -6,78 +6,43 @@ __all__ = [ # Modules "errors", "ops", - "symbolic_helper", - "utils", - # All opsets - "symbolic_opset7", - "symbolic_opset8", - "symbolic_opset9", - "symbolic_opset10", - "symbolic_opset11", - "symbolic_opset12", - "symbolic_opset13", - "symbolic_opset14", - "symbolic_opset15", - "symbolic_opset16", - "symbolic_opset17", - "symbolic_opset18", - "symbolic_opset19", - "symbolic_opset20", - # Enums - "OperatorExportTypes", - "TrainingMode", - "TensorProtoDataType", - "JitScalarType", # Public functions "export", "is_in_onnx_export", - "select_model_mode_for_export", - "register_custom_op_symbolic", - "unregister_custom_op_symbolic", # Base error "OnnxExporterError", "ONNXProgram", ] from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import deprecated import torch from torch._C import _onnx as _C_onnx -from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode +from torch._C._onnx import ( # Deprecated members that are excluded from __all__ + OperatorExportTypes as OperatorExportTypes, + TensorProtoDataType as TensorProtoDataType, + TrainingMode as TrainingMode, +) +from . import errors, ops from ._internal.exporter._onnx_program import ONNXProgram -from ._type_utils import JitScalarType -from .errors import OnnxExporterError -from .utils import ( +from ._internal.torchscript_exporter import ( # Deprecated members that are excluded from __all__ + symbolic_helper, + symbolic_opset10, + symbolic_opset9, + utils, +) +from ._internal.torchscript_exporter._type_utils import ( + JitScalarType, # Deprecated members that are excluded from __all__ +) +from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__ _run_symbolic_function, _run_symbolic_method, register_custom_op_symbolic, select_model_mode_for_export, unregister_custom_op_symbolic, ) - - -from . import ( # usort: skip. Keep the order instead of sorting lexicographically - errors, - ops, - symbolic_helper, - symbolic_opset7, - symbolic_opset8, - symbolic_opset9, - symbolic_opset10, - symbolic_opset11, - symbolic_opset12, - symbolic_opset13, - symbolic_opset14, - symbolic_opset15, - symbolic_opset16, - symbolic_opset17, - symbolic_opset18, - symbolic_opset19, - symbolic_opset20, - utils, -) +from .errors import OnnxExporterError if TYPE_CHECKING: @@ -85,10 +50,10 @@ if TYPE_CHECKING: from collections.abc import Collection, Mapping, Sequence # Set namespace for exposed private names -JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" +# TODO(justinchuby): Remove these two properties producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION @@ -385,7 +350,7 @@ def export( else: import warnings - from torch.onnx.utils import export + from ._internal.torchscript_exporter.utils import export warnings.warn( "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, " @@ -429,7 +394,7 @@ def export( def is_in_onnx_export() -> bool: """Returns whether it is in the middle of ONNX export.""" - from torch.onnx._globals import GLOBALS from torch.onnx._internal.exporter import _flags + from torch.onnx._internal.torchscript_exporter._globals import GLOBALS return GLOBALS.in_onnx_export or _flags._is_onnx_exporting diff --git a/torch/onnx/_internal/torchscript_exporter/README.md b/torch/onnx/_internal/torchscript_exporter/README.md new file mode 100644 index 000000000000..af0ca464beda --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/README.md @@ -0,0 +1,91 @@ +# TorchScript Exporter + +> [!NOTE] +> This directory hosts code for the legacy TorchScript-based ONNX exporter. It is *deprecated* since PyTorch 2.9 and should be removed along with TorchScript. + +## Symbolic functions Opsets + +Opset 9 is the base version. It is selected as the base version because + +1. It is the first opset version supported by PyTorch export. +2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations + that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, + we chose to handle them as special cases separately. + +Backward support for opset versions beyond opset 7 is not in our roadmap. + +For opset versions other than 9, by default they will inherit the symbolic functions defined in +symbolic_opset9.py. + +To extend support for updated operators in different opset versions on top of opset 9, +simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. +Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. + +## Editing Symbolic Files + +- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example. +- Parameter names must *exactly* match the names in + aten/src/ATen/native/native_functions.yaml, because + dispatch is done with keyword arguments. +- Looking for inplace ops? They're detected by + `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and + transparently dispatched to their non inplace versions in + "run_symbolic_function". See Note [Export inplace](#export-inplace) + +### A note on Tensor types + +In general, we should avoid depending on the type of Tensor Values contained +within the trace graph. However, this is sometimes unavoidable (due to ONNX +spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise. + +In general, we should prefer to rely on the least specific information possible. +For example, not relying on tensor properties at all is better than relying +on the number of dimensions which is better than relying on +concrete shapes. Doing so will make the export symbolics +more robust to different graphs. + +### Extra context for symbolic functions + +The first argument of a symbolic function is always a `GraphContext` object. + +`GraphContext` contains all methods defined in a `torch.Graph` object and context +for the symbolic function. + +In general, symbolic functions only require inputs and attributes to +the original node. An example of a symbolic function needing context is +`prim::Loop`. It needs access to the sub-block of the original node. + +### Export inplace + +It would be better for us to export inplace annotations, +than to not export them, since it is useful information that can +help the target of an ONNX export export more efficiently. However, +ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop +inplace annotations, but we are losing information this way. + +### Pointwise by scalar + +What happens if you add a tensor with a constant (e.g., x + 2)? There are +some moving parts to implementing the ONNX translation in this case: + +- By the time we get the scalar in a symbolic function here, it is no longer a + Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want + it to be a zero dim tensor but this change has not happened yet.) However, the + type of this scalar is *exactly* what the user wrote in Python, which may not + match the tensor it is being added to. PyTorch will do implicit conversions on + scalars; however, ONNX will not, so we must do the conversion ourselves. This + is what `symbolic_helper._if_scalar_type_as()` and + `_jit_pass_onnx_scalar_type_analysis` does. + +- Dispatch to these functions takes advantage an outrageous coincidence + between the tensor and scalar name. When we add two tensors together, + you get the dispatch: + + add(*[self, other], **{"alpha": alpha}) + + When you add a tensor and a scalar, you get the dispatch: + + add(*[self], **{"other": other, "alpha": alpha}) + + By having the argument name line up with the name of the scalar attribute + if it exists, we can write a single function for both overloads. diff --git a/torch/onnx/_internal/torchscript_exporter/__init__.py b/torch/onnx/_internal/torchscript_exporter/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/onnx/_experimental.py b/torch/onnx/_internal/torchscript_exporter/_experimental.py similarity index 100% rename from torch/onnx/_experimental.py rename to torch/onnx/_internal/torchscript_exporter/_experimental.py diff --git a/torch/onnx/_globals.py b/torch/onnx/_internal/torchscript_exporter/_globals.py similarity index 100% rename from torch/onnx/_globals.py rename to torch/onnx/_internal/torchscript_exporter/_globals.py diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_internal/torchscript_exporter/_type_utils.py similarity index 100% rename from torch/onnx/_type_utils.py rename to torch/onnx/_internal/torchscript_exporter/_type_utils.py diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/torchscript_exporter/jit_utils.py similarity index 97% rename from torch/onnx/_internal/jit_utils.py rename to torch/onnx/_internal/torchscript_exporter/jit_utils.py index f3f82c0db7c2..6c00b6a9c8c4 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/torchscript_exporter/jit_utils.py @@ -1,9 +1,6 @@ # mypy: allow-untyped-defs """Utilities for manipulating the torch.Graph object and the torchscript.""" -# TODO(justinchuby): Move more of the symbolic helper functions here and expose -# them to the user. - from __future__ import annotations import dataclasses @@ -14,8 +11,8 @@ from typing import Any import torch from torch import _C -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import registration +from torch.onnx._internal.torchscript_exporter import registration +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS _ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") @@ -89,7 +86,6 @@ class GraphContext: The value representing the single output of this operator (see the `outputs` keyword argument for multi-return nodes). """ - # FIXME(justinchuby): Add the return type back once we know how to handle mypy return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs) def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): @@ -211,8 +207,6 @@ def _add_op( The set of operators and the inputs/attributes they take is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md - This function is monkey-patched onto Graph. - Args: graph_context: The Torch Graph or Block. opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified @@ -337,7 +331,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): return getattr(node, f"{kind}_")(name, value) -# TODO: Expose this to user when migrating symbolic helper functions to here. def _is_tensor(x: _C.Value) -> bool: return x.type().isSubtypeOf(_C.TensorType.get()) diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/torchscript_exporter/onnx_proto_utils.py similarity index 99% rename from torch/onnx/_internal/onnx_proto_utils.py rename to torch/onnx/_internal/torchscript_exporter/onnx_proto_utils.py index 04ed0f83ef84..c79786cf707d 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/torchscript_exporter/onnx_proto_utils.py @@ -9,10 +9,9 @@ import shutil from typing import Any, TYPE_CHECKING import torch -import torch.jit._trace import torch.serialization from torch.onnx import errors -from torch.onnx._internal import jit_utils, registration +from torch.onnx._internal.torchscript_exporter import jit_utils, registration if TYPE_CHECKING: diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/torchscript_exporter/registration.py similarity index 100% rename from torch/onnx/_internal/registration.py rename to torch/onnx/_internal/torchscript_exporter/registration.py diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py new file mode 100644 index 000000000000..a5e85aed01ef --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py @@ -0,0 +1,2380 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + + +__all__ = [ + "_apply_params", + "_arange_cast_helper", + "_arange_helper", + "_argmin_argmax_helper", + "_as_list_type", + "_avgpool_helper", + "_batchnorm_helper", + "_block_list_in_opset", + "_embedding_bag_helper", + "_flatten_helper", + "_generate_wrapped_number", + "_get_const", + "_get_dim_for_cross", + "_get_interpolate_attributes", + "_get_tensor_dim_size", + "_get_tensor_rank", + "_get_tensor_sizes", + "_handle_reduce_dim_none", + "_if_scalar_type_as", + "_index_fill_reshape_helper", + "_interpolate_get_scales_and_mode", + "_interpolate_get_scales_if_available", + "_interpolate_get_scales", + "_interpolate_helper", + "_interpolate_size_to_scales", + "_interpolate_warning", + "_is_bool", + "_is_constant", + "_is_fp", + "_is_list", + "_is_none", + "_is_onnx_constant", + "_is_packed_list", + "_is_scalar_list", + "_is_split_static", + "_is_tensor_list", + "_is_tensor", + "_is_tuple_construct", + "_is_value", + "_linalg_vector_norm_helper", + "_lt_helper", + "_max_helper", + "_maybe_cast_reduce_op_input", + "_maybe_cast_to_type", + "_maybe_get_const", + "_maybe_get_scalar", + "_min_helper", + "_node_get", + "_numel_helper", + "_onnx_opset_unsupported_detailed", + "_onnx_opset_unsupported", + "_onnx_unsupported", + "_op_with_optional_float_cast", + "_optional_input_placeholder_tensor", + "_overload_by_arg_count", + "_parse_arg", + "_reduce_op_symbolic_helper", + "_reduce_with_dtype_helper", + "_reducesum_helper", + "_repeat_interleave_single_value_repeat_helper", + "_repeat_interleave_split_helper", + "_reshape_helper", + "_scalar", + "_scatter_helper", + "_select_helper", + "_size_helper", + "_slice_helper", + "_sort_helper", + "_squeeze_helper", + "_topk_helper", + "_try_get_scalar_type", + "_type_promote_from_values", + "_unbind_helper", + "_unimplemented", + "_unpack_list", + "_unpack_quantized_tensor", + "_unpack_tuple", + "_unsqueeze_helper", + "_var_mean_helper", + "args_have_same_dtype", + "cast_pytorch_to_onnx", + "check_training_mode", + "dequantize_helper", + "is_complex_value", + "parse_args", + "pytorch_name_to_type", + "quantize_helper", + "quantized_args", + "requantize_bias_helper", + "scalar_name_to_pytorch", + "scalar_type_to_onnx", + "scalar_type_to_pytorch_type", +] + +import functools +import inspect +import math +import sys +import typing +import warnings +from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar +from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants, errors +from torch.onnx._internal.torchscript_exporter import _type_utils, jit_utils, utils +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + + +if typing.TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +_T = _TypeVar("_T") +_U = _TypeVar("_U") +_P = _ParamSpec("_P") + +# --------------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------------- + +_ValueDescriptor = Literal[ + "v", + "i", + "is", + "f", + "fs", + "b", + "s", + "t", + "none", +] + + +def _parse_arg( + value, + desc: _ValueDescriptor, + arg_name: str | None = None, + node_name: str | None = None, +): + if desc == "none": + return value + if desc == "v" or not _is_value(value): + return value + + node = value.node() + if node.mustBeNone(): + return None + if node.kind() == "onnx::Constant": + node_val = _node_get(node, "value") + if desc == "i": + return int(node_val) + elif desc == "f": + return float(node_val) + elif desc == "b": + return bool(node_val) + elif desc == "s": + return str(node_val) + elif desc == "t": + return node_val + elif desc == "is": + return [int(v) for v in node_val] + elif desc == "fs": + return [float(v) for v in node_val] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not understand the Constant node '{node}' " + f"specified with descriptor '{desc}'.", + value, + ) + elif node.kind() == "prim::ListConstruct": + if desc == "is": + for v in node.inputs(): + element_node = v.node() + if element_node.kind() != "onnx::Constant": + raise errors.SymbolicValueError( + f"Failed to export a node '{element_node}' " + f"(in list node {node}) " + f"because it is not constant. " + f"Please try to make things (e.g. kernel sizes) static if possible.", + value, + ) + return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not know how to unpack the ListConstruct node that " + f"is not a list of integers: '{node}'", + value, + ) + + if arg_name is None or node_name is None: + raise errors.SymbolicValueError( + f"Expected node type 'onnx::Constant', got '{node.kind()}'.", + value, + ) + + raise errors.SymbolicValueError( + "Expected node type 'onnx::Constant' " + f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", + value, + ) + + +def _node_get(node: _C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + assert isinstance(node, _C.Node) + sel = node.kindOf(key) + return getattr(node, sel)(key) + + +def _is_onnx_constant(value: _C.Value): + """Whether a Value is an ONNX constant.""" + return value.node().kind() == "onnx::Constant" + + +def _maybe_get_const( + value: _C.Value | torch.Tensor | Number | Sequence | None, + descriptor: _ValueDescriptor, +): + # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, + # otherwise it'd be converted to onnx::Constant + # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy + if isinstance(value, _C.Value) and _is_onnx_constant(value): + return _parse_arg(value, descriptor) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, "t") + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if not _is_constant(value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " + f"got '{value}'", + value, + ) + return _parse_arg(value, desc) + + +def _unpack_list(list_value: _C.Value) -> list[_C.Value]: + list_node = list_value.node() + if list_node.kind() != "prim::ListConstruct": + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type prim::ListConstruct, got '{list_node}'.", + list_value, + ) + return list(list_node.inputs()) + + +def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + tuple_node = tuple_value.node() + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type 'prim::TupleConstruct', " + f"got '{tuple_node.kind()}'.", + tuple_value, + ) + return tuple(tuple_node.inputs()) + + +def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. + Args: + tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. + Returns: + A tuple of tensor, scale, zero_point, and optionally axis. + """ + tuple_node = tuple_value.node() + # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, ) + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " + f"tensor. Is this likely due to missing support for quantized " + f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", + tuple_value, + ) + unpacked = tuple(tuple_node.inputs()) + assert len(unpacked) == 3 or len(unpacked) == 4 + return unpacked + + +# Check if list_value is output from prim::ListConstruct +# This is usually called before _unpack_list to ensure the list can be unpacked. +def _is_packed_list(list_value: Any) -> bool: + return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" + + +def parse_args( + *arg_descriptors: _ValueDescriptor, +) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: + """A decorator which converts args from torch._C.Value to built-in types. + + For example: + + ``` + @parse_args('v', 'i', 'fs') + foo(g, a, b, c): + assert isinstance(a, torch._C.Value) + assert isinstance(b, int) + assert isinstance(c, list) + assert isinstance(c[0], float) + ``` + + Args: + arg_descriptors: list of str, where each element is + a string that specifies the type to convert to. Valid descriptors: + "v": no conversion, keep torch._C.Value. + "i": int + "is": list of int + "f": float + "fs": list of float + "b": bool + "s": str + "t": torch.Tensor + "none": the variable is unused + """ + + def decorator( + fn: Callable[_Concatenate[_U, _P], _T], + ) -> Callable[_Concatenate[_U, _P], _T]: + fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] + + @functools.wraps(fn) + def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # some args may be optional, so the length may be smaller + FILE_BUG_MSG = ( + "If you believe this is not due to custom symbolic implementation within your code or " + "an external library, please file an issue at " + "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." + ) + assert len(arg_descriptors) >= len(args), ( + f"A mismatch between the number of arguments ({len(args)}) and " + f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " + f"{FILE_BUG_MSG}" + ) + + try: + sig = inspect.signature(fn) + arg_names = list(sig.parameters.keys())[1:] + fn_name = fn.__name__ + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + arg_names = [None] * len(args) # type: ignore[list-item] + fn_name = None + args = [ + _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) + ] + # only support _outputs in kwargs + assert len(kwargs) <= 1, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " + f"key/value entry. " + f"{FILE_BUG_MSG}" + ) + + if len(kwargs) == 1: + assert "_outputs" in kwargs, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " + f"'_outputs' key at '**kwargs'. " + f"{FILE_BUG_MSG}" + ) + return fn(g, *args, **kwargs) + + return wrapper + + return decorator + + +def quantized_args( + *arg_q_descriptors: bool, + scale: float | None = None, + zero_point: int | None = None, + quantize_output: bool = True, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """A decorator which extends support for quantized version of the base operator. + + Quantization is detected by examining the arguments that are annotated by + `arg_q_descriptors`. + + If quantization is detected, the base operator symbolic function will be wrapped with + argument de-quantization and output quantization. + + Otherwise, only the base symbolic function will be invoked. + + For example: + + ``` + @quantized_args(True, False) + def foo(g, x, y): + return x + y + ``` + + is equivalent to + + ``` + def q_foo(g, x, y): + if is_quantized_tensor(x): + x = dequantize(x) + out = foo(g, x, y) + return quantize(out) + else: + return foo(g, x, y) + ``` + + Args: + arg_q_descriptors: A sequence of bool, where each element represents if the + argument is QTensor for quantized version of this operator. It defaults + to False for unspecified (variable length) arguments. + scale: Quantized output scale. If None, derive from + the first quantized input scale. + zero_point: Quantized output zero point. If None, + derive from the first quantized input zero point. + quantize_output: If True, quantize the output of the base operator. Default is True + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(g, *args, **kwargs): + nonlocal scale + nonlocal zero_point + if scale is not None: + _scale = g.op("Constant", value_t=torch.tensor(scale)) + else: + _scale = None + if zero_point is not None: + _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) + else: + _zero_point = None + + # Support variable length arguments by marking unspecified ones as non-quantized + arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( + len(args) - len(arg_q_descriptors) + ) + descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) + + def _is_arg_quantized(descriptor, arg): + return descriptor and _is_value(arg) and _is_tuple_construct(arg) + + # Run regular symbolic function if none of the argument is QTensor. + is_quantized: list[bool] = [] + for descriptor, arg in descriptor_args: + # ListConstruct + if _is_packed_list(arg): + is_quantized.extend( + _is_arg_quantized(descriptor, arg_input) + for arg_input in arg.node().inputs() + ) + else: + is_quantized.append(_is_arg_quantized(descriptor, arg)) + + if not any(is_quantized): + return fn(g, *args, **kwargs) + + # Dequantize arguments that are quantized + non_quantized_args = [] + for descriptor, arg in descriptor_args: + if _is_arg_quantized(descriptor, arg): + # Quantized arg is a tuple of (value, scale, zero_point) + dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( + g, arg + ) + non_quantized_args.append(dequantized_arg) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + # ListConstruct + elif _is_packed_list(arg): + for arg_input in arg.node().inputs(): + if _is_arg_quantized(descriptor, arg_input): + # Quantized arg is a tuple of (value, scale, zero_point) + ( + dequantized_arg, + arg_scale, + arg_zero_point, + _, + ) = dequantize_helper(g, arg_input) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + arg_input.replaceAllUsesWith(dequantized_arg) + non_quantized_args.append(arg) + else: + # Non-quantized arg + non_quantized_args.append(arg) + # TODO(justinchuby): Only single output is supported for now. We may want to + # support multiple outputs in the future. + output = fn(g, *non_quantized_args, **kwargs) + + assert _scale is not None, "Bug: Scale must be set for quantized operator" + assert _zero_point is not None, ( + "Bug: Zero point must be set for quantized operator" + ) + + if quantize_output: + return quantize_helper(g, output, _scale, _zero_point) + return output + + return wrapper + + return decorator + + +def _scalar(x: Any) -> Number | None: + """Convert a scalar tensor into a Python value.""" + if isinstance(x, torch.Tensor) and x.shape == (): + return x.item() + return None + + +def _if_scalar_type_as(self, tensor): + """ + Convert self into the same type of tensor, as necessary. + We only support implicit casting for scalars, so we never + actually need to insert an ONNX cast operator here; just + fix up the scalar. + """ + if isinstance(self, _C.Value): + return self + + scalar_type = _type_utils.JitScalarType.from_value( + tensor, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + ty = scalar_type.scalar_name().lower() + return getattr(self, ty)() + return self + + +def _is_none(x: Any) -> bool: + return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) + + +def _is_value(x: Any) -> bool: + return isinstance(x, _C.Value) + + +def _is_constant(value: Any) -> bool: + return not _is_value(value) or value.node().kind() in { + "onnx::Constant", + "prim::Constant", + } + + +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. +def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: + if isinstance(jit_type, _C.ListType): + return jit_type + return None + + +def _is_list(x: _C.Value) -> bool: + return _as_list_type(x.type()) is not None + + +def _is_tensor_list(x: _C.Value) -> bool: + x_type = _as_list_type(x.type()) + if x_type is None: + return False + return isinstance(x_type.getElementType(), _C.TensorType) + + +def _is_scalar_list(x: _C.Value) -> bool: + """Checks if x is a scalar list, for example: List[float], List[int]. + + Besides checking the type is ListType, we also check if the data type is + a valid ONNX data type. + """ + x_type = _as_list_type(x.type()) + if x_type is None: + return False + scalar_type = _type_utils.JitScalarType.from_value(x) + return scalar_type.onnx_compatible() + + +def _is_tuple_construct(x: _C.Value) -> bool: + return x.node().kind() == "prim::TupleConstruct" + + +def is_complex_value(x: _C.Value) -> bool: + assert _is_value(x) + return _type_utils.JitScalarType.from_value( + x, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.COMPLEX32, + _type_utils.JitScalarType.COMPLEX64, + _type_utils.JitScalarType.COMPLEX128, + } + + +def _get_tensor_rank(x: _C.Value) -> int | None: + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + return x_type.dim() + + +def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, "a", "b"] -> [1, None, None] + return x_type.varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, "a", "b"] -> None + return x_type.sizes() + + +def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: + sizes = _get_tensor_sizes(x) + return sizes[dim] if sizes else None + + +def _get_dim_for_cross(x: _C.Value, dim: int | None): + if dim == -1: + tensor_rank = _get_tensor_rank(x) + assert tensor_rank is not None + return dim + tensor_rank + # If dim is not given, it defaults to the first dimension found with the size 3 + if dim is None: + sizes = _get_tensor_sizes(x) + assert sizes is not None + for index, size in enumerate(sizes): + if size is not None and size == 3: + return index + return dim + + +def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: + # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + _onnx_unsupported(f"{op}, {msg}", value) + + +def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: + message = ( + f"Unsupported: ONNX export of operator {op_name}. " + f"Please feel free to request support or submit a pull request " + f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported( + op_name: str, + current_opset: int, + supported_opset: int, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " + f"Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported_detailed( + op_name: str, + current_opset: int, + supported_opset: int, + reason: str, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in " + f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _block_list_in_opset(name: str): + def symbolic_fn(*args, **kwargs): + raise errors.OnnxExporterError( + f"ONNX export failed on {name}, which is not implemented for opset " + f"{GLOBALS.export_onnx_opset_version}. " + "Try exporting with other opset versions." + ) + + return symbolic_fn + + +def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: + for arg in args: + scalar_type = _type_utils.JitScalarType.from_value( + arg, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + return scalar_type + return None + + +def _type_promote_from_values(*args) -> _type_utils.JitScalarType: + undef = _type_utils.JitScalarType.UNDEFINED + jit_types = [_try_get_scalar_type(arg) for arg in args] + if len(jit_types) == 0: + return undef + if len(jit_types) == 1: + return jit_types[0] # type: ignore[return-value] + new_dtype = jit_types[0].dtype() # type: ignore[union-attr] + for t in jit_types: + new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] + return _type_utils.JitScalarType.from_dtype(new_dtype) + + +def _maybe_cast_to_type( + g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType +): + if ( + _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) + != jit_type + ): + return g.op( + "Cast", + value, + to_i=jit_type.onnx_type(), + ) + return value + + +def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): + index_const = _maybe_get_scalar(index) + index_dim = _get_tensor_rank(index) + if not _is_value(index_const): + # Index is a constant scalar. Make it a size 1 constant tensor. + index = g.op("Constant", value_t=torch.LongTensor([index_const])) + elif index_dim is not None and apply_reshape: + if index_dim == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = _reshape_helper( + g, index, g.op("Constant", value_t=torch.LongTensor([1])) + ) + + index_scalar_type = _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + if index_scalar_type not in { + _type_utils.JitScalarType.INT64, + _type_utils.JitScalarType.INT, + }: + index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Gather", self, index, axis_i=dim) + + +def _slice_helper( + g: jit_utils.GraphContext, + input, + axes, + starts, + ends, + steps=None, +): + if g.opset <= 9: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( + _slice as _slice9, + ) + + return _slice9(g, input, axes, starts, ends) + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import ( + _slice as _slice10, + ) + + return _slice10(g, input, axes, starts, ends, steps) + + +def _is_fp(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + } + + +def _is_bool(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in {_type_utils.JitScalarType.BOOL} + + +def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): + """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. + + A Tensor is a considered a "wrapped number" if it is + auto-wrapped from a C++ or Python number type. Integer types are + wrapped as 0-dim int64 tensors and floating-point types are + wrapped as 0-dim double tensors. + + The input to this function is constant value. If the data type + is a floating point type, it is converted to a 0-dim double + tensor, else it is converted to a 0-dim tensor of its original type + """ + assert not isinstance(scalar, torch.Tensor) + if isinstance(scalar, float): + return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) + return g.op("Constant", value_t=torch.tensor(scalar)) + + +def _sort_helper(g: jit_utils.GraphContext, input, dim, descending=True, out=None): + if out is not None: + _unimplemented("Sort", "Out parameter is not supported") + shape_ = g.op("Shape", input) + dim_size_ = g.op( + "Gather", + shape_, + g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), + ) + if g.opset <= 10: + if not descending: + _unimplemented("Sort", "Ascending is not supported") + return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, dim_size_, axis_i=dim, largest_i=descending, outputs=2 + ) + + +def _topk_helper( + g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None +): + if out is not None: + _unimplemented("TopK", "Out parameter is not supported") + if not _is_value(k): + k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) + else: + k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) + if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: + k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) + if g.opset <= 10: + if not largest: + _unimplemented("TopK", "Ascending is not supported") + return g.op("TopK", input, k, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 + ) + + +def _lt_helper(g: jit_utils.GraphContext, input, other): + if g.opset <= 8: + from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import lt as _lt8 + + return _lt8(g, input, other) + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import lt as _lt9 + + return _lt9(g, input, other) + + +def _interpolate_warning(interpolate_mode): + onnx_op = ( + "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" + ) + warnings.warn( + "You are trying to export the model with " + + onnx_op + + " for ONNX opset version " + "" + str(GLOBALS.export_onnx_opset_version) + ". " + "This operator might cause results to not match the expected results by PyTorch.\n" + "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " + "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " + "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" + "We recommend using opset 11 and above for models using this operator." + ) + + +def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): + if len(axes_i) == 0: + # unnecessary unsqueeze if axes length==0 + return input + elif _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Unsqueeze", input, axes) + return g.op("Unsqueeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input + ) + return g.op("Unsqueeze", input, axes_i[0]) + + +def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): + if _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Squeeze", input, axes) + return g.op("Squeeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Squeeze with dynamic axes.", input + ) + axes_t = axes_i[0] + axes_rank = _get_tensor_rank(axes_t) + assert axes_rank is not None + if axes_rank > 1: + raise errors.SymbolicValueError( + "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input + ) + elif axes_rank == 0: + # The axes is a scalar. Unsqueeze it to a rank 1 tensor. + axes_t = _unsqueeze_helper(g, axes_t, [0]) + return g.op("Squeeze", input, axes_t) + return g.op("Squeeze", input, axes_t) + + +def _reducesum_helper( + g: jit_utils.GraphContext, + input, + axes_i=None, + keepdims_i=1, + noop_with_empty_axes_i=0, +): + keepdims_i = _maybe_get_const(keepdims_i, "i") + if g.opset >= 13: + if axes_i: + if not _is_value(axes_i): + axes_i = g.op( + "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) + ) + return g.op( + "ReduceSum", + input, + axes_i, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + return g.op( + "ReduceSum", + input, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + else: + return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) + + +def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): + output_size = _maybe_get_const(output_size, "is") + if _is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) + dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) + divisor = _slice_helper( + g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] + ) + divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + scales = g.op( + "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) + ) + return scales + + +def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): + available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( + scales[0] + ) + + if not available_scales: + return None + + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scales_list = g.op( + "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) + ) + scales = g.op("Concat", offsets, scales_list, axis_i=0) + return scales + + +def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): + if mode == "nearest": + align_corners = None + scales = args[0:] + else: + align_corners = args[0] + scales = args[1:] + scales = _interpolate_get_scales_if_available(g, scales) + return scales, align_corners + + +def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), _C.ListType) or ( + scale_factor_rank is not None and scale_factor_rank > 0 + ): + return g.op("Concat", offsets, scale_factor, axis_i=0) + else: + scale_factor = _unsqueeze_helper(g, scale_factor, [0]) + scale_factor = g.op( + "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + scales = [scale_factor for i in range(dim - 2)] + scale_factor = g.op("Concat", offsets, *scales, axis_i=0) + return scale_factor + + +def _interpolate_get_scales_and_mode( + g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + _interpolate_warning(mode) + + align_corners = _maybe_get_const(align_corners, "b") + if isinstance(align_corners, bool) and align_corners: + return _unimplemented("interpolate", "align_corners == True") + + if not input.type().dim(): + return _unimplemented("interpolate", "missing input shape") + dim = input.type().dim() + + if not _is_none(scale_factor): + scale_factor = _interpolate_get_scales(g, scale_factor, dim) + elif not _is_none(size): + if not _is_packed_list(size): + is_scalar = _maybe_get_const(size, "t").dim() == 0 + if is_scalar: + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(dim - 2)] + size = g.op("Concat", *size, axis_i=0) + scale_factor = _interpolate_size_to_scales(g, input, size, dim) + else: + return _unimplemented( + "interpolate", "Both size and scales are None in __interpolate" + ) + return scale_factor, mode + + +def _argmin_argmax_helper( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, + op_name: str, +): + def op_wrapper(input, axis_i, keepdims_i): + if g.opset >= 12: + return g.op( + op_name, + input, + axis_i=axis_i, + keepdims_i=keepdims_i, + select_last_index_i=False, + ) + return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) + + if _is_none(dim): + flattened = _reshape_helper( + g, input, g.op("Constant", value_t=torch.tensor([-1])) + ) + output = op_wrapper(flattened, axis_i=0, keepdims_i=False) + if keepdim: + input_shape = g.op("Shape", input) + input_shape_shape = g.op("Shape", input_shape) + new_shape = g.op( + "ConstantOfShape", + input_shape_shape, + value_t=torch.tensor([1], dtype=torch.int64), + ) + output = g.op("Reshape", output, new_shape) + return output + + dim = _parse_arg(dim, "i") + return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) + + +def _interpolate_helper(name, dim, interpolate_mode): + @quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) + align_corners = _maybe_get_scalar(align_corners) + coordinate_transformation_mode = ( + "asymmetric" + if interpolate_mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if scales is None: + input_size = g.op("Shape", input) + input_size_beg = _slice_helper( + g, input_size, axes=[0], ends=[2], starts=[0] + ) + output_size = g.op( + "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 + ) + output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + output_size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + else: + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + return symbolic_fn + + +def __interpolate_helper( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + align_corners = _maybe_get_const(align_corners, "b") + align_corners = False if not isinstance(align_corners, bool) else align_corners + coordinate_transformation_mode = ( + "asymmetric" + if mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if not _is_none(size): + input_size = g.op("Shape", input) + input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + # in some cases size is not a packed list but size is a scalar + # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) + # but this information is not always available. Try to get the dim, + # and if not assume that it is not a scalar. + try: + is_scalar = not _is_packed_list(size) and ( + _maybe_get_const(size, "t").dim() == 0 + ) + except AttributeError: + is_scalar = not _is_packed_list(size) + if not is_scalar: + warnings.warn( + "Cannot verify if the output_size is a scalar " + "while exporting interpolate. Assuming that it is not a scalar." + ) + + if is_scalar: + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented( + "interpolate (with a scalar output_size)", + "missing input shape (try giving an array of output_size values)", + ) + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(rank - 2)] + size = g.op("Concat", *size, axis_i=0) + size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) + size = g.op("Concat", input_size, size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) + else: # if not _is_none(scales) + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with scales)", "missing input shape") + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + scales = _interpolate_get_scales(g, scale_factor, rank) + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + +def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): + if g.opset < 11: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import unbind + elif g.opset <= 12: + from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import ( + unbind, # type: ignore[no-redef] + ) + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import ( + unbind, # type: ignore[no-redef] + ) + return unbind(g, self, dim, _outputs) + + +def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): + if g.opset <= 10: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import ( + scatter, # type: ignore[no-redef] + ) + return scatter(g, self, dim, index, src) + + +def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): + if g.opset <= 12: + split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import split + + repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) + split_out = split(g, self, repeats, dim, _outputs=reps) + return split_out if reps > 1 else [split_out] + + +def _repeat_interleave_single_value_repeat_helper( + g: jit_utils.GraphContext, self, repeats, dim +): + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( + flatten, + unsqueeze, + ) + + if not _is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + + const_repeats: bool = _is_constant(repeats) + reps = _maybe_get_const(repeats, "t") + + # Convert 'repeats' to 1-d if it is 0-d. + if _get_tensor_rank(repeats) == 0: + repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) + + # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. + unsqueezed = unsqueeze(g, self, dim + 1) + + # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. + if const_repeats: + # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. + onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] + onehot[dim + 1] = reps + repeats_per_dim = g.op("Constant", value_t=onehot) + else: + # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. + onehot = g.op( + "OneHot", + unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional + g.op( + "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) + ), # depth + g.op( + "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 + ), # on/off values + ) + repeats_per_dim = flatten(g, onehot, 0, 1) + + tiled = g.op("Tile", unsqueezed, repeats_per_dim) + return flatten(g, tiled, dim, dim + 1) + + +def _arange_cast_helper( + g: jit_utils.GraphContext, end, start=None, step=None, dtype=None +) -> tuple[ + _type_utils.JitScalarType, + _C.Value | None, + _C.Value | None, + _C.Value | None, +]: + def _is_all_integral(scalars): + for scalar in scalars: + scalar_type = _type_utils.JitScalarType.from_value( + scalar, _type_utils.JitScalarType.UNDEFINED + ) + if ( + scalar_type != _type_utils.JitScalarType.INT64 + and scalar_type != _type_utils.JitScalarType.UNDEFINED + ): + return False + return True + + # This logic is based on torch.arange docs. If "dtype" is provided, + # infer input types from dtype. If not, then check if any of start, stop, + # or step are floating point, and infer the type from get_default. + # Otherwise, the dtype is inferred to be torch.int64. + if dtype is None or (_is_value(dtype) and _is_none(dtype)): + if _is_all_integral([start, end, step]): + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType.from_dtype( + torch.get_default_dtype() + ) + else: + assert isinstance(dtype, int) + # TODO(justinchuby): Check if dtype is indeed a int. + scalar_type = _type_utils.JitScalarType(dtype) + + start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None + end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None + step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None + return scalar_type, end, start, step + + +def _arange_helper(g: jit_utils.GraphContext, *args): + if g.opset <= 10: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import arange + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import ( + arange, # type: ignore[no-redef] + ) + return arange(g, *args) + + +def _size_helper(g: jit_utils.GraphContext, self, dim): + full_shape = g.op("Shape", self) + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import select + + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) + + +def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): + # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] + # 2. expand index => [..., dim, ...], same shape as self except for dim. + # 3. expand value as well. + # 4. apply onnx::scatter. + + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import expand + + if g.opset <= 10: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import ( + scatter, # type: ignore[no-redef] + ) + + if self.type().dim() is None: + return _unimplemented("index_fill", "input rank not accessible") + self_dim = self.type().dim() + dim_value = _parse_arg(dim, "i") + if dim_value < 0: + dim_value += self_dim + unsqueezed_index = _unsqueeze_helper( + g, index, [i for i in range(self_dim) if i != dim_value] + ) + expanded_index_shape = scatter( + g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) + ) + expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) + return expanded_index_shape, expanded_index + + +# By default, when any value in the 'shape' input is equal to zero +# the corresponding dimension value is copied from the input tensor dynamically. +# allowzero=1 indicates that if any value in the 'shape' input is set to zero, +# the zero value is honored, similar to NumPy. +# allowzero=1 is only supported for opset version >= 14. +def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): + shape = _maybe_get_const(shape, "is") + if not _is_value(shape): + shape = g.op("Constant", value_t=torch.LongTensor(shape)) + if g.opset <= 13: + if allowzero == 1: + _onnx_opset_unsupported( + "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input + ) + return g.op("Reshape", input, shape) + else: + return g.op("Reshape", input, shape, allowzero_i=allowzero) + + +def _batchnorm_helper( + g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var +): + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import _var_mean + + batch_size = _get_tensor_dim_size(input, 0) + channel_size = _get_tensor_dim_size(input, 1) + + if weight is None or _is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or _is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if ( + running_mean is None + or _is_none(running_mean) + or running_var is None + or _is_none(running_var) + ): + assert batch_size is not None and channel_size is not None + reshape_in = _reshape_helper( + g, + input, + g.op( + "Constant", + value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), + ), + ) + trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean( + g, + trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, + False, + ) + return weight, bias, running_mean, running_var + + +def _avgpool_helper( + tuple_fn: Callable[[Any], Sequence[int]], + padding: int | Sequence[int], + kernel_size, + stride, + divisor_override, + name, +) -> tuple[int, ...]: + if divisor_override and divisor_override.node().kind() != "prim::Constant": + _unimplemented(name, "divisor_override") + return tuple(tuple_fn(padding)) + + +def check_training_mode(op_train_mode: int, op_name: str) -> None: + """Warns the user if the model's training mode and the export mode do not agree.""" + if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: + return + + if op_train_mode: + op_mode_enum = _C_onnx.TrainingMode.TRAINING + else: + op_mode_enum = _C_onnx.TrainingMode.EVAL + if op_mode_enum == GLOBALS.training_mode: + # The modes agree. Do nothing + return + + op_mode_text = f"train={bool(op_train_mode)}" + # Setting the model mode could result in op_mode != GLOBALS.training_mode + # if the model is a FuncModule. In this case we warn the user of + # the state and export depending on op_mode + # This is to support use-cases of fixing certain layer weights + # in training. + warnings.warn( + f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " + f"is set to {op_mode_text}. Exporting with {op_mode_text}." + ) + + +def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): + input_size = g.op("Shape", input) + slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) + slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] + if end_dim < dim - 1: + slice3 = _slice_helper( + g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] + ) + slices = [ + slice1, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slice3, + ] + + final_shape = g.op("Concat", *slices, axis_i=0) + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( + _reshape_from_tensor, + ) + + return _reshape_from_tensor(g, input, final_shape) + + +def _is_split_static(split_size_or_sizes, _outputs): + if _outputs is None: + return False + if ( + _is_value(split_size_or_sizes) + and split_size_or_sizes.node().kind() != "onnx::Constant" + ): + return False + return True + + +def _optional_input_placeholder_tensor(g): + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): + rank = _get_tensor_rank(self) + if rank is not None and any( + _get_tensor_dim_size(self, i) == 0 for i in range(rank) + ): + # If input tensor is empty, according to ONNX ReduceSum definition, + # set keepdims=1 so that the resulted tensor has the same rank as the input. + return g.op(op_name, self, keepdims_i=1) + return g.op(op_name, self, keepdims_i=0) + + +def dequantize_helper( + g: jit_utils.GraphContext, + qtensor: _C.Value, + qdtype: _C_onnx.TensorProtoDataType | None = None, +) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: + """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) + for per tensor quantization, or + (quantized_tensor, scale, zero_point, axis) for per channel quantization, + representing the quantized tensor. + qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the + data type of quantized tensor. It must be either + torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. + """ + unpacked_qtensors = _unpack_quantized_tensor(qtensor) + tensor, scale, zero_point = unpacked_qtensors[:3] + axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None + axis_i = _get_const(axis, "i", "axis") + input_qdtype = _type_utils.JitScalarType.from_value(tensor) + if qdtype is None: + if input_qdtype is not None: + qdtype = input_qdtype.onnx_type() + else: + qdtype = _C_onnx.TensorProtoDataType.UINT8 + value = g.op("Cast", tensor, to_i=qdtype) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + zero_point = g.op("Cast", zero_point, to_i=qdtype) + + if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: + _onnx_opset_unsupported_detailed( + "DequantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + qtensor, + ) + + return ( + g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), + scale, + zero_point, + axis, + ) + + +def quantize_helper( + g: jit_utils.GraphContext, + tensor: _C.Value, + scale: _C.Value, + zero_point: _C.Value, + axis: _C.Value | None = None, +) -> _C.Value: + """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + tensor: torch._C.Value, representing the tensor to be quantized. + scale: torch._C.Value, quantized scale. + zero_point: torch._C.Value, quantized zero point. + axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. + Otherwise, represents per channel quantization, along given axis. + + Returns: + A TupleConstruct storing information of the quantized tensor. + """ + if ( + axis is not None + and not _is_none(axis) + and GLOBALS.export_onnx_opset_version < 13 + ): + _onnx_opset_unsupported_detailed( + "QuantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + tensor, + ) + + assert scale is not None + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + assert zero_point is not None + if _type_utils.JitScalarType.from_value( + zero_point, _type_utils.JitScalarType.UNDEFINED + ) not in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + }: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + output = g.op( + "QuantizeLinear", + tensor, + scale, + zero_point, + axis_i=_get_const(axis, "i", "axis"), + ) + args = [output, scale, zero_point] + if axis is not None and not _is_none(axis): + args.append(axis) + return g.op("prim::TupleConstruct", *args) + + +def requantize_bias_helper( + g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None +): + """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. + In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. + Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using + regular operators. + """ + bias_scale = g.op("Mul", weight_scale, input_scale) + bias_scale_shape = g.op("Shape", bias_scale) + bias_zero_point = g.op( + "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) + ) + q_bias = g.op( + "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 + ) + axis_args = [] + if axis is not None and not _is_none(axis): + axis_args.append(axis) + return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) + + +def args_have_same_dtype(args): + assert args + base_dtype = _type_utils.JitScalarType.from_value(args[0]) + has_same_dtype = all( + _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args + ) + return has_same_dtype + + +def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): + """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. + This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch + operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic + `Clip(INPUT)` (opset version < 12). + + Args: + g (torch._C.Graph): graph to write the ONNX representation into. + op_name (str): operator name in ONNX. + *args (tuple): operands to the operator. + **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) + indicating the smallest opset version to trigger such casting behavior and "target_float_t" + (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. + + Returns: + Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. + """ + opset_before = kwargs.pop("opset_before", None) + target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) + + inputs = list(args) + dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) + + require_cast = not _is_fp(inputs[0]) and ( + opset_before is None or GLOBALS.export_onnx_opset_version < opset_before + ) + + if require_cast: + for input in inputs: + if input.isCompleteTensor(): + input_scalar_type = _type_utils.JitScalarType.from_value(input) + if input_scalar_type != dtype_0: + raise errors.SymbolicValueError( + f"Inputs of {op_name} must have same dtype." + f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", + input, + ) + for i, input in enumerate(inputs): + if input.isCompleteTensor() and not _is_fp(input): + inputs[i] = g.op( + "Cast", + input, + to_i=target_float_t.onnx_type(), + ) + + self = g.op(op_name, *inputs, **kwargs) + + if require_cast: + self = g.op("Cast", self, to_i=dtype_0.onnx_type()) + + return self + + +def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + # This check only covers traced modules where dtype is present + # pytorch reduce-ops cast all other integral types to int64 + if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) + return self + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): + def symbolic(g, self, dim=None, keepdim=None): + self = _maybe_cast_reduce_op_input(g, self) + if dim is None or dim == (): + # Dim can be 0, which will cause (not dim) == True. So we don't want to do + # (not dim) + # all-reduce path + return _handle_reduce_dim_none(g, self, onnx_op_name) + else: + # dim-reduce path + keepdim = _get_const(keepdim, "i", "keepdim") + if g.opset < 18: + desc = "is" if allow_multi_dim_support else "i" + dim = _get_const(dim, desc, "dim") + dim_list = dim if allow_multi_dim_support else [dim] + return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) + else: + if _is_value(dim): + axes = dim + else: + if allow_multi_dim_support: + axes = g.op( + "Constant", value_t=torch.tensor(dim, dtype=torch.long) + ) + else: + axes = g.op( + "Constant", value_t=torch.tensor([dim], dtype=torch.long) + ) + return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) + + return symbolic + + +def _overload_by_arg_count(fn): + @functools.wraps(fn) + def wrapper(g, *args): + overloads = fn(g, *args) + for overload in overloads: + arg_descriptors = overload._arg_descriptors + if len(arg_descriptors) == len(args): + return overload(g, *args) + return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") + + return wrapper + + +def _reduce_with_dtype_helper( + onnx_op: str, name: str, allow_multi_dim_support: bool = True +): + symbolic = _reduce_op_symbolic_helper( + onnx_op, allow_multi_dim_support=allow_multi_dim_support + ) + + @_overload_by_arg_count + def reduce(g, *args, **kwargs): + @quantized_args(True) + @parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + dim_desc = "is" if allow_multi_dim_support else "i" + + @quantized_args(True) + @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMax", self, keepdims_i=0) + # torch.max(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) + # torch.max(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) + return max, indices + + +def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMin", self, keepdims_i=0) + # torch.min(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) + # torch.min(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) + return min, indices + + +def _numel_helper(g: jit_utils.GraphContext, self): + shape = g.op("Shape", self) + return g.op("ReduceProd", shape, keepdims_i=0) + + +@parse_args("v", "is", "i", "i") +def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): + if g.opset < 18: + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + else: + axes = None + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + if axes is None: + var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) + else: + var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + + +def _embedding_bag_helper( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return _onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + zero = g.op("Constant", value_t=torch.tensor([0])) + + indices_len = _unsqueeze_helper( + g, + _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + [0], + ) + if not include_last_offset: + offsets = [offsets, indices_len] + offsets = g.op("Concat", *offsets, axis_i=0) + + # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by + # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. + # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. + offsets_starts = _slice_helper( + g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] + ) + offsets_ends = _slice_helper( + g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] + ) + + loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + loop_block = loop_context.block + + # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return + block_input_iter = utils._add_input_to_block(loop_block) + utils._add_input_to_block(loop_block) + + indices_start = loop_context.op( + "Gather", offsets_starts, block_input_iter, axis_i=0 + ) + indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) + indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) + indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) + + indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) + embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) + if not _is_none(per_sample_weights): + per_sample_weights_row = loop_context.op( + "Slice", per_sample_weights, indices_start, indices_end, zero + ) + per_sample_weights_row = _unsqueeze_helper( + loop_context, per_sample_weights_row, [1] + ) + embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = _reducesum_helper( + loop_context, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) + else: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, embeddings) + + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return loop.node().output(), None, None, None + + +def _linalg_vector_norm_helper( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + axes = None + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html + if _is_none(dim): + self = _reshape_helper(g, self, [-1]) + keepdim = False + elif g.opset >= 18: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + + if ord == math.inf: + if g.opset < 18: + result = g.op( + "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == -math.inf: + if g.opset < 18: + result = g.op( + "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == 0: + if g.opset < 11: + return _onnx_opset_unsupported_detailed( + "linalg_vector_norm", 9, 11, "ord=0 not supported", self + ) + else: + if dim is None: + self = _reshape_helper( + g, + self, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + ) + keepdim = False + + cond_op = g.op( + "Not", + g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + cond_op = g.op( + "Cast", + cond_op, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) + elif ord == 1: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, axes, keepdim=keepdim + ) + elif ord == 2: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, axes, keepdim=keepdim + ) + else: + ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) + result = _reducesum_helper( + g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim + ) + result = g.op( + "Pow", + result, + g.op( + "Div", + g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), + ord_op, + ), + ) + + if not _is_none(dtype): + dtype = _get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] + return result + + +# Deprecated. Internally use _type_utils.ScalarType +# TODO: remove these once we support Type's in the JIT IR and we can once again +# use the unified toType operator +cast_pytorch_to_onnx = { + "Byte": _C_onnx.TensorProtoDataType.UINT8, + "Char": _C_onnx.TensorProtoDataType.INT8, + "Double": _C_onnx.TensorProtoDataType.DOUBLE, + "Float": _C_onnx.TensorProtoDataType.FLOAT, + "Half": _C_onnx.TensorProtoDataType.FLOAT16, + "Int": _C_onnx.TensorProtoDataType.INT32, + "Long": _C_onnx.TensorProtoDataType.INT64, + "Short": _C_onnx.TensorProtoDataType.INT16, + "Bool": _C_onnx.TensorProtoDataType.BOOL, + "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, + "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, + "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, + "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, +} + +# Deprecated. Internally use _type_utils.ScalarType +scalar_name_to_pytorch = { + "uint8_t": "Byte", + "int8_t": "Char", + "double": "Double", + "float": "Float", + "half": "Half", + "int": "Int", + "int64_t": "Long", + "int16_t": "Short", + "bool": "Bool", + "complex64": "ComplexFloat", + "complex128": "ComplexDouble", + "qint8": "QInt8", + "quint8": "QUInt8", + "qint32": "QInt32", + "bfloat16": "BFloat16", +} + + +# Deprecated. Internally use _type_utils.ScalarType +# This indicates each scalar type's corresponding +# torch type. Related source: +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +scalar_type_to_pytorch_type = [ + torch.uint8, # 0 + torch.int8, # 1 + torch.short, # 2 + torch.int, # 3 + torch.int64, # 4 + torch.half, # 5 + torch.float, # 6 + torch.double, # 7 + torch.complex32, # 8 + torch.complex64, # 9 + torch.complex128, # 10 + torch.bool, # 11 + torch.qint8, # 12 + torch.quint8, # 13 + torch.qint32, # 14 + torch.bfloat16, # 15 +] + +# Deprecated. Internally use _type_utils.ScalarType +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +pytorch_name_to_type = { + "Byte": torch.uint8, + "Char": torch.int8, + "Double": torch.double, + "Float": torch.float, + "Half": torch.half, + "Int": torch.int, + "Long": torch.int64, + "Short": torch.short, + "Bool": torch.bool, + "ComplexFloat": torch.complex64, + "ComplexDouble": torch.complex128, + "QInt8": torch.qint8, + "QUInt8": torch.quint8, + "QInt32": torch.qint32, + "BFloat16": torch.bfloat16, +} + + +# Deprecated. Internally use _type_utils.ScalarType +scalar_type_to_onnx = [ + cast_pytorch_to_onnx["Byte"], # 0 + cast_pytorch_to_onnx["Char"], # 1 + cast_pytorch_to_onnx["Short"], # 2 + cast_pytorch_to_onnx["Int"], # 3 + cast_pytorch_to_onnx["Long"], # 4 + cast_pytorch_to_onnx["Half"], # 5 + cast_pytorch_to_onnx["Float"], # 6 + cast_pytorch_to_onnx["Double"], # 7 + cast_pytorch_to_onnx["Undefined"], # 8 + cast_pytorch_to_onnx["ComplexFloat"], # 9 + cast_pytorch_to_onnx["ComplexDouble"], # 10 + cast_pytorch_to_onnx["Bool"], # 11 + cast_pytorch_to_onnx["Char"], # 12 + cast_pytorch_to_onnx["Byte"], # 13 + cast_pytorch_to_onnx["Int"], # 14 + cast_pytorch_to_onnx["BFloat16"], # 15 +] + +# Global set to store the list of quantized operators in the network. +# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. +_quantized_ops: set[int] = set() diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py new file mode 100644 index 000000000000..6b36396250b4 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py @@ -0,0 +1,1187 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants, errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, +) +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 10 +# Opset 10 is supported by ONNX release 1.5.0 +# release on 04/24/19 + + +__all__ = [ + "dequantize", + "div", + "embedding_bag", + "fake_quantize_per_tensor_affine", + "flip", + "fmod", + "isfinite", + "isinf", + "nan_to_num", + "quantize_per_tensor", + "quantized_add_relu", + "quantized_add", + "quantized_cat", + "quantized_conv1d_relu", + "quantized_conv2d_relu", + "quantized_conv3d_relu", + "quantized_conv1d", + "quantized_conv2d", + "quantized_conv3d", + "quantized_conv_transpose1d", + "quantized_conv_transpose2d", + "quantized_conv_transpose3d", + "quantized_group_norm", + "quantized_hardswish", + "quantized_instance_norm", + "quantized_layer_norm", + "quantized_leaky_relu", + "quantized_linear", + "quantized_linear_relu", + "quantized_mul", + "quantized_sigmoid", + "slice", + "sort", + "topk", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return opset9.true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode == "floor": + return _floor_divide(g, self, other) + else: + return opset9._div_rounding_mode(g, self, other, rounding_mode) + + +@_onnx_symbolic("aten::_floor_divide") +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = opset9.true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does truncation rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Mod", self, other, fmod_i=0) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Sub", div, one) + return g.op("Where", fixup_mask, fixup, div) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): + return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +def _aten_max_pool_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, +) -> _C.Value: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, _ = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", + pool_result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + return pool_result + + +# For MaxPool +def _adjust_attributes_of_max_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, + dilation: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(dilation, int): + dilation = [dilation] * expand_size + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + # 2D padding + pads = padding * 2 # type: ignore[operator, assignment] + elif len(padding) == 3: + # 3D padding + pads = padding * 2 # type: ignore[operator, assignment] + else: + # When padding is already done for all dimensions, + # we don't need to double it + # eg: (1, 1, 1, 1, 1, 1) + pads = padding # type: ignore[assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads, dilation) + + +def _aten_max_pool_with_indices_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, + n_dims_one: Sequence[int], + n_dims_zero: Sequence[int], + n_dims_axes: Sequence[int], +) -> tuple[_C.Value, Sequence[int]]: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, indices = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + _, flatten_indices = g.op( + "MaxPool", + self, + outputs=2, + dilations_i=dilations, + kernel_shape_i=n_dims_one, + strides_i=n_dims_one, + ) + + ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) + starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) + axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) + + delta = g.op("Slice", flatten_indices, starts, ends, axes) + indices = g.op("Sub", indices, delta) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) + ) + indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) + + return (pool_result, indices) + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool1d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d_with_indices", + 1, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d_with_indices", + 2, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d_with_indices", + 3, + return_indices=True, + ) + ], +) +def _max_pool(name: str, expand_size: int, return_indices: bool): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn( + g: jit_utils.GraphContext, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + dilation: Sequence[int], + ceil_mode: bool, + ): + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + + if return_indices: + return _aten_max_pool_with_indices_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) + else: + return _aten_max_pool_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ) + + return symbolic_fn + + +# For AvgPool +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + pads = padding * expand_size # type: ignore[operator, assignment] + else: + pads = padding * 2 # type: ignore[operator, assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], +) +def _avg_pool(name, expand_size): + @symbolic_helper.quantized_args(True, False, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = g.op( + "AveragePool", + input, + ceil_mode_i=ceil_mode, + count_include_pad_i=count_include_pad, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + return result + + return symbolic_fn + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + @symbolic_helper.quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Resize", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Resize", input, scales, mode_s=mode) + + +def _slice( + g: jit_utils.GraphContext, + input: torch._C.Value, + axes: list | torch.Tensor | torch._C.Value, + starts: list | torch.Tensor | torch._C.Value, + ends: list | torch.Tensor | torch._C.Value, + steps: list | torch.Tensor | torch._C.Value | None = None, +): + def is_none_value(value): + if value is None: + return True + return ( + isinstance(value, torch._C.Value) + and value.node().kind() == "prim::Constant" + and isinstance(value.type(), _C.NoneType) + ) + + def to_slice_input(list_or_value, default_value=None): + # Convert input param into a 1D torch.Value. + if is_none_value(list_or_value) and default_value is not None: + list_or_value = [default_value] + + if isinstance(list_or_value, torch.Tensor): + return g.op("Constant", value_t=list_or_value.clone().detach()) + elif isinstance(list_or_value, list): + return g.op("Constant", value_t=torch.tensor(list_or_value)) + + rank = symbolic_helper._get_tensor_rank(list_or_value) + if rank == 0: + return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) + if rank == 1: + return list_or_value + raise errors.SymbolicValueError( + f"Rank must be 0 or 1, not {rank}", list_or_value + ) + + def get_const_value(list_or_value): + if isinstance(list_or_value, (list, torch.Tensor)): + if len(list_or_value) == 1: + return list_or_value[0] + return None + return symbolic_helper._maybe_get_const(list_or_value, "i") + + # Check if slice is a no-op + if ( + get_const_value(starts) == 0 + and get_const_value(ends) == _constants.INT64_MAX + and (steps is None or get_const_value(steps) == 1) + ): + return input + + axes = to_slice_input(axes) + starts = to_slice_input(starts, default_value=0) + ends = to_slice_input(ends, default_value=_constants.INT64_MAX) + if steps is None: + return g.op("Slice", input, starts, ends, axes) + steps = to_slice_input(steps, default_value=1) + return g.op("Slice", input, starts, ends, axes, steps) + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor + dims, start, end, step = args + elif len(args) == 3: + # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] + start, end, step = args + dims = [0] + else: + raise errors.SymbolicValueError("Unknown aten::slice signature", self) + + return symbolic_helper._slice_helper( + g, + self, + axes=dims, + starts=start, + ends=end, + steps=step, + ) + + +@_onnx_symbolic("aten::flip") +@symbolic_helper.parse_args("v", "is") +def flip(g: jit_utils.GraphContext, input, dims): + return symbolic_helper._slice_helper( + g, + input, + axes=dims, + starts=[-1] * len(dims), + ends=[-_constants.INT64_MAX] * len(dims), + steps=[-1] * len(dims), + ) + + +@_onnx_symbolic("aten::fmod") +def fmod(g: jit_utils.GraphContext, input, other): + return g.op("Mod", input, other, fmod_i=1) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return symbolic_helper._onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + warnings.warn( + "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " + "Please use opset 11 or higher to export model for dynamic input shape.'" + ) + offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: + if include_last_offset: + offset_len = offsets_dim_0 - 1 + offsets_extended = offsets + else: + offset_len = offsets_dim_0 + offsets_extended = [ + offsets, + g.op("Constant", value_t=torch.tensor([sys.maxsize])), + ] + offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) + list_ = [] + for i in range(offset_len): + start_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), + [0], + ) + end_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select( + g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) + ), + [0], + ) + axes_ = g.op("Constant", value_t=torch.tensor([0])) + indices_row = g.op("Slice", indices, start_, end_, axes_) + + embeddings = g.op("Gather", embedding_matrix, indices_row) + if not symbolic_helper._is_none(per_sample_weights): + per_sample_weights_row = g.op( + "Slice", per_sample_weights, start_, end_, axes_ + ) + per_sample_weights_row = symbolic_helper._unsqueeze_helper( + g, per_sample_weights_row, [1] + ) + embeddings = g.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = symbolic_helper._reducesum_helper( + g, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) + else: + embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) + + embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) + list_.append(embeddings) + + output = g.op("Concat", *list_, axis_i=0) + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return output, None, None, None + else: + return symbolic_helper._onnx_unsupported( + "embedding_bag with unknown shape of offsets for opset 10 is not supported. " + "please use opset 11 or higher." + ) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) == (0, 127): + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Quantize range (0, 127) not supported, requires opset 13 Clip", + inputs, + ) + if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: + raise errors.SymbolicValueError( + f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + scale = symbolic_helper._maybe_get_scalar(scale) + if scale is None: + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Non-constant scale not supported", + inputs, + ) + scale = scale.float().data # Avoid exporter generating double type + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + return g.op( + "DequantizeLinear", + g.op("QuantizeLinear", inputs, scale, zero_point), + scale, + zero_point, + ) + + +@_onnx_symbolic("aten::isinf") +def isinf(g: jit_utils.GraphContext, input): + return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) + + +@_onnx_symbolic("aten::isfinite") +def isfinite(g: jit_utils.GraphContext, input): + inf_node = isinf(g, input) + nan_node = opset9.isnan(g, input) + return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) + + +@_onnx_symbolic("aten::quantize_per_tensor") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + # TODO(justinchuby): Extract all the cast ops into a helper function. + zero_point = g.op( + "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return symbolic_helper.quantize_helper(g, input, scale, zero_point) + + +@_onnx_symbolic("aten::dequantize") +def dequantize(g: jit_utils.GraphContext, input): + return symbolic_helper.dequantize_helper(g, input)[0] + + +@_onnx_symbolic("aten::nan_to_num") +@symbolic_helper.parse_args("v", "f", "f", "f") +def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): + # Cannot create a int type tensor with inf/nan values, so we simply + # return the original tensor + if not symbolic_helper._is_fp(input): + return input + input_dtype = _type_utils.JitScalarType.from_value(input).dtype() + if nan is None: + nan = 0.0 + nan_cond = opset9.isnan(g, input) + nan_result = g.op( + "Where", + nan_cond, + g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), + input, + ) + + # For None values of posinf, neginf we use the greatest/lowest finite + # value representable by input's dtype. + finfo = torch.finfo(input_dtype) + if posinf is None: + posinf = finfo.max + posinf_cond = opset9.logical_and( + g, + isinf(g, nan_result), + opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + nan_posinf_result = g.op( + "Where", + posinf_cond, + g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), + nan_result, + ) + + if neginf is None: + neginf = finfo.min + neginf_cond = opset9.logical_and( + g, + isinf(g, nan_posinf_result), + opset9.lt( + g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) + ), + ) + return g.op( + "Where", + neginf_cond, + g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), + nan_posinf_result, + ) + + +# Quantized symbolics --------------------------------------------------------- +# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were +# introduced in opset version 10. +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add") +def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add_relu") +def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::mul") +def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.mul(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::sigmoid") +def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.sigmoid(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::leaky_relu") +def quantized_leaky_relu( + g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.leaky_relu(g, x, negative_slope, inplace) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::group_norm") +def quantized_group_norm( + g: jit_utils.GraphContext, + x, + num_groups, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") +def quantized_instance_norm( + g: jit_utils.GraphContext, + q_input, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) + + output = opset9.instance_norm( + g, input, weight, bias, None, None, False, 0.0, eps, False + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::cat") +@symbolic_helper.parse_args("v", "i", "v", "v") +def quantized_cat( + g: jit_utils.GraphContext, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, +) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py new file mode 100644 index 000000000000..f437e2670768 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py @@ -0,0 +1,1472 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 11.""" + +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset10 as opset10, + symbolic_opset9 as opset9, + utils, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "add", + "append", + "arange", + "argsort", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "cat", + "chunk", + "clamp_max", + "clamp_min", + "clamp", + "constant_pad_nd", + "cumsum", + "Delete", + "embedding_bag", + "embedding_renorm", + "flatten", + "gather", + "hardtanh", + "hstack", + "im2col", + "index_fill", + "index", + "index_copy", + "index_put", + "insert", + "linalg_det", + "linalg_vector_norm", + "logdet", + "masked_scatter", + "masked_select", + "mm", + "narrow", + "normal", + "pad", + "pixel_shuffle", + "pop", + "prim_constant_chunk", + "reflection_pad", + "relu6", + "remainder", + "replication_pad", + "round", + "scatter", + "select", + "size", + "sort", + "split_with_sizes", + "split", + "squeeze", + "stack", + "topk", + "unbind", + "unique_dim", + "unsqueeze", + "vstack", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), + ) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_val, max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not symbolic_helper._is_none(tensor): + return g.op( + "Cast", + tensor, + to_i=dtype.onnx_type(), + ) + else: + return tensor + + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + min = _cast_if_not_none(min, scalar_type) + max = _cast_if_not_none(max, scalar_type) + + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if ( + symbolic_helper._get_tensor_rank(min) == 0 + and symbolic_helper._get_tensor_rank(max) == 0 + ): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(min) == 0: + max = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(max) == 0: + min = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::relu6") +def relu6(g: jit_utils.GraphContext, input): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(6, dtype=scalar_type.dtype()), + ) + return clamp(g, input, min_val, max_val) + + +@_onnx_symbolic("aten::select") +# Opset 11 gather accepts negative indices +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::index_put") +def index_put( + g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False +): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + return values + + if len(indices_list) > 1: + for idx_ in range(len(indices_list)): + if symbolic_helper._is_bool(indices_list[idx_]): + indices_list[idx_] = g.op("NonZero", indices_list[idx_]) + index = indices_list[0] + + for ind in indices_list[1:]: + index = opset9.add(g, index, ind) + broadcast_index_shape = g.op("Shape", index) + indices_list = [ + symbolic_helper._unsqueeze_helper( + g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] + ) + for ind in indices_list + ] + index = g.op("Concat", *indices_list, axis_i=-1) + else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains a single boolean input. + # + # index_put -> masked_fill + # * input index contains single tensor of Bool type (e.g.: %24 <- %23). + # * input value contains single element (e.g.: %18). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # + # index_put -> masked_scatter + # * input index contains single tensor of Bool type (e.g.: %32 <- %31). + # * input value contains multiple elements (e.g.: %28). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + index = indices_list[0] + bool_inp = index + if symbolic_helper._is_bool(bool_inp): + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + return opset9.masked_fill(g, self, bool_inp, values) + mask_rank = symbolic_helper._get_tensor_rank(bool_inp) + self_rank = symbolic_helper._get_tensor_rank(self) + if ( + mask_rank is not None + and self_rank is not None + and self_rank > mask_rank + ): + # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. + bool_inp = symbolic_helper._unsqueeze_helper( + g, bool_inp, list(range(mask_rank, self_rank)) + ) + return masked_scatter(g, self, bool_inp, values) + broadcast_index_shape = g.op("Shape", index) + index = symbolic_helper._unsqueeze_helper(g, index, [-1]) + sub_data_shape = symbolic_helper._slice_helper( + g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] + ) + values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) + # Check if values is a singular value and expand accordingly + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + values = opset9.expand(g, values, values_shape, None) + values = symbolic_helper._reshape_helper(g, values, values_shape) + + self_scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: + values_scalar_type = _type_utils.JitScalarType.from_value( + values, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != values_scalar_type: + values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) + elif accumulate: + raise errors.SymbolicValueError("self does not have a valid scalar type.", self) + + if accumulate: + zeros = g.op( + "ConstantOfShape", + g.op("Shape", self), + value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), + ) + result = g.op("ScatterND", zeros, index, values) + result = add(g, self, result) + else: + result = g.op("ScatterND", self, index, values) + + return result + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None and rank != 4: + return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") + return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bicubic2d", + decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) + + +@_onnx_symbolic("aten::__interpolate") +@symbolic_helper.quantized_args(True, False, False, False, False, False, False) +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + return symbolic_helper.__interpolate_helper( + g, input, size, scale_factor, mode, align_corners, recompute_scale_factor + ) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True") + return g.op("GatherElements", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value(src) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return g.op( + "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): + dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + cast = g.op( + "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + else: + cast = self + csum = g.op("CumSum", cast, dim_tensor) + return csum + + +@_onnx_symbolic("aten::masked_select") +def masked_select(g: jit_utils.GraphContext, self, mask): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + return g.op("GatherND", self, index) + + +@_onnx_symbolic("aten::masked_scatter") +def masked_scatter(g: jit_utils.GraphContext, self, mask, source): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) + source = symbolic_helper._slice_helper( + g, + source, + axes=torch.LongTensor([0]), + starts=torch.LongTensor([0]), + ends=opset9.size(g, index, torch.LongTensor([0])), + ) + return g.op("ScatterND", self, index, source) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + if ( + symbolic_helper._is_tensor_list(self) + or self.node().kind() == "onnx::SplitToSequence" + ): + return g.op("SequenceLength", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + if symbolic_helper._is_tensor_list(self): + # SequenceAt requires that the input be a List of Tensors + return g.op("SequenceAt", self, i) + else: + from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( + __getitem_ as getitem, + ) + + return getitem(g, self, i) + + +@_onnx_symbolic("aten::_set_item") +def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): + tensor_list = g.op("SequenceErase", tensor_list, i) + return g.op("SequenceInsert", tensor_list, v, i) + + +@_onnx_symbolic("aten::append") +def append(g: jit_utils.GraphContext, self, tensor): + return g.op("SequenceInsert", self, tensor) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + tensor_list_node = other.node() + if tensor_list_node.kind() != "prim::ListConstruct": + return symbolic_helper._unimplemented( + "add", "does not support adding dynamic tensor list to another" + ) + tensors = symbolic_helper._unpack_list(other) + l = self + for t in tensors: + l = g.op("SequenceInsert", l, t) + return l + + return opset9.add(g, self, other, alpha) + + +@_onnx_symbolic("aten::insert") +def insert(g: jit_utils.GraphContext, self, pos, tensor): + return g.op("SequenceInsert", self, tensor, pos) + + +@_onnx_symbolic("aten::pop") +def pop(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::Delete") +def Delete(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.quantized_args(True) +def cat(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.cat(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.stack(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::unique_dim") +@symbolic_helper.parse_args("v", "i", "i", "i", "i") +def unique_dim( + g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts +): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): + return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) + + +@_onnx_symbolic("aten::argsort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def argsort(g: jit_utils.GraphContext, self, dim, descending, out=None): + _, indices = symbolic_helper._sort_helper( + g, self, dim, descending=descending, out=out + ) + return indices + + +@_onnx_symbolic("aten::round") +@symbolic_helper.parse_args("v", "i") +def round(g: jit_utils.GraphContext, self, decimals=0): + if not symbolic_helper._is_fp(self): + return self + if decimals == 0: + return g.op("Round", self) + mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) + round = g.op("Round", mul) + return g.op( + "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) + ) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): + return opset9.remainder(g, input, other) + return g.op("Mod", input, other, fmod_i=0) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + else: + return opset9.split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + else: + return opset9.unbind(g, self, dim, _outputs) + + +def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): + """Generate paddings in ONNX order based on pad in pytorch. + + Args: + input: the input tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, + where m is in range [0, n]. + """ + if ( + not symbolic_helper._is_packed_list(pad) + and symbolic_helper._is_list(pad) + and symbolic_helper._is_scalar_list(pad) + ): + pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning + pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) + # Set extension = [0] * (dim * 2 - len(pad)) + rank = symbolic_helper._get_tensor_rank(input) + if rank is None: + rank = g.op("Size", g.op("Shape", input)) + else: + rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) + extension = g.op( + "Sub", + g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), + pad_len, + ) + # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] + # Currently ONNX only supports int64 type for Pad + pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) + paddings = g.op( + "Concat", + pad, + g.op( + "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) + ), + axis_i=0, + ) + # Reshape and reverse order and collate first beginnings and then ends + # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], + # [..., 0, dim_n-1_end, dim_n_end]] + # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) + ) + paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1])) + ) + padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) + return padding_c + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): + mode = "constant" + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, input) + pad = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, pad, value, mode_s=mode) + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return opset9._pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic("aten::linalg_det") +def linalg_det(g: jit_utils.GraphContext, self): + return g.op("Det", self) + + +@_onnx_symbolic("aten::logdet") +def logdet(g: jit_utils.GraphContext, input): + return opset9.log(g, linalg_det(g, input)) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + if len(args) == 2 and all(isinstance(val, int) for val in args): + # aten::arange(Scalar start, Scalar end) + dtype = torch.int64 + # Start index. + start = g.op( + "Constant", + value_t=torch.tensor(args[0], dtype=dtype), + ) + # End (exclusive) index. + end = g.op( + "Constant", + value_t=torch.tensor(args[1], dtype=dtype), + ) + # Step size from start to end indexes. + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=dtype), + ) + return g.op("Range", start, end, delta_default) + elif len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + start_default = g.op( + "Constant", + value_t=torch.tensor(0, dtype=type_.dtype()), + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start_default, end, delta_default) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + _, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + return g.op("Range", start, end, step) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start, end, delta_default) + else: + return symbolic_helper._unimplemented( + "aten::arange", f"with {len(args)} arguments" + ) + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + # dim as a tensor + if not symbolic_helper._is_constant(dim): + return symbolic_helper._squeeze_helper(g, self, [dim]) + + dim = symbolic_helper._get_const(dim, "i", "dim") + + input_rank = symbolic_helper._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: + # If onnx shape inference is not on, export always as dynamic. + # Because we cannot tell if observed static shape is also static at runtime. + # create "cond" node (condition is shape[i]==1) + dim_constant = g.op("Constant", value_t=torch.tensor([dim])) + size = symbolic_helper._size_helper(g, self, dim_constant) + const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) + cond = g.op("Equal", size, const_one) + # create the "If" node and add the "then" and "else" blocks to it. + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", cond, n_blocks=2 + ) + squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) + utils._add_output_to_block(if_context.block, squeeze_) + identity_ = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, identity_) + return if_op + + # For static input shape + dim = adjusted_dim + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please export with dynamic_axes argument." + ) + return self + return symbolic_helper._squeeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::unsqueeze") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + if symbolic_helper._is_constant(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + + return symbolic_helper._unsqueeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not symbolic_helper._is_none(index) and ( + symbolic_helper._is_bool(index) + or _type_utils.JitScalarType.from_value(index) + == _type_utils.JitScalarType.UINT8 + ): + index = opset9.nonzero(g, index) + return g.op("GatherND", self, index) + return opset9.index(g, self, index) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = opset9.expand(g, value, expanded_index_shape, None) + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bitwise_right_shift") +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="RIGHT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::bitwise_left_shift") +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="LEFT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +def _get_im2col_indices_along_dim( + g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = g.op( + "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) + ) + blocks_d = g.op( + "Sub", + blocks_d, + g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), + ) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = g.op( + "Range", + g.op("Constant", value_t=torch.tensor(0)), + blocks_d, + g.op("Constant", value_t=torch.tensor(stride_d)), + ) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) + kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + blocks_d_indices = symbolic_helper._unsqueeze_helper( + g, blocks_d_indices, [0] + ) # Reshape to [1, -1] + kernel_mask = symbolic_helper._reshape_helper( + g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) + ) + block_mask = g.op("Add", blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) + return g.op("Pad", input, pad) + + +def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): + batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) + channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) + channel_unfolded = g.op( + "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) + ) + + return g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), + symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), + g.op("Constant", value_t=torch.tensor([-1])), + axis_i=0, + ) + + +@_onnx_symbolic("aten::im2col") +@symbolic_helper.parse_args("v", "is", "is", "is", "is") +def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): + # Input is always 4-D tensor (N, C, H, W) + # All other args are int[2] + + input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) + input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) + + stride_h, stride_w = stride[0], stride[1] + padding_h, padding_w = padding[0], padding[1] + dilation_h, dilation_w = dilation[0], dilation[1] + kernel_h, kernel_w = kernel_size[0], kernel_size[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + g, input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + g, input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) + output = g.op("Gather", output, blocks_col_indices, axis_i=4) + output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) + return symbolic_helper._reshape_helper(g, output, output_shape) + + +@_onnx_symbolic("aten::narrow") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + end = g.op("Add", start, length) + return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim == 1: + return input + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1: + if end_dim == -1 or (dim is not None and end_dim == dim - 1): + return g.op("Flatten", input, axis_i=start_dim) + elif start_dim == 0: + if end_dim == -2 or (dim is not None and end_dim == dim - 2): + return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + # if end_dim is negative add dim + if end_dim < 0: + end_dim = dim + end_dim + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self, + ord, + dim: Sequence[int] | None, + keepdim: bool, + dtype, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::embedding_renorm") +@symbolic_helper.parse_args("v", "v", "f", "f") +def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): + unique_indices = g.op("Unique", indices) + partial_weight = g.op("Gather", weight, unique_indices) + norm_i = int(norm_type) + if norm_i == 1: + norm_type = "ReduceL1" + elif norm_i == 2: + norm_type = "ReduceL2" + else: + raise errors.SymbolicValueError( + f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " + "Only 1. and 2. are supported.", + weight, + ) + partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) + # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 + # Add 1e-7 to prevent division by zero. + partial_weight_norm_ = g.op( + "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) + ) + max_norm = torch.tensor(max_norm) + scales = g.op("Div", max_norm, partial_weight_norm_) + partial_weight_renorm = g.op("Mul", partial_weight, scales) + partial_weight_renorm = g.op( + "Where", + g.op("Greater", partial_weight_norm, max_norm), + partial_weight_renorm, + partial_weight, + ) + return g.op( + "ScatterND", + weight, + symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), + partial_weight_renorm, + ) + + +@_onnx_symbolic("aten::chunk") +def chunk(g: jit_utils.GraphContext, self, chunks, dim): + # Calculate chunk size for dynamic chunk + dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) + chunk_size_s = g.op( + "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) + ) + chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) + # Create splits vector + chunk_vec = [ + opset9.expand(g, chunk_size, chunk_size_s, None), + g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), + ] + chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) + return split(g, self, chunk_vec, dim) + + +@_onnx_symbolic("aten::normal") +def normal( + g: jit_utils.GraphContext, + mean, + std, + sizes=None, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a + # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample + # from a mean 0 and variance 1 distribution then + # sigma x+mu + # is a sample with mean mu and variance sigma's square. + if sizes is not None and not symbolic_helper._is_none(sizes): + mean = opset9.expand(g, mean, sizes, None) + result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) + return add(g, result, mean) + + +@_onnx_symbolic("aten::atleast_1d") +def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 1D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1])) + ) + return self + + +@_onnx_symbolic("aten::atleast_2d") +def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 2D + # If it's 1D, unsqueeze to 2D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + return self + + +@_onnx_symbolic("aten::atleast_3d") +def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 3D + # If it's 1D, unsqueeze to 3D + # If it's 2D, unsqueeze to 3D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + elif tensor_rank == 2: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + elif tensor_rank == 2: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + return self + + +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op( + "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) + ) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + + +@_onnx_symbolic("aten::hstack") +def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_1d(g, tensor_list) + first_tensor = g.op( + "SequenceAt", + tensor_list, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), + ) + first_tensor_shape = g.op("Shape", first_tensor) + first_tensor_dim = g.op("Size", first_tensor_shape) + + const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + equal_to_one = g.op("Equal", first_tensor_dim, const_one) + + ( + if_op_greater, + (if_context_equal, else_context_equal), + _, + ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) + result_if = if_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 + ) + utils._add_output_to_block(if_context_equal.block, result_if) + result_else = else_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 + ) + utils._add_output_to_block(else_context_equal.block, result_else) + result = if_op_greater.node().output() + + return result + + +@_onnx_symbolic("aten::vstack") +def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_2d(g, tensor_list) + return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py new file mode 100644 index 000000000000..431660409717 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py @@ -0,0 +1,465 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, + utils, +) + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 12 + +__all__ = [ + "argmax", + "argmin", + "binary_cross_entropy_with_logits", + "celu", + "cross_entropy_loss", + "dropout", + "einsum", + "ge", + "le", + "native_dropout", + "nll_loss", + "nll_loss2d", + "nll_loss_nd", + "outer", + "pow", + "tensordot", + "unfold", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) + + +def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): + if not tensors: + raise RuntimeError("Einsum inputs are empty.") + # ONNX does not support bool for Einsum inputs. + if symbolic_helper._is_bool(tensors[0]): + tensors = [ + g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) + for tensor in tensors + ] + return g.op( + "Cast", + g.op("Einsum", *tensors, equation_s=equation), + to_i=_C_onnx.TensorProtoDataType.BOOL, + ) + else: + return g.op("Einsum", *tensors, equation_s=equation) + + +@_onnx_symbolic("aten::einsum") +@symbolic_helper.parse_args("s", "v", "is") +def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): + tensors = symbolic_helper._unpack_list(tensor_list) + return _einsum_helper(g, equation, tensors) + + +@_onnx_symbolic("aten::outer") +@symbolic_helper.parse_args("v", "v") +def outer(g: jit_utils.GraphContext, input, other): + # make sure to cast other to self's type + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(input): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), + ) + return _einsum_helper(g, "i,j->ij", [input, other]) + + +def _dropout_returns_masked_input_and_mask( + g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool +) -> tuple[torch._C.Value, torch._C.Value | None]: + symbolic_helper.check_training_mode(train, "dropout") + # In eval mode, dropout is non-op. That is, if the node's + # train param is set to False, dropout just returns its inputs. + if not train: + return input, None + p = g.op("Constant", value_t=torch.tensor(p)) + t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) + r, mask = g.op("Dropout", input, p, t, outputs=2) + return r, mask + + +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "b") +def dropout(g: jit_utils.GraphContext, input, p, train): + masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) + return masked + + +@_onnx_symbolic("aten::native_dropout") +@symbolic_helper.parse_args("v", "f", "b") +def native_dropout(g: jit_utils.GraphContext, input, p, train): + return _dropout_returns_masked_input_and_mask(g, input, p, train) + + +@_onnx_symbolic("aten::nll_loss") +def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return nllloss + + +@_onnx_symbolic("aten::nll_loss2d") +def nll_loss2d( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::nll_loss_nd") +def nll_loss_nd( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::cross_entropy_loss") +def cross_entropy_loss( + g: jit_utils.GraphContext, + self, + target, + weight, + reduction, + ignore_index, + label_smoothing, +): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") + if label_smoothing is not None and label_smoothing > 0.0: + raise errors.SymbolicValueError( + "Unsupported: ONNX does not support label_smoothing", self + ) + + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return celoss + + +@_onnx_symbolic("aten::binary_cross_entropy_with_logits") +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def binary_cross_entropy_with_logits( + g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction +): + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = opset9.sigmoid(g, input) + log_sig_x = opset9.log(g, sig_x) + sub_1_x = opset9.sub(g, p, sig_x) + sub_1_y = opset9.sub(g, p, target) + log_1_x = opset9.log(g, sub_1_x) + if pos_weight is None or symbolic_helper._is_none(pos_weight): + output = opset9.neg( + g, + opset9.add( + g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) + ), + ) + else: + output = opset9.neg( + g, + opset9.add( + g, + opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), + opset9.mul(g, sub_1_y, log_1_x), + ), + ) + + if weight is not None and not symbolic_helper._is_none(weight): + output = opset9.mul(g, weight, output) + + reduction = symbolic_helper._maybe_get_const(reduction, "i") + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return g.op("ReduceSum", output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", + input, + ) + + +@_onnx_symbolic("aten::celu") +def celu(g: jit_utils.GraphContext, self, alpha): + alpha = symbolic_helper._maybe_get_const(alpha, "f") + # if the input is of type double cast it to float + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.DOUBLE + ): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = g.op("Celu", self, alpha_f=alpha) + return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + return g.op("Celu", self, alpha_f=alpha) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + return g.op("Pow", self, exponent) + + +@_onnx_symbolic("aten::ge") +def ge(g: jit_utils.GraphContext, input, other): + return g.op("GreaterOrEqual", input, other) + + +@_onnx_symbolic("aten::le") +def le(g: jit_utils.GraphContext, input, other): + return g.op("LessOrEqual", input, other) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "v", "v") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + const_size = symbolic_helper._maybe_get_const(size, "i") + const_step = symbolic_helper._maybe_get_const(step, "i") + if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( + const_step + ): + return opset9.unfold(g, input, dimension, const_size, const_step) + + sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) + if sizedim is not None: + low_start = g.op("Constant", value_t=torch.tensor(0)) + low_end = g.op("Constant", value_t=torch.tensor(sizedim)) + hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) + low_indices = g.op("Range", low_start, low_end, step) + hi_indices = g.op("Range", size, hi_end, step) + + low_size = symbolic_helper._size_helper( + g, low_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + hi_size = symbolic_helper._size_helper( + g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + + unsqueeze_list = [] + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + loop_len = g.op("Min", low_size, hi_size) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + + starts = loop_context.op("Gather", low_indices, block_input_iter) + ends = loop_context.op("Gather", hi_indices, block_input_iter) + axes = loop_context.op("Constant", value_t=torch.tensor([2])) + starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) + ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) + stack = loop_context.op("Slice", input, starts, ends, axes) + + unsqueeze = symbolic_helper._unsqueeze_helper( + loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] + ) + unsqueeze_list.append(unsqueeze) + concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, concat) + + loop_output = loop.node().output() + perm = [0, 1, 2, 3, 4] + perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] + transpose = g.op("Transpose", loop_output, perm_i=perm) + squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) + + return squeeze + + return symbolic_helper._unimplemented("Unfold", "input size not accessible") + + +@_onnx_symbolic("aten::tensordot") +@symbolic_helper.parse_args("v", "v", "is", "is", "v") +def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Tensordot", "Out parameter is not supported for tensordot." + ) + + dim_count_a = symbolic_helper._get_tensor_rank(input_a) + if dim_count_a is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", + input_a, + ) + + dim_count_b = symbolic_helper._get_tensor_rank(input_b) + if dim_count_b is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", + input_b, + ) + + dims_a = [ + (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] + for i in range(len(dims_a)) + ] + dims_b = [ + (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] + for i in range(len(dims_b)) + ] + + left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] + left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] + + new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) + new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) + + input_shape = g.op("Shape", new_input_a) + left_sizes_a = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] + ) + shape_sizes = [ + left_sizes_a, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", output_a) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", new_input_b) + left_sizes_b = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] + ) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] + ) + shape_sizes = [ + slices, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + input_shape = g.op("Shape", output_b) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) + + shape_sizes = [left_sizes_a, left_sizes_b] + return opset9._reshape_from_tensor(g, output, shape_sizes) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py new file mode 100644 index 000000000000..e9da6a426f7f --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py @@ -0,0 +1,1113 @@ +# mypy: allow-untyped-defs +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 13 +import functools + +import torch +import torch._C._onnx as _C_onnx +from torch.onnx import _constants, errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset11 as opset11, + symbolic_opset9 as opset9, + utils, +) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + + return softmax + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return return_op + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "v", "i") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + dim_val = symbolic_helper._maybe_get_const(dim, "is") + if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: + return g.op("ReduceL2", self, keepdims_i=0) + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + splits = g.op("Constant", value_t=torch.tensor(splits)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::tensor_split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def tensor_split( + g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None +): + axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + axis = opset11.unsqueeze(g, axis, 0) + const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + + if symbolic_helper._is_split_static(indices_or_sections, _outputs): + split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") + + if split_val.dim() > 0: + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + res = [] + assert _outputs is not None + for i in range(_outputs - 1): + end = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + axis_i=0, + ) + res.append(g.op("Slice", self, start, end, axis)) + start = end + + end = symbolic_helper._size_helper(g, self, axis) + res.append(g.op("Slice", self, start, end, axis)) + return res + + split_size = symbolic_helper._get_const( + indices_or_sections, "i", "indices_or_sections" + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + + min_split_size = size // split_size + num_splits_one_extra = size % split_size + + splits = num_splits_one_extra * [min_split_size + 1] + leftover = (split_size - num_splits_one_extra) * [min_split_size] + + splits = g.op( + "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) + ) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + if ( + symbolic_helper._is_tensor(indices_or_sections) + and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 + ): + loop_len = symbolic_helper._size_helper( + g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) + ) + loop_len = opset11.unsqueeze(g, loop_len, 0) + loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) + + # To make the first slice in the below loop work, + # we pad a zero to the first position so that it will be the initial start of slice. + padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) + + final_splits = g.op("SequenceEmpty") + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + start = loop_context.op( + "Gather", indices_or_sections, block_input_iter, axis_i=0 + ) + end = loop_context.op( + "Gather", + indices_or_sections, + loop_context.op("Add", block_input_iter, const_1), + axis_i=0, + ) + + slice = loop_context.op("Slice", self, start, end, axis) + final_splits = loop_context.op("SequenceInsert", final_splits, slice) + + # Loop outputs + cond_out = loop_context.op("Identity", loop_condition) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + start = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), + axis_i=0, + ) + start = opset11.unsqueeze(g, start, 0) + end = symbolic_helper._size_helper(g, self, axis) + + last_slice = g.op("Slice", self, start, end, axis) + + return g.op("SequenceInsert", loop_out, last_slice) + + else: # scalar tensor + dim_size = symbolic_helper._size_helper(g, self, axis) + min_split_size = g.op("Div", dim_size, indices_or_sections) + min_split_size_plus_1 = g.op( + "Add", + min_split_size, + const_1, + ) + num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) + splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) + leftover = g.op( + "Tile", + min_split_size, + g.op( + "Sub", + opset11.unsqueeze(g, indices_or_sections, 0), + num_splits_one_extra, + ), + ) + + splits = g.op("Concat", splits, leftover, axis_i=0) + if _outputs is None: + return g.op("SplitToSequence", self, splits, axis_i=dim) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) + outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) + for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = opset9.nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::fake_quantize_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") +def fake_quantize_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) + + +def _reduce_op_symbolic(onnx_op_name): + def symbolic(g, self, dim=None, keepdim=None): + self = symbolic_helper._maybe_cast_reduce_op_input(g, self) + if dim is None: + # all-reduce path + return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) + else: + keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") + return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) + + return symbolic + + +@_onnx_symbolic( + "aten::sum", + decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], +) +def _reduce_with_dtype(onnx_op, name): + symbolic = _reduce_op_symbolic(onnx_op) + + @symbolic_helper._overload_by_arg_count + def reduce(g, *args, **kwargs): + @symbolic_helper.parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + @symbolic_helper.parse_args("v", "v", "i", "none") + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 +# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... +@_onnx_symbolic("aten::unflatten") +def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + + # dim could be negative + input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) + dim = g.op("Add", input_dim, dim) + dim = g.op("Mod", dim, input_dim) + + input_size = g.op("Shape", input) + + head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) + head_end_idx = g.op( + "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) + + dim_plus_one = g.op( + "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + tail_start_idx = g.op( + "Reshape", + dim_plus_one, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), + ) + tail_end_idx = g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ) + tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) + + final_shape = g.op( + "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 + ) + + return symbolic_helper._reshape_helper(g, input, final_shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + + # TODO: So far we don"t have a module using this method. We"ll keep + # this as a constant unless we see a request of dynamics in any + # user's modules. + splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::tile") +def tile(g: jit_utils.GraphContext, self, dims): + self_shape = g.op("Shape", self) + self_rank = g.op("Size", self_shape) + dims_rank = g.op("Size", dims) + diff = g.op("Sub", self_rank, dims_rank) + const_zero = g.op("Constant", value_t=torch.tensor([0])) + + # 1. If dims is shorter than self.shape pad dims with 1 + dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) + ( + if_op_greater, + (if_context_greater, else_context_greater), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) + exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) + dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) + utils._add_output_to_block(if_context_greater.block, dims_) + identity_dim = else_context_greater.op("Identity", dims) + utils._add_output_to_block(else_context_greater.block, identity_dim) + dims_final = if_op_greater.node().output() + + # 2. If dims is longer than self.shape pad self.shape with 1 + dims_longer_than_self_shape = g.op("Less", diff, const_zero) + ( + if_op_less, + (if_context_less, else_context_less), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_less = if_context_less.op( + "Reshape", + if_context_less.op("Abs", diff), + const_one, + ) + exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) + self_final_shape = if_context_less.op( + "Concat", exapnd_ones_less, self_shape, axis_i=0 + ) + self_ = if_context_less.op("Reshape", self, self_final_shape) + utils._add_output_to_block(if_context_less.block, self_) + identity_self = else_context_less.op("Identity", self) + utils._add_output_to_block(else_context_less.block, identity_self) + self_final = if_op_less.node().output() + + dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Tile", self_final, dims_final) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + + # Check if all indices should be repeated the same number of times. + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = symbolic_helper._size_helper(g, self, dim) + reps = opset11.unsqueeze(g, reps, 0) + + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + if cond_dynamic_repeats: + repeat_dim = symbolic_helper._size_helper( + g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) + ) + repeat_cond = g.op( + "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) + ) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op( + "ConstantOfShape", + g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long), + ) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, self, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) + + i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) + r_concat = [ + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), + r_split, + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), + ] + r_concat = loop_context.op("Concat", *r_concat, axis_i=0) + i_split = opset9.expand(loop_context, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) + ) + final_splits = loop_context.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out + + +@_onnx_symbolic("aten::diagonal") +@symbolic_helper.parse_args("v", "i", "i", "i") +def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): + rank = symbolic_helper._get_tensor_rank(self) + # Replace negative indexing when rank is known + if rank is not None: + dim1 = dim1 if dim1 >= 0 else dim1 + rank + dim2 = dim2 if dim2 >= 0 else dim2 + rank + + dim1_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) + ) + dim2_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) + ) + # Create appropriate mask + mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) + mask = opset9.zeros(g, mask_shape, None, None, None) + mask = g.op("EyeLike", mask, k_i=offset) + # dim1 and dim2 appended as a dimension at the end of the shape + + if rank is not None: + axes = list(range(rank)) + axes.remove(dim1) + axes.remove(dim2) + self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) + else: + return symbolic_helper._unimplemented("diagonal", "unknown input rank") + + # Multiply input and mask to calculate values along diagonal + # The mask consists of one values where diagonal values are to be calculated + # For example: + # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], + # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], + # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] + result = g.op("Mul", self, mask) + result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) + + # Calculate gather indices based on offset and dims + # If offset is greater than zero, set offset to zero as this aids in + # calculation of selection window + offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) + if offset >= 0: + diag_size = g.op( + "Max", + g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + offset = 0 + else: + diag_size = g.op( + "Max", + g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + diag_size = g.op("Concat", diag_size, axis_i=0) + + # Calculate which diagonal values to select + # For example, in cases with offsets: + # [[0, 1.1, 0] + # [0, 0, 2.2]] + # we need to select the last two columns, so we create a tensor + # with all columns that are to be selected + # So in this example, it is [1, 2] + select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) + select_window = g.op( + "CumSum", + select_window_ones_fill, + g.op("Constant", value_t=torch.LongTensor([0])), + ) + select_window = g.op( + "Add", + select_window, + g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), + ) + + gather_shape = [ + opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) + for axis in list(range(rank))[:-2] + ] + gather_shape.append(diag_size) + gather_shape = g.op("Concat", *gather_shape, axis_i=0) + gather_indices = opset9.zeros(g, gather_shape, 4, None, None) + + # There might be cases where offset value is greater than number of rows/columns + # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. + # For example, if + # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) + # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above + # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 + # In cases without diagonal overrun, we select the appropriate rows/columns along which we + # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has + # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially + # returning an empty tensor + overrun_cond = g.op( + "Not", + g.op( + "Equal", + diag_size, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), + ), + ) + + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", overrun_cond, n_blocks=2 + ) + + gather_indices_if_block = if_context.op("Add", gather_indices, select_window) + gather_indices_if_block = symbolic_helper._unsqueeze_helper( + if_context, gather_indices_if_block, [rank - 1] + ) + final_non_overrun = if_context.op( + "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 + ) + final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) + utils._add_output_to_block(if_context.block, final_non_overrun) + utils._add_output_to_block(else_context.block, final_overrun) + return if_op + + +# Quantized ops + + +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py new file mode 100644 index 000000000000..5675f362893e --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py @@ -0,0 +1,296 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 14. + +Note [ONNX operators that are added/updated in opset 14] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + HardSwish, Trilu + +Updated operators: + Reshape + Add, Sub, Mul, Div + GRU, LSTM, RNN + BatchNorm, Cumsum, Relu +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md +from __future__ import annotations + +import functools + +import torch +from torch.onnx import _constants +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, +) +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + + +__all__ = [ + "hardswish", + "tril", + "triu", + "reshape", + "batch_norm", + "quantized_hardswish", + "scaled_dot_product_attention", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + return g.op("HardSwish", self) + + +@_onnx_symbolic("aten::tril") +def tril(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=0) + + +@_onnx_symbolic("aten::triu") +def triu(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=1) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v") +def reshape(g: jit_utils.GraphContext, self, shape): + # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 + # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. + return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 14, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + symbolic_helper.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + training_mode_i=0 if not training else 1, + outputs=1 if not training else 3, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + return res + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 +# aten_scaled_dot_product_attention +# NOTE: Need op.Trilu +@_onnx_symbolic("aten::scaled_dot_product_attention") +@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") +def scaled_dot_product_attention( + g: jit_utils.GraphContext, + query: torch._C.Value, + key: torch._C.Value, + value: torch._C.Value, + attn_mask: torch._C.Value | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: torch._C.Value | None = None, + enable_gqa: bool = False, +): + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) + + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) + + if is_causal: + attn_mask = _causal_attention_mask(g, query, key) + + # Swap the last two axes of key + # NOTE: onnx-script has different logic here, because the attribute perms in + # transpose needs list of ints + key_shape_builtin = symbolic_helper._get_tensor_rank(key) + key_transposed_axes = list(range(key_shape_builtin)) + key_transposed_axes[-1], key_transposed_axes[-2] = ( + key_transposed_axes[-2], + key_transposed_axes[-1], + ) + key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) + key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) + mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) + + if symbolic_helper._is_none(attn_mask): + mul_qk_add = mul_qk + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + elif ( + _type_utils.JitScalarType.from_value(attn_mask) + == _type_utils.JitScalarType.BOOL + ): + # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) + mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values + # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. + # This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match + # the behavior of PyTorch with boolean masks. + attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight) + elif _type_utils.JitScalarType.from_value(attn_mask) in ( + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + ): + mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + else: + raise ValueError( + f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" + ) + + if dropout_p != 0: + attn_weight = g.op( + "Dropout", + attn_weight, + g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), + ) + + return g.op("MatMul", attn_weight, value) + + +def _attention_scale( + g: jit_utils.GraphContext, query: torch._C.Value +) -> torch._C.Value: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + query_shape = g.op("Shape", query) + query_shape_last = g.op( + "Slice", + query_shape, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ), + ) + embedding_size = g.op( + "Cast", + query_shape_last, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) + scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) + # Add a Cast to convert the scale back to original type + scale = g.op( + "Cast", + scale, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + return scale + + +def _causal_attention_mask( + g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value +) -> torch._C.Value: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + + query_shape = g.op("Shape", query) + key_shape = g.op("Shape", key) + + last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) + target_length = g.op("Slice", query_shape, second_last_idx, last_idx) + source_length = g.op("Slice", key_shape, second_last_idx, last_idx) + # attn_mask = torch.ones(L, S) := { + size = g.op("Concat", target_length, source_length, axis_i=0) + const_one = g.op("Constant", value_t=torch.tensor([1.0])) + attn_mask = g.op("Expand", const_one, size) + # } + attn_mask = g.op("Trilu", attn_mask, upper_i=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op( + "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero + ) + return attn_mask diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset15.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset15.py new file mode 100644 index 000000000000..4f86a7f2f862 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset15.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 15. + +Note [ONNX operators that are added/updated in opset 15] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set +New operators: + Bernoulli + CastLike + Optional + OptionalGetElement + OptionalHasElement + +Updated operators: + BatchNormalization https://github.com/onnx/onnx/pull/3545 + Backwards compatible + TODO: test coverage for mixed types inputs. + Pow https://github.com/onnx/onnx/pull/3412 + Backwards compatible + TODO: bfloat16 support. + Shape https://github.com/onnx/onnx/pull/3580 + Backwards compatible + TODO: optional start/end attribute. +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch import _C +from torch.onnx._internal.torchscript_exporter import ( + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, +) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) + + +@_onnx_symbolic("aten::__is_") +def aten__is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if isinstance(self.type(), _C.OptionalType): + none = g.op("OptionalHasElement", self) + return g.op("Not", none) + else: + return g.op("Constant", value_t=torch.BoolTensor([0])) + return opset9.eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@opset9.wrap_logical_op_with_negation # type: ignore[has-type] +def aten__isnot_(g: jit_utils.GraphContext, self, other): + return aten__is_(g, self, other) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + if p is None or symbolic_helper._is_none(p): + return g.op("Bernoulli", input) + return opset9.bernoulli(g, input, p, generator, out) + + +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + # exists to refine the type of the Value + # if x is Optional[Tensor], unchecked_cast will cast + # x to Tensor, so the rest of the graph knows that x is a Tensor. + if isinstance(self.type(), _C.OptionalType): + return g.op("OptionalGetElement", self) + + return self diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset16.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset16.py new file mode 100644 index 000000000000..a617270a2a7c --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset16.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 16. + +Note [ONNX Operators that are added/updated in opset 16] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set +New operators: + GridSample https://github.com/onnx/onnx/pull/3557 + +Updated operators: + Identity + If + LeakyRelu + Loop + PRelu + RoiAlign + Scan + ScatterElements + ScatterND + Where + GreaterOrEqual + LessOrEqual +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch.nn.functional import ( + GRID_SAMPLE_INTERPOLATION_MODES, + GRID_SAMPLE_PADDING_MODES, +) +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + utils, +) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) + + +# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? +# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def grid_sampler( + g: jit_utils.GraphContext, + input, + grid, + mode_enum, + padding_mode_enum, + align_corners, +): + # Check the input and grid tensor rank beforehand. + if symbolic_helper._get_tensor_rank(input) == 5: + return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") + mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] + padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] + padding_mode_enum + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src_sizes = symbolic_helper._get_tensor_sizes(src) + index_sizes = symbolic_helper._get_tensor_sizes(index) + + if len(src_sizes) != len(index_sizes): + return symbolic_helper._unimplemented( + "scatter_add", + f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", + ) + + # PyTorch only allows index shape <= src shape, so we can only consider + # taking index as subset size to src, like PyTorch does. When sizes for src + # and index are not matched or there are dynamic axes, we take index shape to + # slice src to accommodate. + if src_sizes != index_sizes or None in index_sizes: + adjusted_shape = g.op("Shape", index) + starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) + src = g.op("Slice", src, starts, adjusted_shape) + + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + return g.op( + "ScatterElements", + self, + index, + src, + axis_i=dim, + reduction_s="add", + ) + + +@_onnx_symbolic("aten::scatter_reduce") +@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") +def scatter_reduce( + g: jit_utils.GraphContext, + self: torch._C.Value, + dim: int, + index: torch._C.Value, + src: torch._C.Value, + reduce: str, + include_self: bool, +): + if reduce == "mean": + raise errors.OnnxExporterError( + "ONNX does not support mean reduction for scatter_reduce" + ) + if not include_self: + raise errors.OnnxExporterError( + "ONNX does not support include_self=False for scatter_reduce" + ) + + reduce_mode = { # convert torch string name to onnx string name + "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + + self_rank = g.op("Size", g.op("Shape", self)) + + # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) + self_rank_is_zero = g.op( + "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + ) + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=3 + ) + neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + self_reshape = if_context.op("Reshape", self, neg_1) + utils._add_output_to_block(if_context.block, self_reshape) + index_reshape = if_context.op("Reshape", index, neg_1) + utils._add_output_to_block(if_context.block, index_reshape) + src_reshape = if_context.op("Reshape", src, neg_1) + utils._add_output_to_block(if_context.block, src_reshape) + + self_identity = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, self_identity) + index_identitye = else_context.op("Identity", index) + utils._add_output_to_block(else_context.block, index_identitye) + src_identity = else_context.op("Identity", src) + utils._add_output_to_block(else_context.block, src_identity) + + result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) + + # if self_rank == 0: + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=1 + ) + result_squeezed = if_context.op("Squeeze", result) + utils._add_output_to_block(if_context.block, result_squeezed) + result_identity = else_context.op("Identity", result) + utils._add_output_to_block(else_context.block, result_identity) + result_final = if_op.node().output() + + return result_final diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py new file mode 100644 index 000000000000..e8ea41e64306 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py @@ -0,0 +1,244 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 17. + +Note [ONNX Operators that are added/updated in opset 17] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set +New operators: + BlackmanWindow + DFT + HammingWindow + HannWindow + LayerNormalization + MelWeightMatrix + STFT + SequenceMap +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, +) + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + dtype = scalar_type.dtype() + if symbolic_helper._is_none(weight): + weight_value = torch.ones(normalized_shape, dtype=dtype) + weight = g.op("Constant", value_t=weight_value) + if symbolic_helper._is_none(bias): + bias_value = torch.zeros(normalized_shape, dtype=dtype) + bias = g.op("Constant", value_t=bias_value) + return g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + ) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +def _compute_edge_sizes(n_fft, window_size): + """Helper function to compute the sizes of the edges (left and right) + of a given window centered within an FFT size.""" + left = (n_fft - window_size) // 2 + right = n_fft - left - window_size + return left, right + + +@_onnx_symbolic("aten::stft") +@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b") +def stft( + g: jit_utils.GraphContext, + input: _C.Value, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[_C.Value] = None, + normalized: bool = False, + onesided: Optional[bool] = True, + return_complex: Optional[bool] = False, + align_to_window: Optional[bool] = None, +) -> _C.Value: + """Associates `torch.stft` with the `STFT` ONNX operator. + Note that torch.stft calls _VF.stft, without centering or padding options. + Hence, this function does not contain these two arguments. + See torch.stft source code for more info. + + Args: + g: Graph to write the ONNX representation into + input: Input tensor for the transformation + n_fft: FFT size + hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` + win_length: Size of the analysis window. Defaults to `n_fft` + window: Analysis window. Defaults to a window of all ones + normalized: Whether to return a normalized STFT + onesided: Whether to return only half (+1) of the results, given the + symmetry of the STFT + return_complex: Whether to return the complex value (Note: Must be + `False` or `None`) + + Returns: + op: Operator for torch.stft associated with STFT (ONNX) + """ + # Checks + if return_complex: + raise errors.SymbolicValueError( + msg="STFT does not currently support complex types", value=input + ) + + if align_to_window is not None: + raise errors.SymbolicValueError( + msg="STFT does not currently support the align_to_window option", + value=input, + ) # TODO(#145944): add compatibility with align_to_window option. + + # Get STFT sizes + frame_step_value = hop_length if hop_length is not None else n_fft // 4 + frame_step_const = g.op( + "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) + ) + frame_length_const = g.op( + "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) + ) + + # Pre-process input if needed + signal = input + signal_rank = symbolic_helper._get_tensor_rank(signal) + if signal_rank == 1: + # Add batch dimension + signal = g.op( + "Unsqueeze", + signal, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + elif signal_rank is None or signal_rank > 2: + raise errors.SymbolicValueError( + msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " + f"Current rank of signal is {signal_rank}, please reduce it.", + value=input, + ) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) + if n_win is not None: + win_length_default = win_length if win_length else n_fft + assert n_win == win_length_default, ( + "Analysis window size must equal `win_length` or `n_fft`. " + f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", + ) + + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left, right = _compute_edge_sizes(n_fft, n_win) + left_win = g.op("Constant", value_t=torch.zeros(left)) + right_win = g.op("Constant", value_t=torch.zeros(right)) + window = g.op("Concat", left_win, window, right_win, axis_i=0) + + # Create window, if needed + if symbolic_helper._is_none(window): + if win_length: + if win_length > n_fft: + raise errors.SymbolicValueError( + msg="The analysis window can't be longer than the size of the FFT. " + f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", + value=input, + ) + + # Center window, if needed + left, right = _compute_edge_sizes(n_fft, win_length) + torch_window = torch.hstack( + (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) + ) + else: + # Rectangle window + torch_window = torch.ones(n_fft) + assert torch_window.shape[0] == n_fft + window = g.op("Constant", value_t=torch_window) + window = g.op( + "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() + ) + + # Run STFT + result = g.op( + "STFT", + signal, + frame_step_const, + window, + frame_length_const, + onesided_i=1 if onesided is None or onesided else 0, + ) + + # Transpose to mimic torch.stft's behavior + result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) + + # Remove batch dimension, if needed + if signal_rank == 1: + result = g.op( + "Squeeze", + result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + # Normalize, if needed + if normalized: + sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) + result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) + + return result diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py new file mode 100644 index 000000000000..6a5ac408fb1b --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py @@ -0,0 +1,270 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 18. + +Note [ONNX Operators that are added/updated in opset 18] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set +New operators: + BitwiseAnd + CenterCropPad + Col2Im + Mish + OptionalGetElement + OptionalHasElement + Pad + Resize + ScatterElements + ScatterND + Split +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, +) + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = [ + "col2im", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) + + +@_onnx_symbolic("aten::__and_") +@_onnx_symbolic("aten::bitwise_and") +def __and_(g: jit_utils.GraphContext, self, other): + # do type promotion (scalars don't seem to apply) + args = [self, other] + # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) + prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] + if len(prom_args) == 0: + prom_args = args + promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) + self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) + other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) + if promotion_jit_type == _type_utils.JitScalarType.BOOL: + return g.op("And", self, other) + return g.op("BitwiseAnd", self, other) + + +@_onnx_symbolic("aten::col2im") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") +def col2im( + g, + input: _C.Value, + output_size: _C.Value, + kernel_size: _C.Value, + dilation: Sequence[int], + padding: Sequence[int], + stride: Sequence[int], +): + # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] + adjusted_padding: list[int] = [] + for pad in padding: + adjusted_padding.extend(pad for _ in range(2)) + + num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] + if not adjusted_padding: + adjusted_padding = [0, 0] * num_dimensional_axis + + if not dilation: + dilation = [1] * num_dimensional_axis + + if not stride: + stride = [1] * num_dimensional_axis + + return g.op( + "Col2Im", + input, + output_size, + kernel_size, + dilations_i=dilation, + pads_i=adjusted_padding, + strides_i=stride, + ) + + +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def _native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def _glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMax", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( + "ReduceMax", self, axes, keepdims_i=keepdim + ) + else: + return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( + "ReduceMax", self, keepdims_i=keepdim + ) + + +@_onnx_symbolic("aten::var_mean") +def _var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return symbolic_helper._var_mean_helper(g, input, None, args[0], None) + else: + return symbolic_helper._var_mean_helper(g, input, *args) + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + if dim is None: + return g.op("ReduceLogSumExp", input, keepdims_i=0) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def _linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Optional[Sequence[int]], + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset19.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset19.py new file mode 100644 index 000000000000..781bc2d200c7 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset19.py @@ -0,0 +1,31 @@ +"""This file exports ONNX ops for opset 19. + +Note [ONNX Operators that are added/updated in opset 19] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set +New operators: +AveragePool +Cast +CastLike +Constant +DeformConv +DequantizeLinear +Equal +Identity +If +Loop +Pad +QuantizeLinear +Reshape +Resize +Scan +Shape +Size +""" + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__: list[str] = [] diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset20.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset20.py new file mode 100644 index 000000000000..8e8ca44a26a4 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset20.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 20. + +Note [ONNX Operators that are added/updated in opset 20] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set +New operators: + AffineGrid + ConstantOfShape + DFT + Gelu + GridSample + ImageDecoder + IsInf + IsNaN + ReduceMax + ReduceMin + RegexFullMatch + StringConcat + StringSplit +""" + +import functools + +import torch.nn.functional as F +from torch import _C +from torch.onnx._internal.torchscript_exporter import ( + jit_utils, + registration, + symbolic_helper, +) + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] + + +def convert_grid_sample_mode(mode_s): + return ( + "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s + ) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) + + +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def _grid_sampler( + g: jit_utils.GraphContext, + input: _C.Value, + grid: _C.Value, + mode_enum: int, + padding_mode_enum: int, + align_corners: bool, +): + mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] + # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html + mode_s = convert_grid_sample_mode(mode_s) + padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] + padding_mode_enum # type: ignore[index] + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::affine_grid_generator") +@symbolic_helper.parse_args("v", "v", "b") +def _affine_grid_generator( + g: jit_utils.GraphContext, + theta: _C.Value, + size: _C.Value, + align_corners: bool, +): + return g.op( + "AffineGrid", + theta, + size, + align_corners_i=int(align_corners), + ) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): + return g.op("Gelu", self, approximate_s=approximate) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset7.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset7.py new file mode 100644 index 000000000000..d11750b1ee8a --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset7.py @@ -0,0 +1,71 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 7 to opset 8] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Expand + +Updated operators: + Min, Max, Sum, Mean: supports multidirectional broadcasting. + MaxPool: added optional indices output. + Scan +""" + +import functools +import warnings + +from torch.onnx._internal.torchscript_exporter import ( + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, +) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) + +block_listed_operators = ( + "scan", + "expand", + "expand_as", + "meshgrid", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", +) + + +# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +@_onnx_symbolic("aten::max") +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to max operators " + "have different shapes" + ) + return opset9.max(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::min") +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to min operators " + "have different shapes" + ) + return opset9.min(g, self, dim_or_y, keepdim) + + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py new file mode 100644 index 000000000000..bde072608088 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 8 to opset 9] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Compress + ConstantOfShape + EyeLike + MaxUnpool + OneHot + Sinh + Cosh + Asinh + Acosh + Atanh + Shrink + IsNaN + Sign + Erf + Scatter + Where + NonZero + TfIdfVectorizer + MeanVarianceNormalization + +Updated operators: + BatchNormalization: removed spatial attribute. + Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. + Cast: more data types{string} supported. + Upsample: moved scales from attribute to input. + Scan +""" + +import functools +import warnings + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, + symbolic_opset9 as opset9, +) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) + +block_listed_operators = ( + "nonzero", + "where", + "scatter", + "scatter_add", + "erf", + "sign", + "isnan", + "gather", + "arange", + "masked_fill", + "index_fill", + "index_copy", + "repeat_interleave", + "any", + "all", +) + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + output_size = symbolic_helper._maybe_get_const(output_size, "is") + if symbolic_helper._is_value(output_size): + return symbolic_helper._unimplemented( + name, "torch._C.Value (output_size) indexing" + ) + if scales is None: + scales = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + align_corners = symbolic_helper._maybe_get_const(align_corners, "b") + if not symbolic_helper._is_none(align_corners) and align_corners: + return symbolic_helper._unimplemented("interpolate", "align_corners == True") + + if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( + scale_factor + ): + return symbolic_helper._unimplemented( + "interpolate", "dynamic scales in opset 8" + ) + + if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): + return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") + + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, mode_s=mode, scales_f=scales) + + +# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation +# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which +# is lost after casting. +def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): + floating_scalar_types = { + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + } + old_type = None + # Cast the input tensor to Float if its scalarType is known and is not floating number. + # If casting is performed, return the old scalarType, otherwise return None. + arg0_type = _type_utils.JitScalarType.from_value( + args[0], _type_utils.JitScalarType.UNDEFINED + ) + if arg0_type != _type_utils.JitScalarType.UNDEFINED: + old_type = arg0_type + if old_type not in floating_scalar_types: + old_type = old_type.scalar_name() # type: ignore[assignment] + args = tuple( + g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) + for arg in args + ) + else: + return (None,) + args + else: + warnings.warn( + "Only floating datatype is supported for these operators: " + "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " + "the onnx model to be incorrect, if inputs have integer datatypes." + ) + return (old_type,) + args + + +def _cast_to_type(g: jit_utils.GraphContext, input, to_type): + if to_type is None: + return input + return getattr(opset9, f"_cast_{to_type}")(g, input, False) + + +def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): + other = symbolic_helper._maybe_get_scalar(other) + other = symbolic_helper._if_scalar_type_as(other, input) + _, input, other = _try_cast_integer_to_float(g, input, other) + return g.op(op_name, input, other) + + +# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, +# integer input type not supported in opset8. Cast to float if possible. +@_onnx_symbolic("aten::gt") +def gt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Greater") + + +@_onnx_symbolic("aten::lt") +def lt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Less") + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other = _try_cast_integer_to_float(g, self, other) + return _cast_to_type(g, g.op("MatMul", self, other), old_type) + else: + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return bmm(g, self, other) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + if symbolic_helper._try_get_scalar_type(self): + old_type, self, weight = _try_cast_integer_to_float(g, self, weight) + return _cast_to_type(g, g.op("PRelu", self, weight), old_type) + else: + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + scalar_type = symbolic_helper._try_get_scalar_type(self, other) + if scalar_type is None: + raise errors.SymbolicValueError( + "mm can only operate on tensors with known types", self + ) + zero_constant = g.op( + "Constant", + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other, zero_constant = _try_cast_integer_to_float( + g, self, other, zero_constant + ) + return _cast_to_type( + g, + g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), + old_type, + ) + return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) + return _cast_to_type( + g, + g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ), + old_type, + ) + else: + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::flatten") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") + end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") + + dim = input.type().dim() + if end_dim_i < 0: + end_dim_i = dim + end_dim_i + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim_i == 1 and end_dim_i == dim - 1: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=start_dim_i), old_type + ) + else: + return g.op("Flatten", input, axis_i=start_dim_i) + if start_dim_i == 0 and end_dim_i == dim - 2: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type + ) + else: + return g.op("Flatten", input, axis_i=end_dim_i + 1) + + return opset9.flatten(g, input, start_dim, end_dim) + + +def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + if not scalar_type.dtype().is_floating_point: + result = g.op( + "ConstantFill", + sizes, + dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + return g.op("Cast", result, to_i=scalar_type.onnx_type()) + else: + return g.op( + "ConstantFill", + sizes, + dtype_i=scalar_type.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device and layout in ONNX, so we ignore it + return _constant_fill(g, sizes, dtype, 0) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 0) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + return _constant_fill(g, sizes, dtype, 1) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 1) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + tmp = zeros(g, sizes, dtype, layout, device) + return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return _constant_fill(g, sizes, dtype, const_value) + + +@_onnx_symbolic("aten::full_like") +@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, fill_value) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + if not symbolic_helper._is_value(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if symbolic_helper._is_packed_list(repeats): + repeat_size_len = len(symbolic_helper._unpack_list(repeats)) + else: + const_repeats = symbolic_helper._maybe_get_const(repeats, "is") + repeat_size_len = len(const_repeats) + if self.isCompleteTensor(): + sizes = self.type().sizes() + diff_dims = repeat_size_len - len(sizes) + if diff_dims > 0: + self = opset9.view( + g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) + ) + return g.op("Tile", self, repeats) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py new file mode 100644 index 000000000000..596c656777f8 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -0,0 +1,6656 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 9. + +Opset 9 is supported by ONNX release 1.4.1 +release on 01/23/19 +""" + +from __future__ import annotations + +import builtins +import functools +import math +import sys +import warnings +from typing import Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.nn.modules.utils +import torch.onnx +from torch import _C +from torch.onnx import _constants, errors +from torch.onnx._internal.torchscript_exporter import ( + _type_utils, + jit_utils, + registration, + symbolic_helper, +) +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "abs", + "acos", + "add", + "addcmul", + "addmm", + "alias", + "amax", + "amin", + "aminmax", + "arange", + "argmax", + "argmin", + "as_strided", + "as_tensor", + "asin", + "atan", + "atan2", + "baddbmm", + "batch_norm", + "bernoulli", + "bitwise_not", + "bitwise_or", + "bmm", + "broadcast_tensors", + "broadcast_to", + "bucketize", + "cat", + "cdist", + "ceil", + "clamp_max", + "clamp_min", + "clamp", + "clone", + "constant_pad_nd", + "contiguous", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv1d", + "conv2d", + "conv3d", + "convert_element_type", + "convolution", + "cos", + "cosine_similarity", + "cross", + "cumsum", + "detach", + "dim", + "div", + "dot", + "dropout", + "elu", + "embedding_bag", + "embedding", + "empty_like", + "empty", + "eq", + "erf", + "exp", + "expand_as", + "expand", + "eye", + "fill", + "flatten", + "floor_divide", + "floor", + "floordiv", + "frobenius_norm", + "full_like", + "full", + "gather", + "ge", + "gelu", + "get_pool_ceil_padding", + "glu", + "group_norm", + "gt", + "hann_window", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_select", + "index", + "instance_norm", + "is_floating_point", + "is_pinned", + "isnan", + "item", + "kl_div", + "layer_norm", + "le", + "leaky_relu", + "lerp", + "lift", + "linalg_cross", + "linalg_matrix_norm", + "linalg_norm", + "linalg_vector_norm", + "linear", + "linspace", + "log_sigmoid", + "log_softmax", + "log", + "log10", + "log1p", + "log2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logsumexp", + "lstm_cell", + "lstm", + "lt", + "masked_fill", + "masked_fill_", + "matmul", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "max", + "maximum", + "meshgrid", + "min", + "minimum", + "mish", + "mm", + "movedim", + "mse_loss", + "mul", + "multinomial", + "mv", + "narrow", + "native_layer_norm", + "ne", + "neg", + "new_empty", + "new_full", + "new_ones", + "new_zeros", + "nonzero_numpy", + "nonzero", + "norm", + "numel", + "numpy_T", + "one_hot", + "ones_like", + "ones", + "onnx_placeholder", + "pad", + "pairwise_distance", + "permute", + "pixel_shuffle", + "pixel_unshuffle", + "pow", + "prelu", + "prim_constant_chunk", + "prim_constant_split", + "prim_constant", + "prim_data", + "prim_device", + "prim_dtype", + "prim_if", + "prim_layout", + "prim_list_construct", + "prim_list_unpack", + "prim_loop", + "prim_max", + "prim_min", + "prim_shape", + "prim_tolist", + "prim_tuple_construct", + "prim_type", + "prim_unchecked_cast", + "prim_uninitialized", + "rand_like", + "rand", + "randint_like", + "randint", + "randn_like", + "randn", + "reciprocal", + "reflection_pad", + "relu", + "relu6", + "remainder", + "repeat_interleave", + "repeat", + "replication_pad", + "reshape_as", + "reshape", + "roll", + "rrelu", + "rsqrt", + "rsub", + "scalar_tensor", + "scatter_add", + "scatter", + "select", + "selu", + "sigmoid", + "sign", + "silu", + "sin", + "size", + "slice", + "softmax", + "softplus", + "softshrink", + "sort", + "split_with_sizes", + "split", + "sqrt", + "square", + "squeeze", + "stack", + "std_mean", + "std", + "sub", + "t", + "take", + "tan", + "tanh", + "tanhshrink", + "tensor", + "threshold", + "to", + "topk", + "transpose", + "true_divide", + "type_as", + "unbind", + "unfold", + "unsafe_chunk", + "unsafe_split_with_sizes", + "unsafe_split", + "unsqueeze", + "unsupported_complex_operators", + "noop_complex_operators", + "unused", + "var_mean", + "var", + "view_as", + "view", + "where", + "wrap_logical_op_with_cast_to", + "wrap_logical_op_with_negation", + "zeros_like", + "zeros", + "zero", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) + + +def _export(name: str): + """Exports the function in the current global namespace.""" + + def wrapper(func): + globals()[name] = func + __all__.append(name) + return func + + return wrapper + + +def unused(g): + """Represents "missing" optional inputs.""" + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +@_onnx_symbolic("aten::_shape_as_tensor") +def _shape_as_tensor(g: jit_utils.GraphContext, input): + return g.op("Shape", input) + + +@_onnx_symbolic("aten::_reshape_from_tensor") +def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): + if isinstance(shape, list): + shape = g.op("Concat", *shape, axis_i=0) + return reshape(g, input, shape) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +def reshape(g: jit_utils.GraphContext, self, shape): + return symbolic_helper._reshape_helper(g, self, shape) + + +@_onnx_symbolic("aten::reshape_as") +@symbolic_helper.quantized_args(True) +def reshape_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + return symbolic_helper._onnx_opset_unsupported_detailed( + "Add", 9, 11, "Add between list of tensors not supported", self + ) + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Add", self, other) + + +@_onnx_symbolic("aten::sub") +def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Sub", self, other) + + +@_onnx_symbolic("aten::rsub") +def rsub(g: jit_utils.GraphContext, self, other, alpha=None): + return sub(g, other, self, alpha=alpha) + + +@_onnx_symbolic("aten::mul") +def mul(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): + # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. + return g.op("And", self, other) + else: + return g.op("Mul", self, other) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@_onnx_symbolic("aten::addcmul") +@symbolic_helper.parse_args("v", "v", "v", "f") +def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): + value_tens = g.op("Constant", value_t=torch.tensor([value])) + return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode is None: + return true_divide(g, self, other) + elif rounding_mode == "floor": + return _floor_divide(g, self, other) + elif rounding_mode == "trunc": + return _trunc_divide(g, self, other) + else: + raise errors.SymbolicValueError( + f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', + self, + ) + + +def _trunc_divide(g: jit_utils.GraphContext, self, other): + out = g.op("Div", self, other) + # the correct operation is truncate, which is not supported in ONNX, + # we cannot call floor since it will behave differently for negative numbers + # (eg. -0.1 should become -0 ) + # - if scalar_type information are not available, assume that + # we need to call floor (treat as float) + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) + + # Matching PyTorch's behavior: + # - if self is fp the output's type is self's type + # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT + # - self is not fp and other is not fp, the output's type is self's output type + # - the output type defaults to Float + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + else: + out = g.op( + "Cast", + out, + to_i=scalar_type.onnx_type(), + ) + else: + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return out + + +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does truncation rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op( + "Xor", + symbolic_helper._lt_helper(g, self, zero), + symbolic_helper._lt_helper(g, other, zero), + ) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Sub", self, g.op("Mul", div, other)) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Mul", fixup_mask, one) + return g.op("Sub", div, fixup) + + +@_onnx_symbolic("aten::floor_divide") +def floor_divide(g: jit_utils.GraphContext, self, other): + # Deprecated behavior, floor_divide actually truncates + return _trunc_divide(g, self, other) + + +@_onnx_symbolic("aten::floordiv") +def floordiv(g: jit_utils.GraphContext, self, other): + return floor_divide(g, self, other) + + +@_onnx_symbolic("aten::true_divide") +def true_divide(g: jit_utils.GraphContext, self, other): + """Division where both inputs are cast to floating types + + If both inputs are floating, performs div as usual + If only one input is a floating type, the other input is cast to its type + If neither input is a floating type, both inputs are cast to the default scalar type + """ + + # Case 1: either values are floating + # Performs div as usual. + # Implicit casting will be handled in scalar type analysis pass. + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + return g.op("Div", self, other) + + # Case 2: neither is floating + # Casts both inputs to the default scalar type + scalar_type = torch.get_default_dtype() + onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT + assert scalar_type is torch.float or scalar_type is torch.double + if torch.get_default_dtype() is torch.double: + onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE + + self = g.op("Cast", self, to_i=onnx_scalar_type) + other = g.op("Cast", other, to_i=onnx_scalar_type) + return g.op("Div", self, other) + + +@_onnx_symbolic("aten::reciprocal") +def reciprocal(g: jit_utils.GraphContext, self): + # torch.reciprocal implicitly casts to float, so we do the same. + if not symbolic_helper._is_fp(self): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return g.op("Reciprocal", self) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.parse_args("v", "i") +def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ + tensors = symbolic_helper._unpack_list(tensor_list) + # torch.cat ignores empty tensors such as `torch.Tensor([])` + # These needs to be removed as input from ONNX's concat too, otherwise shape inference + # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) + nonempty_tensors = [] + for t in tensors: + if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( + t, 0 + ): + continue + nonempty_tensors.append(t) + assert len(nonempty_tensors) > 0 + assert all( + symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None + or symbolic_helper._get_tensor_rank(t) is None + or symbolic_helper._get_tensor_rank(t) + == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) + for t in nonempty_tensors + ) + tensor_list.node().removeAllInputs() + for t in nonempty_tensors: + tensor_list.node().addInput(t) + + tensors = symbolic_helper._unpack_list(tensor_list) + return g.op("Concat", *tensors, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +@symbolic_helper.parse_args("v", "i") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + unsqueezed = [ + symbolic_helper._unsqueeze_helper(g, t, [dim]) + for t in symbolic_helper._unpack_list(tensor_list) + ] + return g.op("Concat", *unsqueezed, axis_i=dim) + + +@_onnx_symbolic("aten::list") +def _list(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + C = g.op("Constant", value_t=torch.tensor([1])) + return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + scalar_type = None + self_scalar_type = symbolic_helper._try_get_scalar_type(self) + mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) + mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) + if self_scalar_type is not None: + scalar_type = self_scalar_type + elif mat1_scalar_type is not None: + scalar_type = mat1_scalar_type + elif mat2_scalar_type is not None: + scalar_type = mat2_scalar_type + + mat1_rank = symbolic_helper._get_tensor_rank(mat1) + mat2_rank = symbolic_helper._get_tensor_rank(mat2) + + def is_not_none_nor(v, u): + return v is not None and v != u + + if scalar_type is not None and ( + is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) + ): + res1 = g.op("MatMul", mat1, mat2) + res2 = self + + alpha = symbolic_helper._scalar(alpha) + beta = symbolic_helper._scalar(beta) + + if alpha != 1: + alpha = g.op( + "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) + ) + res1 = g.op("Mul", res1, alpha) + if beta != 1: + beta = g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._scalar(beta), dtype=scalar_type.dtype() + ), + ) + res2 = g.op("Mul", res2, beta) + + return g.op("Add", res1, res2) + + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::neg") +def neg(g: jit_utils.GraphContext, self): + return g.op("Neg", self) + + +@_onnx_symbolic("aten::sqrt") +def sqrt(g: jit_utils.GraphContext, self): + if _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT16, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT64, + }: + # torch converts all int inputs to sqrt to float + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + return g.op("Sqrt", self) + + +@_onnx_symbolic("aten::rsqrt") +def rsqrt(g: jit_utils.GraphContext, self): + return g.op( + "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) + ) + + +@_onnx_symbolic("aten::tanh") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp +@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) +def tanh(g: jit_utils.GraphContext, self): + return g.op("Tanh", self) + + +@_onnx_symbolic("aten::sin") +def sin(g: jit_utils.GraphContext, self): + return g.op("Sin", self) + + +@_onnx_symbolic("aten::cos") +def cos(g: jit_utils.GraphContext, self): + return g.op("Cos", self) + + +@_onnx_symbolic("aten::tan") +def tan(g: jit_utils.GraphContext, self): + return g.op("Tan", self) + + +@_onnx_symbolic("aten::asin") +def asin(g: jit_utils.GraphContext, self): + return g.op("Asin", self) + + +@_onnx_symbolic("aten::acos") +def acos(g: jit_utils.GraphContext, self): + return g.op("Acos", self) + + +@_onnx_symbolic("aten::atan") +def atan(g: jit_utils.GraphContext, self): + return g.op("Atan", self) + + +@_onnx_symbolic("aten::atan2") +def atan2(g: jit_utils.GraphContext, self, other): + # self is y, and other is x on coordinate + slope = g.op("Div", self, other) + atan = g.op("Atan", slope) + const_zero = g.op("Constant", value_t=torch.tensor(0)) + const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) + + condition_second_or_third_quadrant = g.op("Greater", self, const_zero) + second_third_quadrant = g.op( + "Where", + condition_second_or_third_quadrant, + g.op("Add", atan, const_pi), + g.op("Sub", atan, const_pi), + ) + + condition_14_or_23_quadrant = g.op("Less", other, const_zero) + result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) + + return result + + +@_onnx_symbolic("aten::sigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ + return g.op("Sigmoid", self) + + +@_onnx_symbolic("aten::sign") +def sign(g: jit_utils.GraphContext, self): + return g.op("Sign", self) + + +@symbolic_helper.quantized_args(True) +def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): + assert len(starts) == len(ends) + if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: + return input + return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) + + +@_onnx_symbolic( + "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] +) +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +# torch.prod does not support multidimensional "dim" +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, input, dim, dtype): + symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) + + +@_onnx_symbolic("aten::_sample_dirichlet") +def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) + + +@_onnx_symbolic("aten::_standard_gamma") +def _standard_gamma(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_standard_gamma", self) + + +@_onnx_symbolic("aten::t") +def t(g: jit_utils.GraphContext, self): + rank = symbolic_helper._get_tensor_rank(self) + if rank is None or rank < 2: + # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior + # clearly and onnxruntime fails on these cases. So we add an Identity node to + # mirror the behavior of eager mode. + return g.op("Identity", self) + return g.op("Transpose", self, perm_i=(1, 0)) + + +@_onnx_symbolic("aten::numpy_T") +@symbolic_helper.quantized_args(True) +def numpy_T(g: jit_utils.GraphContext, input): + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(reversed(range(0, ndim))) + return g.op("Transpose", input, perm_i=perm) + + +@_onnx_symbolic("aten::expand") +@symbolic_helper.quantized_args(True) +def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::broadcast_to") +@symbolic_helper.quantized_args(True) +def broadcast_to(g: jit_utils.GraphContext, self, size): + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::expand_as") +@symbolic_helper.quantized_args(True, True) +def expand_as(g: jit_utils.GraphContext, self, other): + self_t = symbolic_helper._maybe_get_const(self, "t") + if isinstance(self_t, torch.Tensor): + orig_type = self_t.dtype + self_t = self_t.to(torch.double) + dims = [] + for d in range(self_t.dim()): + if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): + dims.append(d) + self = g.op( + "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) + ) + + shape = g.op("Shape", other) + return g.op("Expand", self, shape) + + +@_onnx_symbolic("aten::embedding") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i", "b", "v") +def embedding( + g: jit_utils.GraphContext, + weight, + indices, + padding_idx, + scale_grad_by_freq, + sparse, +): + if scale_grad_by_freq and GLOBALS.export_training: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " + "for training mode. ONNX does not support scaling the gradients.", + weight, + ) + if padding_idx >= 0 and GLOBALS.export_training: + warnings.warn( + "Warning: ONNX export of embedding with padding_idx >= 0 " + "for training mode. " + "ONNX does not support not updating the embedding vector at padding_idx during training." + ) + + return g.op("Gather", weight, indices) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if not symbolic_helper._is_none(per_sample_weights): + return symbolic_helper._onnx_unsupported( + "embedding_bag with per_sample_weights" + ) + + return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + if symbolic_helper._maybe_get_const(dim, "i") < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + dim = symbolic_helper._maybe_get_const(dim, "i") + rank + dim = g.op("Constant", value_t=torch.tensor(dim)) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::transpose") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "i") +def transpose(g: jit_utils.GraphContext, self, dim0, dim1): + if dim0 == dim1: # micro-optimization + return self + + # NB: Transpose in ONNX is actually a Permute + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) + axes[dim0], axes[dim1] = axes[dim1], axes[dim0] + return g.op("Transpose", self, perm_i=axes) + else: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of transpose for tensor of unknown rank.", + self, + ) + + +@_onnx_symbolic("aten::permute") +@symbolic_helper.parse_args("v", "is") +def permute(g: jit_utils.GraphContext, self, dims): + if dims == list(range(0, len(dims))): + return self + return g.op("Transpose", self, perm_i=dims) + + +@_onnx_symbolic("aten::view") +@symbolic_helper.quantized_args(True) +def view(g: jit_utils.GraphContext, self, size): + return reshape(g, self, size) + + +@_onnx_symbolic("aten::view_as") +def view_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self + ) + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "unsafe_chunk", "unknown dimension size", self + ) + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Dynamic number of outputs not supported", self + ) + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "is", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self + ) + return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unbind", 9, 11, "Dynamic number of outputs not supported", self + ) + + outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::select") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ + index = symbolic_helper._maybe_get_scalar(index) + if (not symbolic_helper._is_value(index)) and (index < 0): + if index == -1: + end_index = _constants.INT64_MAX + else: + end_index = index + 1 + slice_node = symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[index], ends=[end_index] + ) + return symbolic_helper._squeeze_helper(g, slice_node, [dim]) + else: + # FIXME(justinchuby): can index be an int and not a value? + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::square") +def square(g: jit_utils.GraphContext, self): + return g.op("Mul", self, self) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") + # Handle negative dims + if squeeze_dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export squeeze with negative axis " + + str(squeeze_dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(squeeze_dim + rank) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + squeeze_dim += rank + else: + return symbolic_helper._unimplemented( + "squeeze", "negative axis with unknown input rank", self + ) + + dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + " on an input " + + "with unknown shape. Note that if the size of dimension " + + str(squeeze_dim) + + " of the input " + + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + + "non-singleton dimensions, it is recommended to export this model using opset " + + "version 11 or higher." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please use opset version 11 to " + + "export the model." + ) + return self + + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". If the model is " + + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + weight_rank = len(weight_sizes) + if self_rank is not None: + if self_rank > 2: + # make weight unidirectional broadcastable + weight = symbolic_helper._unsqueeze_helper( + g, weight, list(range(1, self_rank - 1)) + ) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + weight_rank = 0 + + if self_rank is not None and weight_rank is not None: + assert self_rank >= weight_rank, ( + f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + ) + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::silu") +def silu(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Sigmoid", input)) + + +@_onnx_symbolic("aten::mish") +def mish(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) + + +@_onnx_symbolic("aten::relu") +@symbolic_helper.quantized_args(True) +def relu(g: jit_utils.GraphContext, input): + return symbolic_helper._op_with_optional_float_cast( + g, "Relu", input, opset_before=14 + ) + + +@_onnx_symbolic("aten::relu6") +@symbolic_helper.quantized_args(True) +def relu6(g: jit_utils.GraphContext, input): + return clamp(g, input, 0, 6) + + +@_onnx_symbolic("aten::ceil") +def ceil(g: jit_utils.GraphContext, input): + return g.op("Ceil", input) + + +@_onnx_symbolic("aten::floor") +def floor(g: jit_utils.GraphContext, input): + return g.op("Floor", input) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::threshold") +@symbolic_helper.parse_args("v", "t", "t") +def threshold(g: jit_utils.GraphContext, self, threshold, value): + # See Note [Export inplace] + if symbolic_helper._scalar(threshold) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) + if symbolic_helper._scalar(value) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero value", self) + return g.op("Relu", self) + + +@_onnx_symbolic("aten::leaky_relu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "b") +def leaky_relu( + g: jit_utils.GraphContext, + input: _C.Value, + negative_slope: float, + inplace: bool = False, +): + # See Note [Export inplace] + return g.op("LeakyRelu", input, alpha_f=negative_slope) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # Softmax does normalization at vector level. + # PyTorch and ONNX use different strategies to split the input tensor into vectors. + # Thus dim and axis have different meanings. + # PyTorch slices the input tensor into vectors along the `dim`-th dimension. + # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. + # If input is a 2 x 3 tensor: + # input = [[1.0, 1.0, 1.0], + # [1.0, 1,0, 1,0]] + # with dim = 0, the result is: + # result = [[0.5, 0.5, 0.5], + # [0.5, 0.5, 0.5]] + # with axis = 0, the result is: + # result = [[0.167, 0.167, 0.167], + # [0.167, 0.167, 0.167]] + # So only when dim and axis both equal to ndim - 1 (the last dimension), + # their semantics are equivalent. + # So use softmax when dim and axis both equal to ndim - 1, + # otherwise transpose the input to put the vectors to be normalized to the last dimension. + # When input rank is not known at export time we compute softmax using a subgraph + # with other operators + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is not None: + # TODO: remove this as onnx opset 11 spec allows negative axes + if dim < 0: + dim = input_dim + dim + + is_transpose_required = input_dim != dim + 1 + + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", + softmax, + to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), + ) + + if is_transpose_required: + softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] + return softmax + + # Apply max normalization. + input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) + + exp = g.op("Exp", input) + sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) + softmax = g.op("Div", exp, sum) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return softmax + + +@_onnx_symbolic("aten::softplus") +def softplus(g: jit_utils.GraphContext, self, beta, threshold): + beta_const = symbolic_helper._maybe_get_const(beta, "f") + if beta_const != 1: + return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) + return g.op("Softplus", self) + + +@_onnx_symbolic("aten::get_pool_ceil_padding") +def get_pool_ceil_padding(input, kernel_size, stride, padding): + # TODO(justinchuby): Looks like this op is deprecated in torch + sizes = symbolic_helper._get_tensor_sizes(input) + dim = sizes[-len(padding) :] if sizes is not None else None + if dim is None or any(i is None for i in dim): + return symbolic_helper._unimplemented( + "get_pool_ceil_padding", "input size not accessible", input + ) + ceiled_output_dim = [ + int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + + 1 + for i in range(0, len(padding)) + ] + # ensure last pooling starts inside + ceiled_output_dim = [ + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) + for i in range(0, len(ceiled_output_dim)) + ] + padding_ceil = [ + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) + ) + for i in range(0, len(padding)) + ] + # ensure padding is not > kernel_size + padding_ceil = [ + ( + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) + ) + for i in range(0, len(padding_ceil)) + ] + return padding_ceil + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ), + _export("max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ), + _export("max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ), + _export("max_pool3d"), + ], +) +def _max_pool(name, tuple_fn, ndims, return_indices): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): + if set(tuple_fn(dilation)) != {1}: + return symbolic_helper._unimplemented(name, "dilation", input) + if not stride: + stride = kernel_size + padding = tuple(tuple_fn(padding)) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) + else: + padding = padding * 2 + kwargs = { + "kernel_shape_i": tuple_fn(kernel_size), + "pads_i": padding, + "strides_i": tuple_fn(stride), + } + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by Pytorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and subtract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor were each dimension has values of indices within + # the dimension it is in. + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + if return_indices: + r, indices = g.op("MaxPool", input, outputs=2, **kwargs) + _, flattened_indices = g.op( + "MaxPool", + input, + outputs=2, + kernel_shape_i=[1 for _ in range(ndims)], + strides_i=[1 for _ in range(ndims)], + ) + # convert indices to have non-flattened indices values + s = symbolic_helper._slice_helper( + g, + flattened_indices, + axes=[2 + i for i in range(ndims)], + starts=list(tuple_fn(0)), + ends=list(tuple_fn(1)), + ) + indices = sub(g, indices, s) + return r, indices + else: + r = g.op("MaxPool", input, outputs=1, **kwargs) + return r + + return symbolic_fn + + +max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( + _max_pool( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) +) +max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( + _max_pool( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) +) +max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( + _max_pool( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) +) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[ + symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), + _export("avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[ + symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), + _export("avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[ + symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), + _export("avg_pool3d"), + ], +) +def _avg_pool(name, tuple_fn): + @symbolic_helper.quantized_args(True) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + if not stride: + stride = kernel_size + padding = symbolic_helper._avgpool_helper( + tuple_fn, padding, kernel_size, stride, divisor_override, name + ) + assert isinstance(padding, tuple) + adjusted_padding = padding + # Although onnx::AvgPool provides count_include_pad, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + if count_include_pad: + input = symbolic_helper._op_with_optional_float_cast( + g, + "Pad", + input, + pads_i=((0,) * 2 + padding) * 2, + mode_s="constant", + value_f=0.0, + opset_before=11, + ) + adjusted_padding = (0,) * len(padding) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) + else: + adjusted_padding = adjusted_padding * 2 + output = g.op( + "AveragePool", + input, + kernel_shape_i=tuple_fn(kernel_size), + strides_i=tuple_fn(stride), + pads_i=adjusted_padding, + ) + return output + + return symbolic_fn + + +@_onnx_symbolic( + "aten::adaptive_avg_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single + ), + _export("adaptive_avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair + ), + _export("adaptive_avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple + ), + _export("adaptive_avg_pool3d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool1d", + "MaxPool", + torch.nn.modules.utils._single, + max_pool1d_with_indices, + ), + _export("adaptive_max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool2d", + "MaxPool", + torch.nn.modules.utils._pair, + max_pool2d_with_indices, + ), + _export("adaptive_max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool3d", + "MaxPool", + torch.nn.modules.utils._triple, + max_pool3d_with_indices, + ), + _export("adaptive_max_pool3d"), + ], +) +def _adaptive_pool(name, type, tuple_fn, fn=None): + @symbolic_helper.quantized_args(True, False) + def symbolic_fn(g, input, output_size): + # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, + # by executing a GlobalPool. + # It is also supported for cases where the output size is a factor of the input size. + # For these cases the stride and kernel size are uniform along all the indices of + # the same dimension, which makes it possible to export it to ONNX. + # for MaxPool, GlobalMaxPool does not return indices, + # so we try using max_poolxd_with_indices, and if it is not possible + # (input is not a complete tensor or output size not factor of input size) + # then we call GlobalAveragePool and return None for the indices + output_size_value = output_size + try: + output_size = symbolic_helper._parse_arg(output_size, "is") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_unsupported( + "adaptive pooling, since output_size is not constant.", input + ) + if output_size == [1] * len(output_size) and type == "AveragePool": + return g.op("GlobalAveragePool", input) + sizes = symbolic_helper._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim = None + if dim is None or any(i is None for i in dim): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "input size not accessible", input + ) + # verify if output size % input size = 0 for all dim + mod = [dim[i] % output_size[i] for i in range(0, len(dim))] + if mod != [0] * len(mod): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "output size that are not factor of input size", output_size_value + ) + k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] + # call max_poolxd_with_indices to get indices in the output + if type == "MaxPool": + return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) + output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) + return output + + return symbolic_fn + + +def _prepare_onnx_paddings(dim: int, pad): + """Generate paddings in ONNX order based on pad in pytorch. + Args: + dim: the dimension of the tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... + """ + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # assume zero-dimensions in the beginning + paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) + # reverse order and collate first beginnings and then ends + paddings = paddings[-2::-2] + paddings[-1::-2] + return paddings + + +def _convert_padding_node(input): + padding = symbolic_helper._maybe_get_const(input, "is") + if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): + input_list = symbolic_helper._unpack_list(padding) + try: + padding = [ + symbolic_helper._get_const(v, "i", "padding") for v in input_list + ] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The sizes of the padding must be constant", input + ) + return padding + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): + mode = "constant" + try: + value = symbolic_helper._get_const(value, "f", "value") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The value for the padding must be constant", value + ) + + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 + ) + + +def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): + padding = _convert_padding_node(pad) + assert len(padding) % 2 == 0 + ndim = len(padding) // 2 + + cur = input + for idx in range(ndim): + pad_r = padding[-(2 * idx + 1)] + pad_l = padding[-(2 * idx + 2)] + tensors = [] + if pad_l > 0: + left = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] + ) + tensors.append(left) + + if pad_l < 0 or pad_r < 0: + start = builtins.max(0, -pad_l) + end = -(builtins.max(0, -pad_r)) + middle = symbolic_helper._slice_helper( + g, + cur, + axes=[2 + idx], + starts=[start], + ends=[end], + ) + tensors.append(middle) + else: + tensors.append(cur) + + if pad_r > 0: + right = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] + ) + tensors.append(right) + + cur = g.op("Concat", *tensors, axis_i=(2 + idx)) + + return cur + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return _pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), + _export("upsample_nearest1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), + _export("upsample_nearest2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), + _export("upsample_nearest3d"), + ], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[ + symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), + _export("upsample_linear1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[ + symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), + _export("upsample_bilinear2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[ + symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), + _export("upsample_trilinear3d"), + ], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Upsample", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, scales, mode_s=mode) + + +@_onnx_symbolic("aten::bitwise_not") +def bitwise_not(g: jit_utils.GraphContext, input): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + input, + ) + return g.op("Not", input) + + +@_onnx_symbolic("aten::bitwise_or") +def bitwise_or(g, self, other): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. self: ", + self, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. other: ", + other, + ) + return g.op("Or", self, other) + + +def wrap_logical_op_with_cast_to(to_type): + def decorator(fn): + @functools.wraps(fn) + def wrap_with_cast(g, input, other): + to_cast_func = globals()[f"_cast_{to_type}"] + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) + + return wrap_with_cast + + return decorator + + +def wrap_logical_op_with_negation(func: Callable) -> Callable: + @functools.wraps(func) + def wrap_with_not(g, input, other): + return g.op("Not", func(g, input, other)) + + return wrap_with_not + + +@_onnx_symbolic("aten::__not_") +def __not_(g: jit_utils.GraphContext, self): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + self, + ) + return g.op("Not", self) + + +@_onnx_symbolic("aten::eq") +@symbolic_helper.quantized_args(True, True) +def eq(g: jit_utils.GraphContext, self, other): + if isinstance(self.type(), _C.DeviceObjType) and isinstance( + other.type(), _C.DeviceObjType + ): + # ONNX doesn't have devices, so consider them all to be equal. + # The no-op check for equality will get constant-folded. + return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) + self_node = self.node() + other_node = other.node() + if self_node.kind() == other_node.kind() == "onnx::Constant": + if self_node.kindOf("value") == other_node.kindOf("value") == "s": + # Exporting strings to ONNX is not supported. + # If both strings are constant, we can compare them directly. + # The no-op check for equality will get constant-folded. + return g.op( + "Constant", + value_t=torch.tensor( + self_node.s("value") == other_node.s("value"), + dtype=torch.bool, + ), + ) + + return g.op("Equal", self, other) + + +@_onnx_symbolic("aten::ne") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ne(g: jit_utils.GraphContext, self, other): + return eq(g, self, other) + + +@_onnx_symbolic("aten::gt") +@symbolic_helper.quantized_args(True, True) +def gt(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +def _gt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Greater", input, other) + + +@_onnx_symbolic("aten::lt") +@symbolic_helper.quantized_args(True, True) +def lt(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +def _lt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Less", input, other) + + +@_onnx_symbolic("aten::ge") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ge(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +@_onnx_symbolic("aten::le") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def le(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +@_onnx_symbolic("aten::__and_") +def __and_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + other, + ) + return g.op("And", input, other) + + +@_onnx_symbolic("aten::__or_") +def __or_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + other, + ) + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::__xor_") +def __xor_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + other, + ) + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_and") +@wrap_logical_op_with_cast_to("Bool") +def logical_and(g: jit_utils.GraphContext, input, other): + return g.op("And", input, other) + + +@_onnx_symbolic("aten::logical_or") +@wrap_logical_op_with_cast_to("Bool") +def logical_or(g: jit_utils.GraphContext, input, other): + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::logical_xor") +@wrap_logical_op_with_cast_to("Bool") +def logical_xor(g: jit_utils.GraphContext, input, other): + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_not") +def logical_not(g: jit_utils.GraphContext, input): + return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) + + +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # PyTorch dim and ONNX axis have different meanings. + # See Softmax comment for details. + # TODO: remove this as onnx opset 11 spec allows negative axes + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + if dim < 0: + dim = input_dim + dim + is_transpose_required = input_dim != dim + 1 + # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + if is_transpose_required: + return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] + return return_op + + +@_onnx_symbolic("aten::_log_softmax") +@symbolic_helper.parse_args("v", "i", "i") +def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): + if ( + half_to_float + and _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.HALF + ): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return log_softmax(g, input, dim) + + +@_onnx_symbolic("aten::_convolution") +@symbolic_helper.parse_args( + "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" +) +def _convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32=None, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + } + + if any(o != 0 for o in output_padding): + # ONNX supports both output_shape and output_padding. they are equivalent expressive. + # output_padding is more straightforward, so we use it here. + # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 + assert transposed + assert len(stride) == len(output_padding) + kwargs["output_padding_i"] = output_padding + + n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::_convolution_mode") +@symbolic_helper.parse_args( + "v", + "v", + "v", + "is", + "s", + "is", + "i", +) +def _convolution_mode( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + if padding == "valid": + padding = "VALID" + elif padding == "same": + padding = "SAME_UPPER" + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + "auto_pad_s": padding, + "dilations_i": dilation, + "group_i": groups, + } + + n = g.op("Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::convolution") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") +def convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv1d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv2d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv3d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose1d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose3d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + symbolic_helper.check_training_mode(training, "batch_norm") + + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 9, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + outputs=1 if not training else 5, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var, saved_mean, saved_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) + saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) + return res + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + axes = [-i for i in range(len(normalized_shape), 0, -1)] + + two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) + eps_cst = symbolic_helper._generate_wrapped_number(g, eps) + + if g.opset < 18: + mean = g.op("ReduceMean", input, axes_i=axes) + else: + mean = g.op( + "ReduceMean", + input, + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + numerator = sub(g, input, mean) + + # Cast it to eps dtype to avoid precision loss + is_type_half = ( + _type_utils.JitScalarType.from_value(numerator) + == _type_utils.JitScalarType.HALF + ) + if is_type_half: + eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) + numerator = g.op( + "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() + ) + + # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula + if g.opset < 18: + variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) + else: + variance = g.op( + "ReduceMean", + pow(g, numerator, two_cst), + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + denominator = sqrt(g, g.op("Add", variance, eps_cst)) + normalized = g.op("Div", numerator, denominator) + + # Cast back to input type as eps related ops are all done + if is_type_half: + input_dtype = _type_utils.JitScalarType.from_value(input) + normalized = g.op( + "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() + ) + + if not (weight is None or symbolic_helper._is_none(weight)): + normalized = mul(g, normalized, weight) + if not (bias is None or symbolic_helper._is_none(bias)): + normalized = add(g, normalized, bias) + + # rdenominator := 1 / sqrt(variance + eps) + # According to aten::native_layer_norm, rdenominator should have the same dtype as input, + # mean and normalized, so we need to Cast it back + if is_type_half: + denominator = g.op( + "Cast", + denominator, + to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] + ) + rdenominator = g.op("Reciprocal", denominator) + else: + rdenominator = reciprocal(g, denominator) + + return normalized, mean, rdenominator + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +) -> _C.Value: + normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) + return normalized + + +@_onnx_symbolic("aten::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") +def instance_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + use_input_stats: bool, + momentum: Number, + eps: Number, + cudnn_enabled: bool, +): + symbolic_helper.check_training_mode(use_input_stats, "instance_norm") + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if weight is None or symbolic_helper._is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or symbolic_helper._is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + if ( + running_mean is None + or symbolic_helper._is_none(running_mean) + or running_var is None + or symbolic_helper._is_none(running_var) + ): + return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) + else: + input_size = symbolic_helper._get_tensor_sizes(input) + # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. + # For more information instance_norm(): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 + input_size_reshape = input_size.copy() + n = input_size[0] + if n is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm training for unknown " + "batch size.", + input, + ) + c = input_size[1] + input_size_reshape[0] = 1 + input_size_reshape[1] = n * c + weight_ = repeat( + g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + bias_ = repeat( + g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + running_mean_ = repeat( + g, + running_mean, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + running_var_ = repeat( + g, + running_var, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + input_reshaped = g.op( + "Reshape", + input, + g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), + ) + out = batch_norm( + g, + input_reshaped, + weight_, + bias_, + running_mean_, + running_var_, + use_input_stats, + momentum, + eps, + cudnn_enabled, + ) + return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + sizes = symbolic_helper._get_tensor_sizes(input) + # FIXME(justinchuby): Get rid of the try catch here to improve readability + try: + sizedim = sizes[dimension] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + sizedim = None + if sizedim is not None: + low_indices = range(0, sizedim, step) + hi_indices = range(size, sizedim + 1, step) + stack = [ + symbolic_helper._slice_helper( + g, input, axes=[dimension], starts=[low], ends=[hi] + ) + for low, hi in zip(low_indices, hi_indices) + ] + ndim = len(sizes) + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + unsqueeze = [ + symbolic_helper._unsqueeze_helper( + g, g.op("Transpose", t, perm_i=perm), [dimension] + ) + for t in stack + ] + return g.op("Concat", *unsqueeze, axis_i=dimension) + else: + return symbolic_helper._unimplemented( + "Unfold", "input size not accessible", input + ) + + +@_onnx_symbolic("aten::elu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "t", "t", "t") +def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): + if scale and scale != 1.0: + return symbolic_helper._unimplemented( + "scale", "does not support scale in Elu", scale + ) + if input_scale and input_scale != 1.0: + return symbolic_helper._unimplemented( + "input_scale", "does not support input_scale in Elu", input_scale + ) + # See Note [Export inplace] + return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) + + +@_onnx_symbolic("aten::selu") +@symbolic_helper.quantized_args(True) +def selu(g: jit_utils.GraphContext, input): + return g.op("Selu", input) + + +@_onnx_symbolic("aten::index_select") +@symbolic_helper.parse_args("v", "i", "v") +def index_select(g: jit_utils.GraphContext, self, dim, index): + # In case of a scalar index, index_select returns a tensor with the same rank as the input. + # To match this behavior in ONNX, we make index a 1D tensor so that the following gather + # also produces a tensor with the same rank as the input. + return symbolic_helper._select_helper(g, self, dim, index) + + +@_onnx_symbolic("aten::index_put") +def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + if accumulate: + return add(g, self, values) + return values + symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = expand(g, value, expanded_index_shape, None) + + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bucketize") +@symbolic_helper.parse_args("v", "v", "b", "b") +def bucketize( + g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False +): + out_type = _C_onnx.TensorProtoDataType.INT64 + if out_int32: + out_type = _C_onnx.TensorProtoDataType.INT32 + # A tensor expanded_boundaries is created such that it + # contains a copy of boundaries for each element of self. + new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) + # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + tensor_rank = symbolic_helper._get_tensor_rank(self) + assert tensor_rank is not None + unsqueeze_axes = list(range(1, tensor_rank + 1)) + expanded_boundaries = expand( + g, + symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), + new_shape, + None, + ) + # Compare each element of self to boundaries to get a tensor + # with leading 1s and trailing 0s. + # e.g., 4 > [1, 3, 4] = [1, 1, 0] + # The index of the last 1 is the bucket where the element should go. + if right: + cond = ge(g, self, expanded_boundaries) + else: + cond = gt(g, self, expanded_boundaries) + cond_out = g.op("Cast", cond, to_i=out_type) + # Sum to get the number of 1s corresponding to each element, + # which is the same as the bucket index. + # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 + return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) + + +@_onnx_symbolic("aten::type_as") +def type_as(g: jit_utils.GraphContext, self, other): + self_dtype = symbolic_helper._try_get_scalar_type(self) + other_dtype = symbolic_helper._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: + return self + if other_dtype is not None: + return g.op( + "Cast", + self, + to_i=other_dtype.onnx_type(), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of type_as for tensor " + "of unknown dtype. Please check if the dtype of the " + "parameter passed to the type_as function is correct.", + other, + ) + + +@_onnx_symbolic("aten::cosine_similarity") +@symbolic_helper.parse_args("v", "v", "i", "f") +def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): + cross = symbolic_helper._reducesum_helper( + g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 + ) + x1_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 + ) + x2_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 + ) + div_tens = max( + g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) + ) + return div(g, cross, div_tens) + + +@_onnx_symbolic("aten::pairwise_distance") +def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): + if not symbolic_helper._is_value(eps): + eps = g.op("Constant", value_t=torch.tensor([eps])) + inv_p = div( + g, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), + add(g, p, eps), + ) + summation = symbolic_helper._reducesum_helper( + g, + pow(g, sub(g, input1, input2), p), + axes_i=[-1], + keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), + ) + return pow(g, summation, inv_p) + + +@_onnx_symbolic("aten::clone") +# ignore clone operators that are inserted by PyTorch autograd +def clone(g: jit_utils.GraphContext, input, unused_memory_format): + return input + + +@_onnx_symbolic("aten::abs") +def abs(g: jit_utils.GraphContext, self): + return g.op("Abs", self) + + +@_onnx_symbolic("aten::log") +def log(g: jit_utils.GraphContext, self): + return g.op("Log", self) + + +@_onnx_symbolic("aten::log1p") +def log1p(g: jit_utils.GraphContext, self): + return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) + + +@_onnx_symbolic("aten::log10") +def log10(g: jit_utils.GraphContext, self): + _ln10 = 2.30258509299404568401 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + f_dtype = _type_utils.JitScalarType.from_value(self) + if not symbolic_helper._is_fp(self): + f_dtype = _type_utils.JitScalarType.FLOAT + self = g.op("Cast", self, to_i=f_dtype.onnx_type()) + if not symbolic_helper._is_fp(exponent): + exponent = g.op( + "Cast", + exponent, + to_i=f_dtype.onnx_type(), + ) + pow = g.op("Pow", self, exponent) + return pow + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + # min or max may be None that we need to dispatch to + # Clip separately, as ONNX does not have None syntax + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, + "Clip", + self, + min_f=symbolic_helper._parse_arg(min, "f"), + max_f=symbolic_helper._parse_arg(max, "f"), + opset_before=12, + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + if symbolic_helper._is_constant(min): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + min = g.op("Cast", min, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + if symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + max = g.op("Cast", max, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + reduce_kwargs = {"keepdims_i": keepdim} + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + reduce_kwargs["axes_i"] = [dim] + + return g.op("ReduceMin", self, **reduce_kwargs), g.op( + "ReduceMax", self, **reduce_kwargs + ) + + +@_onnx_symbolic("aten::exp") +def exp(g: jit_utils.GraphContext, self): + return g.op("Exp", self) + + +@_onnx_symbolic("aten::dropout_") +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "i") +def dropout(g: jit_utils.GraphContext, input, p, train): + symbolic_helper.check_training_mode(train, "dropout") + # if train is False, dropout is no-op + if not train: + return input + r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) + return r + + +@_onnx_symbolic( + "aten::alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], +) # See Note [Export inplace] +@_onnx_symbolic( + "aten::feature_alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], +) +@_onnx_symbolic( + "aten::feature_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], +) +@_onnx_symbolic( + "aten::feature_alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], +) +@_onnx_symbolic( + "aten::alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], +) +@_onnx_symbolic( + "aten::feature_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_dropout")], +) +def _unsupported_dropout(name: str): + @symbolic_helper.parse_args("v", "none", "b") + def feature_dropout(g, input, p, train): + # NB: In inference mode, FeatureDropout is exported as an identity op. + if train: + return symbolic_helper._unimplemented(name, "training mode", input) + return input + + return feature_dropout + + +@_onnx_symbolic("aten::norm") +@symbolic_helper.parse_args("v", "t", "is", "i", "v") +def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): + if p == 1: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") + elif p == 2: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") + else: + raise errors.SymbolicValueError( + "ONNX export only p-norms with p of 1 or 2", self + ) + result = f(g, self, dim=dim, keepdim=keepdim) + if dtype is not None: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return result + + +@_onnx_symbolic("aten::conv_tbc") +@symbolic_helper.parse_args("v", "v", "v", "i") +def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) + + +@_onnx_symbolic("aten::_unique") +@symbolic_helper.parse_args("v", "i", "i") +def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): + return symbolic_helper._onnx_unsupported("_unique", input) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): + symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) + + +@_onnx_symbolic("aten::_cast_Byte") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) + + +@_onnx_symbolic("aten::_cast_Char") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) + + +@_onnx_symbolic("aten::_cast_Short") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) + + +@_onnx_symbolic("aten::_cast_Int") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + + +@_onnx_symbolic("aten::_cast_Long") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::_cast_Half") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + + +@_onnx_symbolic("aten::_cast_Float") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + +@_onnx_symbolic("aten::_cast_Double") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + +@_onnx_symbolic("aten::_cast_Bool") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::new_empty") +def new_empty( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return empty(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::scalar_tensor") +def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.FLOAT + scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return scalar + + +@_onnx_symbolic("aten::tensor") +def tensor( + g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if symbolic_helper._is_packed_list(data): + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + symbolic_helper._unpack_list(data)[0] + ) + input_list = [] + for t in symbolic_helper._unpack_list(data): + shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) + t = symbolic_helper._reshape_helper(g, t, shape_reference) + t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + input_list.append(t) + return g.op("Concat", *input_list, axis_i=0) + else: + if dtype is None: + dtype = _type_utils.JitScalarType.from_value(data) + if symbolic_helper._is_list(data) and ( + symbolic_helper._is_tensor_list(data) + or symbolic_helper._is_scalar_list(data) + ): + data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) + return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::as_tensor") +def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): + return tensor(g, data, dtype, device) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_zeros") +def new_zeros( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zero") +def zero(g: jit_utils.GraphContext, self): + self_dtype = symbolic_helper._try_get_scalar_type(self) + return zeros_like(g, self, self_dtype) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_ones") +def new_ones( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return ones(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype + tmp = zeros(g, sizes, dtype, layout, device) + return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=const_value.view(1).to(scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::full_like") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + fill_value = symbolic_helper._maybe_get_const(fill_value, "f") + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + if symbolic_helper._is_value(fill_value): + tmp = zeros_like(g, input, dtype, layout, device) + fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) + return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) + else: + shape = g.op("Shape", input) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_full") +def new_full( + g: jit_utils.GraphContext, + self, + size, + fill_value, + dtype, + layout, + device, + pin_memory=False, +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return full(g, size, fill_value, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::eye") +def eye(g: jit_utils.GraphContext, *args): + if len(args) == 5: + # aten::eye(n, dtype, layout, device, pin_memory) + n, dtype, layout, device, _pin_memory = args + dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) + shape = g.op("Concat", dim_size, dim_size, axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + if len(args) == 6: + # aten::eye(n, m, dtype, layout, device, pin_memory) + n, m, dtype, layout, device, _pin_memory = args + shape = g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, n, [0]), + symbolic_helper._unsqueeze_helper(g, m, [0]), + axis_i=0, + ) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + + return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor + dim, start, end, step = args + step = symbolic_helper._parse_arg(step, "i") + if step != 1: + raise errors.SymbolicValueError("step!=1 is currently not supported", self) + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + is_start_onnx_const = start.node().kind() == "onnx::Constant" + is_end_onnx_const = end.node().kind() == "onnx::Constant" + if ( + ((not is_start_none) and (not is_start_onnx_const)) + or ((not is_end_none) and (not is_end_onnx_const)) + or dim.node().kind() != "onnx::Constant" + ): + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " + "is a deprecated experimental op. Please use statically allocated " + "variables or export to a higher opset version.", + self, + ) + else: + start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) + end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) + dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) + return g.op( + "DynamicSlice", + self, + start_unsqueezed, + end_unsqueezed, + dim_unsqueezed, + ) + else: + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + dim = symbolic_helper._parse_arg(dim, "i") + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + elif len(args) == 3: + # aten::slice(t[] l, int start, int end, int step) -> t[] + start, end, step = args + dim = 0 + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + + return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + hs = hardsigmoid(g, self) + return g.op("Mul", self, hs) + + +@_onnx_symbolic("aten::hardsigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +@symbolic_helper.parse_args("v") +def hardsigmoid(g: jit_utils.GraphContext, self): + # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. + # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html + return g.op("HardSigmoid", self, alpha_f=1 / 6) + + +@_onnx_symbolic("aten::tanhshrink") +@symbolic_helper.parse_args("v") +def tanhshrink(g: jit_utils.GraphContext, self): + return g.op("Sub", self, tanh(g, self)) + + +@_onnx_symbolic("aten::hardshrink") +@symbolic_helper.parse_args("v", "f") +def hardshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) + return g.op( + "Where", + cond, + self, + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + + +@_onnx_symbolic("aten::softshrink") +@symbolic_helper.parse_args("v", "f") +def softshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + gt_cond = gt(g, self, lambd_op) + gt_out = g.op( + "Where", + gt_cond, + sub(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + lt_cond = lt(g, self, neg(g, lambd_op)) + lt_out = g.op( + "Where", + lt_cond, + add(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + return add(g, gt_out, lt_out) + + +@_onnx_symbolic("aten::alias") +def alias(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::unsqueeze") +@symbolic_helper.parse_args("v", "i") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" + # Handle negative dim + if dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export unsqueeze with negative axis " + + str(dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(dim + rank + 1) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + dim = dim + rank + 1 + else: + return symbolic_helper._unimplemented( + "unsqueeze", "negative axis with unknown input rank", self + ) + + return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) + + +@_onnx_symbolic("aten::sort") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Sort", "Out parameter is not supported for sort", self + ) + self_sizes = symbolic_helper._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim_size = None + + if dim_size is None: + return symbolic_helper._unimplemented("Sort", "input size not accessible", self) + + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) + + +@_onnx_symbolic("aten::numel") +def numel(g: jit_utils.GraphContext, self): + return symbolic_helper._numel_helper(g, self) + + +@_onnx_symbolic("aten::topk") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + if out is not None: + symbolic_helper._unimplemented( + "TopK", "Out parameter is not supported for topk", self + ) + if not largest: + symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) + + return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) + + +@_onnx_symbolic("prim::convert_element_type") +def convert_element_type(g: jit_utils.GraphContext, self, *args): + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::to") +def to(g: jit_utils.GraphContext, self, *args): + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(_C.ListType.ofInts()) + or isinstance(args[0].type(), _C.DeviceObjType) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() + # In this case, the constant value is a tensor not int, + # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if ( + symbolic_helper._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + ): + tval = symbolic_helper._node_get(args[0].node(), "value") + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = _type_utils.JitScalarType.from_value(args[0]) + return g.op( + "Cast", + self, + to_i=dtype.onnx_type(), + ) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + dtype = _type_utils.JitScalarType.INT64 + shape_ = ones_like(g, repeats, dtype) + self = g.op("Expand", self, shape_) + return g.op("Tile", self, repeats) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + if repeats_sizes[0] is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported for cases with dynamic repeats", + self, + ) + assert repeats_sizes[0] == input_sizes[dim], ( + "repeats must have the same size as input along dim" + ) + reps = repeats_sizes[0] + else: + raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) + + final_splits = [] + r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) + i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + after_view = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [2, 3]), + g.op( + "Constant", + value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) + else: + output_channel = dims[1] // upscale_factor // upscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + upscale_factor, + upscale_factor, + dims[2], + dims[3], + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] * upscale_factor, + dims[3] * upscale_factor, + ] + ), + ), + allowzero=0, + ) + + +@_onnx_symbolic("aten::pixel_unshuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [3]), + g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), + allowzero=0, + ) + after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) + final_reshape = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) + else: + output_channel = dims[1] * downscale_factor * downscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + dims[1], + dims[2] // downscale_factor, + downscale_factor, + dims[3] // downscale_factor, + downscale_factor, + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] // downscale_factor, + dims[3] // downscale_factor, + ] + ), + ), + allowzero=0, + ) + + +def _generic_rnn( + g: jit_utils.GraphContext, + variant, + input, + initial_states, + all_weights, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first=None, + batch_sizes=None, +): + warnings.warn( + "Exporting a model to ONNX with a batch_size other than 1, " + + "with a variable length with " + + variant + + " can cause an error " + + "when running the ONNX model with a different batch size. " + + "Make sure to save the model with a batch size of 1, " + + "or define the initial states (h0/c0) as inputs of the model. " + ) + + onnxActivations = [ + "Relu", + "Tanh", + "Sigmoid", + "Affine", + "LeakyRelu", + "ThresholdedRelu", + "ScaledTanh", + "HardSigmoid", + "Elu", + "Softsign", + "Softplus", + ] + variantToOnnxActivationMap = dict( + zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) + ) + weights_per_layer = 4 if has_biases else 2 + # this means that projections are used inside LSTM, so need to tell user that it's not supported + if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( + 1 + bidirectional + ): + return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) + assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) + layer_weights = [ + all_weights[i : i + weights_per_layer] + for i in range(0, len(all_weights), weights_per_layer) + ] + if batch_first: + # batch, seq, feat -> seq, batch, feat + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if dropout and train: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "dropout in training mode", input + ) + + if variant.startswith("RNN"): + nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] + variant = "RNN" + + w_hh = all_weights[1] + hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "unknown hidden size", input + ) + + unidirectional = not bidirectional + + prev_output = input + + h_outs = [] + if variant == "RNN" or variant == "GRU": + h0 = initial_states + elif variant == "LSTM": + h0, c0 = initial_states + c_outs = [] + + sequence_lens = unused(g) if batch_sizes is None else batch_sizes + + if variant == "GRU": + # pytorch is reset, input, hidden + # onnx is input, reset, hidden + reform_permutation = [(1, 2), (0, 1), (2, 3)] + elif variant == "LSTM": + # pytorch is input, forget, cell, output. + # onnx is input, output, forget, cell. + reform_permutation = [(0, 1), (3, 4), (1, 3)] + + def reform_weights(g, w, n, intervals): + slices = [ + symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) + for x, y in intervals + ] + return g.op("Concat", *slices, axis_i=0) + + def transform_weights_no_bias(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] + ) + + def transform_weights(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh, bias_ih, bias_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh, bias_ih, bias_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] + ) + + def retrieve_state(x, start, end): + return ( + x + if num_layers == 1 + else symbolic_helper._slice_helper( + g, x, axes=[0], starts=[start], ends=[end] + ) + ) + + for i in range(num_layers): + if unidirectional: + if weights_per_layer == 4: + weight_ih, weight_hh, bias_concat = transform_weights(i) + else: + weight_ih, weight_hh = transform_weights_no_bias(i) + bias_concat = unused(g) + + state_indices = i, i + 1 + else: + if weights_per_layer == 4: + weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) + weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) + bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) + else: + weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) + weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) + bias_concat = unused(g) + + weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) + weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) + + state_indices = 2 * i, 2 * i + 2 + + inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] + + inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] + if variant == "LSTM": + inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] + + extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} + if variant == "RNN": + if bidirectional: + activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] + else: + activation = [nonlinearity] # type: ignore[possibly-undefined] + + prev_output, h_out = g.op( + "RNN", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + activations_s=activation, + **extra_kwargs, + ) + elif variant == "GRU": + prev_output, h_out = g.op( + "GRU", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + linear_before_reset_i=1, + **extra_kwargs, + ) + elif variant == "LSTM": + prev_output, h_out, c_out = g.op( + "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs + ) + + if bidirectional: + # The ONNX RNN/GRU/LSTM produce an output of dimensions + # seq_len, num_directions, batch, hidden_size + # We have to convert to match pytorch's expected + # seq_len, batch, num_directions * hidden_size + # by first moving num_directions before hidden_size with + # Transpose, and then combining it with hidden_size + # with Reshape. + prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) + prev_output = symbolic_helper._reshape_helper( + g, + prev_output, + g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), + allowzero=0, + ) + else: + prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) + + h_outs.append(h_out) # type: ignore[possibly-undefined] + if variant == "LSTM": + c_outs.append(c_out) # type: ignore[possibly-undefined] + if batch_first: + # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size + prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) + h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] + if variant == "RNN" or variant == "GRU": + return prev_output, h_outs + elif variant == "LSTM": + c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] + return prev_output, h_outs, c_outs + + +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") +def _lstm_full( + g: jit_utils.GraphContext, + input, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") +def _lstm_packed( + g: jit_utils.GraphContext, + input, + batch_sizes, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + +@_onnx_symbolic("aten::lstm") +def lstm(g: jit_utils.GraphContext, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _lstm_packed(g, *args) + else: + return _lstm_full(g, *args) + + +@_onnx_symbolic("aten::lstm_cell") +def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): + input = symbolic_helper._unsqueeze_helper(g, self, [0]) + hidden = symbolic_helper._unpack_list(hidden) + hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] + weight = ( + (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) + ) + has_biases = True if symbolic_helper._is_tensor(b_ih) else False + _, h_outs, c_outs = _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers=1, + dropout=0, + train=0, + bidirectional=False, + batch_first=False, + ) + return symbolic_helper._squeeze_helper( + g, h_outs, [0] + ), symbolic_helper._squeeze_helper(g, c_outs, [0]) + + +@_onnx_symbolic( + "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] +) +@_onnx_symbolic( + "aten::rnn_tanh", + decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], +) +@_onnx_symbolic( + "aten::rnn_relu", + decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], +) +def _one_hidden_rnn(kind: str): + @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") + def _rnn_full( + g, + input, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") + def _rnn_packed( + g, + input, + batch_sizes, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + def symbolic(g, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _rnn_packed(g, *args) + else: + return _rnn_full(g, *args) + + return symbolic + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::detach") +def detach(g: jit_utils.GraphContext, input): + # Erase aten::detach nodes because ONNX is inference only + return input + + +@_onnx_symbolic("aten::contiguous") +@symbolic_helper.parse_args("v", "i") +def contiguous(g: jit_utils.GraphContext, input, memory_format): + if memory_format > 2: # allower values are any, preserve and contiguous_format + raise errors.SymbolicValueError( + "onnx memory_format support is not implemented", input + ) + return input + + +@_onnx_symbolic("aten::_pack_padded_sequence") +@symbolic_helper.parse_args("v", "v", "i") +def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): + # Currently there is no PackPadded operator in ONNX. We rely on an + # optimization pass to remove this later. It is an error if all + # PackPadded operators cannot be optimized out. + if batch_first: + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): + raise errors.SymbolicValueError( + "'lengths' must be a Tensor for ONNX export", input + ) + # We know it's a TensorType so this check is now safe. + # It's really only necessary because those operators expand to something that + # only works with int32 types in Caffe2... + if ( + _type_utils.JitScalarType.from_value( + lengths, _type_utils.JitScalarType.UNDEFINED + ) + != _type_utils.JitScalarType.INT + ): + lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("prim::PackPadded", input, lengths, outputs=2) + + +@_onnx_symbolic("aten::_pad_packed_sequence") +@symbolic_helper.parse_args("v", "v", "i", "t", "v") +def _pad_packed_sequence( + g: jit_utils.GraphContext, + data, + batch_sizes, + batch_first, + padding_value, + total_length, +): + # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence + # It is only useful/used when training using data_parallel model, so + # It shouldn't be relevant for ONNX anyway + data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) + if batch_first: + data = g.op("Transpose", data, perm_i=[1, 0, 2]) + return data, lengths + + +@_onnx_symbolic("aten::randint") +def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + randn = g.op( + "RandomUniformLike", + shape_const, + low_f=low_i, + high_f=high_i, + ) + else: + randn = g.op( + "RandomUniform", + shape_i=shape, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randint_like") +def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + randn = g.op( + "RandomUniformLike", + self, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randn") +def randn(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomNormalLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomNormal", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::rand") +def rand(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomUniformLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomUniform", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::randn_like") +def randn_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) + + +@_onnx_symbolic("aten::rand_like") +def rand_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return g.op( + "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + +@_onnx_symbolic("aten::rrelu") +@symbolic_helper.parse_args("v", "f", "f", "i", "none") +def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): + if not training: + slope = (upper + lower) / 2.0 + return g.op("LeakyRelu", input, alpha_f=slope) + p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) + return g.op("PRelu", input, p) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + + dtype = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + if dtype == _type_utils.JitScalarType.UNDEFINED: + return symbolic_helper._unimplemented( + "Bernoulli", "input dtype not accessible", input + ) + + rands = g.op( + "RandomUniformLike", + input, + high_f=1.0, + low_f=0.0, + dtype_i=dtype.onnx_type(), + ) + prob = p if p is not None and not symbolic_helper._is_none(p) else input + output = g.op("Less", rands, prob) + return g.op("Cast", output, to_i=dtype.onnx_type()) + + +@_onnx_symbolic("aten::log_sigmoid") +@symbolic_helper.parse_args("v") +def log_sigmoid(g: jit_utils.GraphContext, input): + p = g.op("Sigmoid", input) + return g.op("Log", p) + + +@_onnx_symbolic("aten::erf") +@symbolic_helper.parse_args("v") +def erf(g: jit_utils.GraphContext, input): + return g.op("Erf", input) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + input, + ) + + if dim == 0: + return symbolic_helper._reshape_helper(g, input, [1]) + if dim == 1: + return g.op("Identity", input) + # TODO: remove this as onnx opset 11 spec allows negative axes + if end_dim < 0: + end_dim = dim + end_dim + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1 and end_dim == dim - 1: + return g.op("Flatten", input, axis_i=start_dim) + if start_dim == 0 and end_dim == dim - 2: + return g.op("Flatten", input, axis_i=end_dim + 1) + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::nonzero") +@symbolic_helper.parse_args("v") +def nonzero(g: jit_utils.GraphContext, input): + """Emitted from `torch.nonzero(x, as_tuple=False)`""" + return t(g, g.op("NonZero", input)) + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::isnan") +@symbolic_helper.parse_args("v") +def isnan(g: jit_utils.GraphContext, input): + output = g.op("IsNaN", input) + return output + + +@_onnx_symbolic("aten::any") +def _any(g: jit_utils.GraphContext, *args): + # aten::any(Tensor self) + if len(args) == 1: + input = args[0] + dim, keepdim = None, 0 + # aten::any(Tensor self, int[]? dim, bool keepdim) + else: + input, dim, keepdim = args + # Can be int list or single int + dim = symbolic_helper._parse_arg(dim, "t") + dim = [int(d) for d in dim.view(-1)] + keepdim = symbolic_helper._parse_arg(keepdim, "i") + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + input_sum = symbolic_helper._reducesum_helper( + g, input, axes_i=dim, keepdims_i=keepdim + ) + return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) + + +@_onnx_symbolic("aten::all") +def _all(g: jit_utils.GraphContext, *args): + input = g.op("Not", args[0]) + # aten::all(Tensor self) + if len(args) == 1: + return g.op("Not", _any(g, input)) + # aten::all(Tensor self, int[]? dim, bool keepdim) + else: + return g.op("Not", _any(g, input, args[1], args[2])) + + +@_onnx_symbolic("aten::narrow") +@symbolic_helper.parse_args("v", "i", "i", "i") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + return symbolic_helper._slice_helper( + g, input, axes=[dim], starts=[start], ends=[start + length] + ) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("Scatter", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if self_scalar_type != src_type: + src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) + return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + return symbolic_helper._unimplemented( + "scatter_add", "input dtype not accessible", self + ) + sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) + else: + to_add = zeros_like(g, self, scalar_type) + to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) + return add(g, self, to_add) + + +@_onnx_symbolic("aten::log2") +def log2(g: jit_utils.GraphContext, self): + _ln2 = 0.693147180559945309 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) + + +@_onnx_symbolic("aten::is_floating_point") +def is_floating_point(g: jit_utils.GraphContext, self): + if symbolic_helper._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + +@_onnx_symbolic("aten::__is_") +def __is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if symbolic_helper._is_none(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + return eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@wrap_logical_op_with_negation +def __isnot_(g: jit_utils.GraphContext, self, other): + return __is_(g, self, other) + + +@_onnx_symbolic("aten::one_hot") +def one_hot(g: jit_utils.GraphContext, self, num_classes): + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + # onnxruntime supports limited type combinations for OneHot. + if _type_utils.JitScalarType.from_value( + num_classes, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT16, + }: + num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("OneHot", self, num_classes, values, axis_i=-1) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) + # NOTE: This workaround is needed since GatherElement is only supported + # since opset 11, and Gather in ONNX is not the same as torch.gather. + scalar_type = _type_utils.JitScalarType.from_value(self) + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) + index = g.op( + "Cast", + g.op("OneHot", index, depth, values, axis_i=dim), + to_i=scalar_type.onnx_type(), + ) + mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) + return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) + + +@symbolic_helper.parse_args("v", "is", "i", "i") +def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): + return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) + + +@_onnx_symbolic("aten::std") +def std(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return g.op("Sqrt", var) + + +@_onnx_symbolic("aten::var") +def var(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return var + + +@_onnx_symbolic("aten::var_mean") +def var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return _var_mean(g, input, None, args[0], None) + else: + return _var_mean(g, input, *args) + + +@_onnx_symbolic("aten::std_mean") +def std_mean(g: jit_utils.GraphContext, input, *args): + var, mean = var_mean(g, input, *args) + return g.op("Sqrt", var), mean + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + def _float_step_convert(range_tensor): + if symbolic_helper._is_fp(range_tensor): + range_tensor = g.op( + "Cast", + g.op("Ceil", range_tensor), + to_i=_type_utils.JitScalarType.INT64.onnx_type(), + ) + return range_tensor + + if len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + range_tensor = _float_step_convert(end) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + step = symbolic_helper._unsqueeze_helper(g, step, [0]) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] + ) + arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Sub", end, start)) + arange_tensor = g.op( + "Add", + symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] + ), + start, + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::linspace") +def linspace( + g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory +): + range_tensor = symbolic_helper._arange_helper(g, steps, None) + step = div( + g, + sub(g, end, start), + sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), + ) + return add(g, mul(g, range_tensor, step), start) + + +@_onnx_symbolic("aten::lift") +def lift(g: jit_utils.GraphContext, self): + # at::lift() is a no-op from the perspective of tracing for onnx + return self + + +@_onnx_symbolic("aten::masked_fill") +def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + value = symbolic_helper._maybe_get_scalar(value) + return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) + + +@_onnx_symbolic("aten::masked_fill_") +def masked_fill_(g: jit_utils.GraphContext, self, mask, value): + return masked_fill(g, self, mask, value) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + def try_mask_to_index(index): + if not symbolic_helper._is_none(index) and ( + _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.UINT8 + or symbolic_helper._is_bool(index) + ): + if g.opset < 9: + raise errors.SymbolicValueError( + "Exporting masked indices are only supported after ONNX opset 9.", + self, + ) + warnings.warn( + "Exporting aten::index operator with indices of type Byte. " + "Only 1-D indices are supported. In any other case, " + "this will produce an incorrect ONNX graph." + ) + index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) + return index + + indices = [try_mask_to_index(idx) for idx in indices] + if len(indices) == 1: + return symbolic_helper._select_helper( + g, self, 0, indices[0], apply_reshape=False + ) + else: + # Multiple tensors as indices. Each tensor could either be + # 1. prim::Constant() + # representing ":" in python indexing. E.g. tensor[:, :] + # 2. prim::Constant[value=...] or tensor output + # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. + # For more info on advanced indexing, + # check https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing + + # Consider a general case of + # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] + # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". + # Same results can be achieved through transposing t into + # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] + # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t + # and process the tensor indices. + # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] + # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) + # After gather, reshape and transpose back. + adv_idx_indices = [ + i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) + ] + + if len(adv_idx_indices) == 0: + return self + elif len(adv_idx_indices) == 1: + return index_select( + g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] + ) + else: + rank = symbolic_helper._get_tensor_rank(self) + if rank is None: + return symbolic_helper._unimplemented( + "aten::index", + "operator of advanced indexing on tensor of unknown rank. ", + self, + ) + # TODO: If indexing is supported natively in ONNX in future opsets, + # update the warning to recommend exporting with higher opset version. + warnings.warn( + "Exporting aten::index operator of advanced indexing in opset " + f"{GLOBALS.export_onnx_opset_version}" + " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results." + ) + adv_idx_count = len(adv_idx_indices) + shape_tensor = _shape_as_tensor(g, self) + dim_tensor_list = [ + g.op( + "Gather", + shape_tensor, + g.op("Constant", value_t=torch.LongTensor([dim])), + axis_i=0, + ) + for dim in range(rank) + ] + + self = g.op( + "Transpose", + self, + perm_i=adv_idx_indices + + [i for i in range(rank) if i not in adv_idx_indices], + ) + self = g.op("Flatten", self, axis_i=adv_idx_count) + + # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. + cum_adv_index = indices[adv_idx_indices[-1]] + multiplier = dim_tensor_list[adv_idx_indices[-1]] + for i in range(adv_idx_count - 2, -1, -1): + adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) + cum_adv_index = g.op("Add", cum_adv_index, adv_index) + multiplier = g.op( + "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] + ) + + # perform gather + self = index_select(g, self, 0, cum_adv_index) + + cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) + # check if all advanced indices are consecutive. + # Refer to https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # to understand how the subarray position is decided. + if adv_idx_indices == list( + range(adv_idx_indices[0], adv_idx_indices[-1] + 1) + ): + # unfold regular index axes + folded_adv_idx_shape_list = [ + g.op("Constant", value_t=torch.LongTensor([-1])) + ] + [ + dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices + ] + folded_adv_idx_shape = g.op( + "Concat", *folded_adv_idx_shape_list, axis_i=0 + ) + self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) + + # Transpose folded advanced indexed axis to its original location. + adv_idx_permute = ( + list(range(1, adv_idx_indices[0] + 1)) + + [0] + + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) + ) + self = g.op("Transpose", self, perm_i=adv_idx_permute) + + # unfold advanced index axes + final_shape_list = ( + [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + + [cum_adv_index_shape_tensor] + + [ + dim_tensor_list[i] + for i in range(adv_idx_indices[0], rank) + if i not in adv_idx_indices + ] + ) + final_shape = g.op("Concat", *final_shape_list, axis_i=0) + else: + final_shape = g.op( + "Concat", + cum_adv_index_shape_tensor, + *[ + dim_tensor_list[i] + for i in range(rank) + if i not in adv_idx_indices + ], + axis_i=0, + ) + + return symbolic_helper._reshape_helper(g, self, final_shape) + + +@_onnx_symbolic("aten::linalg_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html + ord_value = None + if dim is None: + if symbolic_helper._is_none(ord): + self = symbolic_helper._reshape_helper(g, self, [-1]) + ord = g.op("Constant", value_t=torch.LongTensor([2])) + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "dim", "Input rank must be known at export time.", self + ) + if self_dim == 1: + ord_value = symbolic_helper._parse_arg(ord, "f") + else: + dim = [0, 1] + else: + if len(dim) == 1: + if symbolic_helper._is_none(ord): + ord = g.op("Constant", value_t=torch.LongTensor([2])) + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value: + return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) + return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html + ord_value = symbolic_helper._parse_arg(ord, "s") + if ord_value == "fro": + return frobenius_norm(g, self, dim, keepdim) + elif ord_value == "nuc": + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) + else: + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value is None: + return frobenius_norm(g, self, dim, keepdim) + if ord_value == 2 or ord_value == -2: + # ord = 2/-2 unimplemented due to lack of operators + # used to calculate singular values + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) + # Wrap the dim vector to handle negative dim values + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "linalg.matrix_norm", "Input rank must be known at export time.", self + ) + # Common implementation for cases with + # ord = 1/-1 and ord = inf/-inf + if dim[0] < 0: + dim[0] += self_dim + if dim[1] < 0: + dim[1] += self_dim + + if ord_value == math.inf or ord_value == -math.inf: + dim[0], dim[1] = dim[1], dim[0] + if dim[1] > dim[0] and not keepdim: + dim[1] -= 1 + sum = symbolic_helper._reducesum_helper( + g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim + ) + if ord_value > 0: + result, _indices = max( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + else: + result, _indices = min( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + return result + + +@_onnx_symbolic("aten::linalg_cross") +@symbolic_helper.parse_args("v", "v", "i") +def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): + return cross(g, input, other, dim) + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "is", "b") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::multinomial") +@symbolic_helper.parse_args("v", "i", "b", "v") +def multinomial( + g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None +): + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Multinomial", "generator is not supported for multinomial", input + ) + if not replacement and num_samples > 1: + symbolic_helper._unimplemented( + "Multinomial", + "replacement=False when num_samples > 1 is not supported for multinomial", + input, + ) + + log_input = log(g, input) + return g.op( + "Multinomial", + log_input, + dtype_i=_C_onnx.TensorProtoDataType.INT64, + sample_size_i=num_samples, + ) + + +@_onnx_symbolic("aten::baddbmm") +def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): + scalar_type = _type_utils.JitScalarType.from_value(self) + batch_mul = matmul(g, batch1, batch2) + mul_a = mul( + g, + batch_mul, + g.op("Cast", alpha, to_i=scalar_type.onnx_type()), + ) + mul_b = mul( + g, + self, + g.op("Cast", beta, to_i=scalar_type.onnx_type()), + ) + return add(g, mul_a, mul_b) + + +@_onnx_symbolic("aten::meshgrid") +@symbolic_helper.parse_args("v", "s") +def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): + if indexing is None: + indexing = "ij" + elif indexing not in {"ij", "xy"}: + raise errors.SymbolicValueError( + f"Unsupported indexing: {indexing}", tensor_list + ) + unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) + if indexing == "xy": + unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] + tensors = [ + symbolic_helper._reshape_helper( + g, t, g.op("Constant", value_t=torch.LongTensor([-1])) + ) + for t in unpacked_tensor_list + ] + tensors_shape = [g.op("Shape", t) for t in tensors] + out_shape = g.op("Concat", *tensors_shape, axis_i=0) + out = [] + for i, t in enumerate(tensors): + shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( + tensors + ) + shape_i[i] = tensors_shape[i] + t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) + out.append(g.op("Expand", t_reshaped, out_shape)) + if indexing == "xy": + out[0], out[1] = out[1], out[0] + return g.op("prim::ListConstruct", *out) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + div = _floor_divide(g, input, other) + quo = g.op("Mul", div, other) + return g.op("Sub", input, quo) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): + if approximate == "tanh": + kBeta = math.sqrt(2 / math.pi) + kKappa = 0.044715 + + beta = torch.tensor(kBeta, dtype=torch.double) + kappa = torch.tensor(kKappa, dtype=torch.double) + one = torch.tensor(1.0, dtype=torch.double) + half = torch.tensor(0.5, dtype=torch.double) + + self_cube = mul(g, self, mul(g, self, self)) + inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) + return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) + else: + _sqrt2 = 1.4142135623730951 + erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add( + g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) + ) + return mul( + g, + mul(g, self, erf_plusone), + g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), + ) + + +@_onnx_symbolic("aten::group_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") +def group_norm( + g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled +): + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = symbolic_helper._get_tensor_rank(input) + if input_rank is None: + return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) + # 0 in the shape list keeps dimension value unchanged. + shape = [0, num_groups, -1] + input_reshaped = symbolic_helper._reshape_helper( + g, input, g.op("Constant", value_t=torch.LongTensor(shape)) + ) + + # C is always divisible by num_groups + # Due to shape difference. we need to apply weight and bias after + # instance norm computation and reshape + weight_ = g.op( + "Constant", + value_t=torch.tensor( + [1.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + bias_ = g.op( + "Constant", + value_t=torch.tensor( + [0.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + + norm_reshaped = g.op( + "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps + ) + norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) + + if weight is None or weight.node().mustBeNone(): + weight_value = torch.tensor( + [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or bias.node().mustBeNone(): + bias_value = torch.tensor( + [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + bias = g.op("Constant", value_t=bias_value) + + # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] + axes = list(range(1, input_rank - 1)) + return add( + g, + mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), + symbolic_helper._unsqueeze_helper(g, bias, axes), + ) + + +@_onnx_symbolic("aten::_weight_norm") +@symbolic_helper.parse_args("v", "v", "i") +def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): + rank = symbolic_helper._get_tensor_rank(weight_v) + if rank is not None: + # W = g * ((v) / ||v||) + # Compute norm_except_dim for l2 norm. dim = None means over all dims + # torch's weight_norm module sets dim = -1 if it's None. + # This conflicts the logic for negative axes to access dims backwards + # TODO: Might need a fix in torch group_norm module + axes = list(range(rank)) + if dim is not None: + if dim < -1: + dim += rank + if dim != -1: + axes.remove(dim) + norm_v = norm(g, weight_v, 2, axes, 1) + div = g.op("Div", weight_v, norm_v) + return g.op("Mul", div, weight_g) + raise errors.SymbolicValueError( + "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", + weight_v, + ) + + +@_onnx_symbolic("aten::dim") +def dim(g: jit_utils.GraphContext, self): + """Implement the dim functionality available for a pytorch tensor in ONNX""" + # ONNX does not support dim directly in this opset so we can use 2 ops to get the info + shape = g.op("Shape", self) + return g.op("Size", shape) + + +@_onnx_symbolic("aten::__contains_") +def __contains_(g: jit_utils.GraphContext, self, element): + unpacked_list = symbolic_helper._unpack_list(self) + if all( + symbolic_helper._is_constant(x) for x in unpacked_list + ) and symbolic_helper._is_constant(element): + return g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._node_get(element.node(), "value") + in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) + ), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of __contains__ for non-constant list or element.", + self, + ) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) + + +@_onnx_symbolic("aten::item") +def item(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::take") +def take(g: jit_utils.GraphContext, self, index): + self_flattened = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + out = index_select(g, self_flattened, 0, index) + out = reshape_as(g, out, index) + return out + + +def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): + diff_ = sub(g, target, input) + exp_ = exp(g, target) + output = mul(g, exp_, diff_) + return output + + +def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): + log_ = log(g, target) + diff_ = sub(g, log_, input) + output_pos = mul(g, target, diff_) + zeros_ = zeros_like(g, output_pos) + mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) + output = where(g, mask_, output_pos, zeros_) + return output + + +@_onnx_symbolic("aten::kl_div") +@symbolic_helper.parse_args("v", "v", "i", "b") +def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): + if log_target: + output = _kl_div_log_target_impl(g, input, target) + else: + output = _kl_div_non_log_target_impl(g, input, target) + + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "kl_div with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::mse_loss") +@symbolic_helper.parse_args("v", "v", "i") +def mse_loss(g: jit_utils.GraphContext, input, target, reduction): + output = mul(g, sub(g, input, target), sub(g, input, target)) + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "mse_loss with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::as_strided") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "is", "i") +def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): + sizes = symbolic_helper._maybe_get_const(sizes, "is") + rank = len(strides) + self_1d = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + ind: torch.Tensor | None + if not symbolic_helper._is_value(sizes): + ind = torch.tensor([0], dtype=torch.long) + for i, (size, stride) in enumerate(zip(sizes, strides)): + r_size = [1] * rank + r_size[i] = -1 + ind = ind + torch.arange(size).view(r_size) * stride + if offset: + ind = ind + offset + return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) + else: + ind = None + for i, stride in enumerate(strides): + r_size = [1] * rank + r_size[i] = -1 + size = select( + g, + sizes, + g.op("Constant", value_t=torch.tensor([0])), + g.op("Constant", value_t=torch.tensor(i)), + ) + tmp_ind = symbolic_helper._reshape_helper( + g, + arange(g, size, 4, None, None, None), + g.op("Constant", value_t=torch.tensor(r_size)), + ) + tmp_ind = g.op( + "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) + ) + if ind is None: + ind = tmp_ind + else: + ind = g.op("Add", ind, tmp_ind) + if offset: + ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) + return g.op("Gather", self_1d, ind) + + +@_onnx_symbolic("aten::__derive_index") +def __derive_index(g: jit_utils.GraphContext, index, start, step): + return g.op("Add", start, g.op("Mul", index, step)) + + +@_onnx_symbolic("aten::__range_length") +# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp +# if (step > 0 && lo < hi) { +# push(stack, 1 + (hi - 1 - lo) / step); +# } else if (step < 0 && lo > hi) { +# push(stack, 1 + (lo - 1 - hi) / (0 - step)); +# } else { +# push(stack, 0); +# } +def __range_length(g: jit_utils.GraphContext, lo, hi, step): + sub = g.op("Sub", hi, lo) + div = g.op("Ceil", true_divide(g, sub, step)) + return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::linear") +def linear(g: jit_utils.GraphContext, input, weight, bias): + rank = symbolic_helper._get_tensor_rank(input) + weight = t(g, weight) + if rank == 2 and not bias.node().mustBeNone(): + alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + output = addmm(g, bias, input, weight, alpha, beta) + else: + output = matmul(g, input, weight) + if not bias.node().mustBeNone(): + output = add(g, bias, output) + + return output + + +@_onnx_symbolic("aten::hann_window") +@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") +def hann_window( + g: jit_utils.GraphContext, + window_length, + periodic=True, + dtype: int | None = None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype_ = torch.get_default_dtype() + if not dtype_ or not dtype_.is_floating_point: + dtype_ = torch.float + scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) + else: + scalar_type = _type_utils.JitScalarType(dtype) + + n_array = arange(g, window_length, 4, None, None, None) + output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) + output = mul( + g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output + ) + + if periodic is False: + window_length = sub( + g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) + ) + output = div(g, output, window_length) + output = g.op( + "Cast", + square(g, sin(g, output)), + to_i=scalar_type.onnx_type(), + ) + + return output + + +@_onnx_symbolic("aten::mv") +def mv(g: jit_utils.GraphContext, self, vec): + return matmul(g, self, vec) + + +@_onnx_symbolic("aten::dot") +def dot(g: jit_utils.GraphContext, self, other): + return matmul(g, self, other) + + +@_onnx_symbolic("aten::movedim") +@symbolic_helper.parse_args("v", "t", "t") +def movedim(g: jit_utils.GraphContext, self, source, destination): + # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim + source = source.view(-1) + destination = destination.view(-1) + + assert source.size() == destination.size() + + if (source == destination).all(): + return self + + self_rank = symbolic_helper._get_tensor_rank(self) + assert self_rank is not None + + perm = list(range(self_rank)) + + src_dims = perm.copy() + dst_dims = perm.copy() + + for src, dst in zip(source.tolist(), destination.tolist()): + perm[dst] = src + src_dims[src] = -1 + dst_dims[dst] = -1 + + src_dims = [dim for dim in src_dims if dim != -1] + dst_dims = [dim for dim in dst_dims if dim != -1] + + for src, dst in zip(src_dims, dst_dims): + perm[dst] = src + + return g.op("Transpose", self, perm_i=perm) + + +@_onnx_symbolic("aten::fill") +@symbolic_helper.parse_args("v", "v") +def fill(g: jit_utils.GraphContext, self, value): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return full_like(g, self, value, scalar_type) + + +@_onnx_symbolic("aten::index_add") +def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): + warnings.warn( + "Warning: ONNX export does not support duplicated values in 'index' field, " + + "this will cause the ONNX model to be incorrect." + ) + + # ONNX does not support "alpha" argument, unlike aten index_add + # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + return symbolic_helper._unimplemented("index_add", "alpha != 1", self) + + dim = symbolic_helper._maybe_get_const(dim, "i") + if dim is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function with " + "unknown 'dim' value.", + self, + ) + + self_dim_rank = symbolic_helper._get_tensor_rank(self) + other_dim_rank = symbolic_helper._get_tensor_rank(other) + + if self_dim_rank is None or other_dim_rank is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function while " + "the rank of self tensor or tensor to be added is unknown.", + self, + ) + + if other_dim_rank != self_dim_rank: + delta = self_dim_rank - other_dim_rank + for i in range(delta): + other = symbolic_helper._unsqueeze_helper( + g, other, [symbolic_helper._get_tensor_rank(other)] + ) + + other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) + self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + + if (other_dim_size is not None) and (self_dim_size is not None): + if other_dim_size > self_dim_size: + raise errors.SymbolicValueError( + "ONNX export does not support exporting 'index_add_()' function with " + "duplicated values in 'index' parameter yet.", + self, + ) + + # Construct a new shape. It's almost as same as self except the size of the 'dim' + # dimension is 1, so that we can expand other dimensions as expected. + new_shape_axes = list(range(self_dim_rank)) + new_shape_starts = [0 for i in range(self_dim_rank)] + new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] + + new_shape = symbolic_helper._slice_helper( + g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends + ) + other = expand_as(g, other, new_shape) + + for i in range(dim): + index = symbolic_helper._unsqueeze_helper(g, index, [0]) + + for i in range(self_dim_rank - dim - 1): + index = symbolic_helper._unsqueeze_helper( + g, index, [symbolic_helper._get_tensor_rank(index)] + ) + + return scatter_add(g, self, dim, expand_as(g, index, other), other) + + +@_onnx_symbolic("aten::roll") +@symbolic_helper.parse_args("v", "is", "is") +def roll(g: jit_utils.GraphContext, self, shifts, dims): + assert len(shifts) == len(dims) + + result = self + for i in range(len(shifts)): + shapes = [] + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] + ) + shapes.append(shape) + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] + ) + shapes.append(shape) + result = g.op("Concat", *shapes, axis_i=dims[i]) + + return result + + +@_onnx_symbolic("aten::cross") +@symbolic_helper.parse_args("v", "v", "i") +def cross(g: jit_utils.GraphContext, input, other, dim=None): + dim = symbolic_helper._get_dim_for_cross(input, dim) + # If we have two tensors such that + # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have + # After first roll, + # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) + roll_x_1 = roll(g, input, [2], [dim]) + roll_y_1 = roll(g, other, [1], [dim]) + # After second roll, + # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) + roll_x_2 = roll(g, input, [1], [dim]) + roll_y_2 = roll(g, other, [2], [dim]) + # cross product is calculated as + # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] + return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) + + +@_onnx_symbolic("aten::cdist") +def cdist( + g: jit_utils.GraphContext, + x1, + x2, + p=2.0, + compute_mode="use_mm_for_euclid_dist_if_necessary", +): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # In order to respect numpy style broadcasting as demonstrated in + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + # we unsqueeze both input tensors + row_size_x1 = symbolic_helper._get_tensor_dim_size(x1, -2) + row_size_x2 = symbolic_helper._get_tensor_dim_size(x2, -2) + assert row_size_x1 is not None + assert row_size_x2 is not None + p_float = symbolic_helper._parse_arg(p, "f") + compute_mode = symbolic_helper._parse_arg(compute_mode, "i") + if p_float == 2.0 and ( + compute_mode == 1 + or (compute_mode is None and row_size_x1 >= 25 and row_size_x2 >= 25) + ): + return _euclidean_dist(g, x1, x2) + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) + broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) + return pairwise_distance( + g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False + ) + + +def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # using matrix multiplication to accelerate the calculation of + # the euclidean distance + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + x1_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x1_pad = ones_like(g, x1_norm) + x2_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x2_pad = ones_like(g, x2_norm) + x1_ = g.op( + "Concat", + *[ + mul(g, symbolic_helper._generate_wrapped_number(g, -2.0), x1), + x1_norm, + x1_pad, + ], + axis_i=-1, + ) + x2_ = g.op("Concat", *[x2, x2_pad, x2_norm], axis_i=-1) + result = matmul(g, x1_, transpose(g, x2_, -2, -1)) + dtype = _type_utils.JitScalarType.from_value(result) + min = g.op( + "Cast", symbolic_helper._generate_wrapped_number(g, 0.0), to_i=dtype.onnx_type() + ) + result = symbolic_helper._op_with_optional_float_cast( + g, "Max", result, min, opset_before=12 + ) + result = sqrt(g, result) + return result + + +@_onnx_symbolic("aten::lerp") +def lerp(g: jit_utils.GraphContext, self, end, weight): + # Conditional for better numeric. This has been discussed in + # https://github.com/pytorch/pytorch/pull/18871 + diff = g.op("Sub", end, self) + return where( + g, + g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), + g.op("Add", self, g.op("Mul", weight, diff)), + g.op( + "Sub", + end, + g.op( + "Mul", + diff, + g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), + ), + ), + ) + + +@_onnx_symbolic("aten::broadcast_tensors") +def broadcast_tensors(g: jit_utils.GraphContext, self): + all_tensors = symbolic_helper._unpack_list(self) + t_with_final_shape = zeros_like(g, all_tensors[0]) + + # Add operator supports multidirectional broadcasting. So we leverage this function + # to infer the final shape generated by the broadcast. + for t in all_tensors: + t_with_final_shape = add(g, t_with_final_shape, t) + + t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] + return g.op("prim::ListConstruct", *t_list) + + +@_onnx_symbolic("aten::is_pinned") +def is_pinned(g: jit_utils.GraphContext, self, device=None): + # Unused by ONNX. + return None + + +@_onnx_symbolic("prim::ConstantSplit") +def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "prim::ConstantSplit", "unknown dimension size", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) + + +# TODO: It would be better to export this as a chunk directly, as this is +# less sensitive to changes in input size. +# TODO: Once we have proper scoping, stop reimplementing chunk, delete this +# method, and use the desugared version +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + if dim_size is None: + return symbolic_helper._unimplemented( + "prim::ConstantChunk", "unknown dimension size", self + ) + split_size = (dim_size + chunks - 1) // chunks + return prim_constant_split(g, self, split_size, dim) + + +@_onnx_symbolic("prim::shape") +def prim_shape(g: jit_utils.GraphContext, self): + return g.op("Shape", self) + + +@_onnx_symbolic("prim::max") +def prim_max(g: jit_utils.GraphContext, self, other): + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, other, opset_before=12 + ) + + +@_onnx_symbolic("prim::min") +def prim_min(g: jit_utils.GraphContext, self, other=None): + if not other: + if symbolic_helper._is_packed_list(self): + self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) + return min(g, self) + return min(g, self, other) + + +@_onnx_symbolic("prim::data") +def prim_data(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::layout") +def prim_layout(g: jit_utils.GraphContext, self): + # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. + # Layout class defined in 'c10/core/Layout.h'. + return g.op("Constant", value_t=torch.tensor(0)) + + +@_onnx_symbolic("prim::ListConstruct") +def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::ListUnpack") +def prim_list_unpack( + g: jit_utils.GraphContext, *inputs, **kwargs +) -> list[_C.Value] | None: + if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": + # Cancel the previous node if it is ListConstruct by returning its inputs + # TODO(justinchuby): Use a public method in the helper module + return symbolic_helper._unpack_list(inputs[0]) + + return None + + +@_onnx_symbolic("prim::TupleConstruct") +def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::Uninitialized") +def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::dtype") +def prim_dtype(g: jit_utils.GraphContext, self): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + scalar_type = _type_utils.JitScalarType.FLOAT + # This node records a torch dtype as int + return g.op("Constant", value_t=torch.tensor(scalar_type)) + + +@_onnx_symbolic("prim::tolist") +def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): + """tolist is currently supported only for 1D input tensors. + + dim_val and elem_ty_val represent dimension and type annotations + that need to match dimension and type of the input tensor. + """ + dim = symbolic_helper._maybe_get_const(dim_val, "i") + if dim > 1: + return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) + return input + + +# ----------------------------------------------------------------------------- +# Symbolic functions that need extra context +# ----------------------------------------------------------------------------- +@_onnx_symbolic("prim::device") +def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: + output_type = g.original_node.output().type() + if isinstance(output_type, _C.DeviceObjType): + return None + + return symbolic_helper._unimplemented( + "prim::device", + f"output type should be 'DeviceObjType', not '{output_type.kind()}'", + g.original_node.output(), + ) + + +@_onnx_symbolic("prim::Loop") +def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + node = g.original_node + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + old_blocks = tuple(node.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + # Copy input metadata to subblock + # + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(old_block.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + # For optional block inputs, they may switch between None not-None inside + # the loop body, so if the loop input is not optional, the block input may + # still need to be optional. + if ( + i > 0 + and (i + 1) < len(inputs) + and not isinstance(b_in.type(), _C.OptionalType) + ): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for Loop after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::If") +def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + n = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + static_if = inputs[0].node().kind() == "onnx::Constant" + if static_if: + # Fold static if + # + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) + # + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() + const_value = ( + all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + ) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + env = torch._C._jit_pass_onnx_block( + current_b, + block, + operator_export_type, + env, + values_in_env, + True, + ) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise errors.SymbolicValueError( + f"The sub block ATen output {current_b_list[idx]} is not in env.", + current_b_list[idx], + ) # type:ignore[operator] + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + old_blocks = tuple(n.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for If after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::Constant") +def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + + if node.mustBeNone(): + return None + # This must go before checking for string values, because some device constants + # have string values, but we want to keep them as unconverted Device types so + # that eq() can work on them. + if isinstance(node.output().type(), _C.DeviceObjType): + return None + if node.kindOf("value") == "t": + return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) + if node.kindOf("value") == "s": + return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) + if node.output().type().isSubtypeOf( + _C.ListType.ofInts() + ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): + return g.op( + "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) + ) + if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): + str_constants = [ + g.op("Constant", value_s=s) + for s in symbolic_helper._node_get(node, "value") + ] + return g.op("prim::ListConstruct", *str_constants) + + raise errors.SymbolicValueError( + f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " + f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + node.output(), + ) + + +@_onnx_symbolic("prim::type") +def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): + if device_value.node().kind() == "prim::device": + device = jit_utils.get_device_from_value(device_value.node().input()) + if device is not None: + return g.op("Constant", value_s=str(device)) + + return symbolic_helper._unimplemented( + "prim::type", + "Device type cannot be statically determined.", + device_value, + ) + + +@_onnx_symbolic("onnx::Placeholder") +def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + + return torch._C._jit_onnx_convert_pattern_from_subblock( + block, node, env, values_in_env + ) + + +@_onnx_symbolic("aten::resolve_conj") +@_onnx_symbolic("aten::resolve_neg") +def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op + return input + + +@_onnx_symbolic("aten::_conj") +@_onnx_symbolic("aten::conj_physical") +def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex + if symbolic_helper.is_complex_value(input): + # FIXME(justinchuby): report correct name for symbolic being executed + return symbolic_helper._onnx_unsupported( + "aten::_conj, aten::conj_physical", + input, + ) + + # they can safely be implemented as no-op for real numbers only + return noop_complex_operators(g, input) + + +@_onnx_symbolic("aten::logit") +def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): + one = g.op("Constant", value_t=torch.tensor(1.0)) + + if not symbolic_helper._is_none(eps): + eps = g.op( + "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() + ) + one_sub_eps = g.op("Sub", one, eps) + self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) + temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) + + temporary_self_less_eps = g.op("Less", temporary_self, eps) + z = g.op("Where", temporary_self_less_eps, eps, temporary_self) + else: + z = self + + sub = g.op("Sub", one, z) + div = g.op("Div", z, sub) + return g.op("Log", div) diff --git a/torch/onnx/_internal/torchscript_exporter/utils.py b/torch/onnx/_internal/torchscript_exporter/utils.py new file mode 100644 index 000000000000..2a7339c27e08 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/utils.py @@ -0,0 +1,1930 @@ +# mypy: allow-untyped-defs +"""Functions to export models into the ONNX IR format. + +These models can be loaded with the ONNX library and then +converted to models which run on other deep learning frameworks. +""" + +from __future__ import annotations + + +__all__ = [ + "select_model_mode_for_export", + "disable_apex_o2_state_dict_hook", + "setup_onnx_logging", + "exporter_context", + "export", + "model_signature", + "warn_on_static_input_change", + "unpack_quantized_tensor", + "unconvertible_ops", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", + "_add_block", + "_add_input_to_block", + "_add_output_to_block", + "_apply_friendly_debug_names", + "_check_flatten_did_not_remove", + "_create_jit_graph", + "_decide_add_node_names", + "_decide_constant_folding", + "_decide_input_format", + "_decide_keep_init_as_input", + "_export", + "_get_aten_op_overload_name", + "_get_example_outputs", + "_get_module_attributes", + "_get_named_param_dict", + "_get_param_count_list", + "_is_constant_tensor_list", + "_model_to_graph", + "_optimize_graph", + "_pre_trace_quant_model", + "_reset_trace_module_map", + "_resolve_args_by_export_type", + "_run_symbolic_function", + "_run_symbolic_method", + "_set_input_and_output_names", + "_setup_trace_module_map", + "_should_aten_fallback", + "_signature", + "_split_tensor_list_constants", + "_trace_and_get_graph_from_model", + "_trace", + "_trigger_symbolic_function_registration", + "_validate_dynamic_axes", + "_verify_custom_op_name", +] + +import contextlib +import copy +import inspect +import re +import typing +import warnings +from typing import Any, Callable, cast +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.jit._trace +from torch import _C +from torch.onnx import _constants, errors +from torch.onnx._internal.torchscript_exporter import ( + jit_utils, + onnx_proto_utils, + registration, + symbolic_helper, +) +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + + +if typing.TYPE_CHECKING: + from collections.abc import Collection, Mapping, Sequence + + +# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp +# Skip check due to cannot import IValue from torch._C +_params_dict = {} # type: ignore[var-annotated] + + +@deprecated("Please set training mode before exporting the model", category=None) +@contextlib.contextmanager +def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, resetting it when we exit the with-block. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + + Args: + model: Same type and meaning as ``model`` arg to :func:`export`. + mode: Same type and meaning as ``training`` arg to :func:`export`. + """ + if not isinstance(mode, _C_onnx.TrainingMode): + raise TypeError( + f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." + ) + originally_training: bool = False + + if hasattr(model, "training"): + originally_training = model.training + + # ONNX opset 12 has better support for training amenable models, with updated + # versions of the dropout and batch_norm operators + if mode == _C_onnx.TrainingMode.TRAINING or ( + mode == _C_onnx.TrainingMode.PRESERVE and originally_training + ): + GLOBALS.export_training = True + if GLOBALS.export_onnx_opset_version < 12: + warnings.warn( + "You are exporting the model in training mode with onnx opset " + f"version {GLOBALS.export_onnx_opset_version}. " + "Opset versions lower than opset 12 will not be able to export " + "nodes such as Dropout and BatchNorm correctly." + ) + else: + GLOBALS.export_training = False + + GLOBALS.training_mode = mode + if mode == _C_onnx.TrainingMode.TRAINING: + model.train(True) + elif mode == _C_onnx.TrainingMode.EVAL: + model.train(False) + # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing + + try: + yield + finally: + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: + model.train(originally_training) + + +@deprecated( + "Please remove usage of this function. Copy its logic if it is required in user code", + category=None, +) +@contextlib.contextmanager +def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): + """A context manager to temporarily disable the Apex O2 hook that returns. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + # Apex O2 hook state_dict to return fp16 weights as fp32. + # Exporter cannot identify them as same tensors. + # Since this hook is only used by optimizer, it is safe to + # remove this hook while exporting. + if not isinstance(model, torch.jit.ScriptFunction): + model_hooks = {} # type: ignore[var-annotated] + for module in model.modules(): + for key, hook in module._state_dict_hooks.items(): + if type(hook).__name__ == "O2StateDictHook": + if module not in model_hooks: + model_hooks[module] = {} + model_hooks[module][key] = hook + if module in model_hooks: + for key in model_hooks[module]: + module._state_dict_hooks.pop(key) + try: + yield + finally: + # Add the hooks back + for module, m_map in model_hooks.items(): + for key, hook in m_map.items(): + module._state_dict_hooks[key] = hook + else: + try: + yield + finally: + pass + + +@deprecated("The feature will be removed. Please remove usage of this function") +@contextlib.contextmanager +def setup_onnx_logging(verbose: bool): + """A context manager to temporarily set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + is_originally_enabled = _C._jit_is_onnx_log_enabled + if is_originally_enabled or verbose: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(True) + try: + yield + finally: + if not is_originally_enabled: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(False) + + +@deprecated( + "The feature will be removed. Please remove usage of this function " + "and implement equivalent logic if needed", + category=None, +) +@contextlib.contextmanager +def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + """ + with ( + select_model_mode_for_export(model, mode) as mode_ctx, + disable_apex_o2_state_dict_hook(model) as apex_ctx, + setup_onnx_logging(verbose) as log_ctx, + ): + yield (mode_ctx, apex_ctx, log_ctx) + + +def export( + model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, + args: tuple[Any, ...] | torch.Tensor, + f: str, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool = False, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + opset_version: int | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool | None = None, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, +) -> None: + r"""Exports a model into ONNX format. + + If ``model`` is not a :class:`torch.jit.ScriptModule` nor a + :class:`torch.jit.ScriptFunction`, this runs + ``model`` once in order to convert it to a TorchScript graph to be exported + (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support + for dynamic control flow as :func:`torch.jit.trace`. + + Args: + model: The model to be exported. + args: + + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS:: + + args = (x, y, z) + + The tuple should contain model inputs such that ``model(*args)`` is a valid + invocation of the model. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + + 2. A TENSOR:: + + args = torch.Tensor([1]) + + This is equivalent to a 1-ary tuple of that Tensor. + + 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: + + args = (x, {"y": input_y, "z": input_z}) + + All but the last element of the tuple will be passed as non-keyword arguments, + and named arguments will be set from the last element. If a named argument is + not present in the dictionary, it is assigned the default value, or None if a + default value is not provided. + + .. warning:: + This behavior will be deprecated in a future release. Please use the + kwargs argument instead. + + .. note:: + If a dictionary is the last element of the args tuple, it will be + interpreted as containing named arguments. In order to pass a dict as the + last non-keyword arg, provide an empty dict as the last element of the args + tuple. For example, instead of:: + + torch.onnx.export( + model, + ( + x, + # WRONG: will be interpreted as named arguments + {y: z}, + ), + "test.onnx.pb", + ) + + Write:: + + torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") + + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Named arguments to the model. + export_params: If True, all parameters will + be exported. Set this to False if you want to export an untrained model. + In this case, the exported model will first take all of its parameters + as arguments, with the ordering as specified by ``model.state_dict().values()`` + verbose: if True, prints a description of the + model being exported to stdout. In addition, the final ONNX graph will include the + field ``doc_string``` from the exported model which mentions the source code locations + for ``model``. If True, ONNX exporter logging will be turned on. + training: + * ``TrainingMode.EVAL``: export the model in inference mode. + * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is + False and in training mode if model.training is True. + * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations + which might interfere with training. + input_names (list of str, default empty list): names to assign to the + input nodes of the graph, in order. + output_names (list of str, default empty list): names to assign to the + output nodes of the graph, in order. + operator_export_type (enum, default OperatorExportTypes.ONNX): + + .. warning:: + This option will be deprecated in a future release. Future exported + graphs will always use the default opset domain. + + * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops + (in the default opset domain). + * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops + to standard ONNX ops in the default opset domain. If unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting the op into a custom opset domain without conversion. Applies + to `custom ops `_ + as well as ATen ops. For the exported model to be usable, the runtime must support + these non-standard ops. + * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") + are exported as ATen ops (in opset domain "org.pytorch.aten"). + `ATen `_ is PyTorch's built-in tensor library, so + this instructs the runtime to use PyTorch's implementation of these ops. + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + This may be useful if the numeric differences in implementations of operators are + causing large differences in behavior between PyTorch and Caffe2 (which is more + common on untrained models). + + * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op + (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for + context. + For example:: + + graph(%0 : Float): + %3 : int = prim::Constant[value=0]() + # conversion unsupported + %4 : Float = aten::triu(%0, %3) + # conversion supported + %5 : Float = aten::mul(%4, %0) + return (%5) + + Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: + + graph(%0 : Float): + %1 : Long() = onnx::Constant[value={0}]() + # not converted + %2 : Float = aten::ATen[operator="triu"](%0, %1) + # converted + %3 : Float = onnx::Mul(%2, %0) + return (%3) + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + opset_version (int, default 18): The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + do_constant_folding: Apply the constant-folding optimization. + Constant-folding will replace some of the ops that have all constant inputs + with pre-computed constant nodes. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to parameters) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the non-parameter inputs are added as inputs. + This may allow for better optimizations (e.g. constant folding) by + backends/runtimes. + + If True, `deduplicate_initializers` pass will not be executed. This means + initializers with duplicated values will not be deduplicated and + will be treated as distinct inputs to the graph. This allows different + input initializers to be supplied at the runtime following export. + + If ``opset_version < 9``, initializers MUST be part of graph + inputs and this argument will be ignored and the behavior will be + equivalent to setting this argument to True. + + custom_opsets (dict[str, int], default empty dict): A dict with schema: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + + export_modules_as_functions: Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + + autograd_inlining: Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Raises: + :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. + :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it + uses an operator that is not supported by the exporter. + :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. + All errors are subclasses of :class:`errors.OnnxExporterError`. + """ + if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: + warnings.warn( + "Setting `operator_export_type` to something other than default is deprecated. " + "The option will be removed in a future release.", + category=DeprecationWarning, + ) + if training == _C_onnx.TrainingMode.TRAINING: + warnings.warn( + "Setting `training` to something other than default is deprecated. " + "The option will be removed in a future release. Please set the training mode " + "before exporting the model.", + category=DeprecationWarning, + ) + + args = (args,) if isinstance(args, torch.Tensor) else args + if kwargs is not None: + args = args + (kwargs,) + + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + + return None + + +def _is_constant_tensor_list(node): + if node.kind() != "prim::Constant": + return False + output_type = node.output().type() + if output_type.isSubtypeOf(_C.ListType.ofTensors()): + return True + if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): + return True + + +# ONNX can't handle constants that are lists of tensors, which can +# get generated in constant prop. So we split them back into prim::ListConstructs + + +def _split_tensor_list_constants(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _split_tensor_list_constants(g, subblock) + if _is_constant_tensor_list(node): + inputs = [] + for val in node.output().toIValue(): + input = g.insertConstant(val) + input.node().moveBefore(node) + input.node().copyMetadata(node) + inputs.append(input) + + lc = ( + g.create("prim::ListConstruct", inputs) + .insertBefore(node) + .output() + .setType(_C.ListType.ofTensors()) + ) + lc.node().copyMetadata(node) + node.output().replaceAllUsesWith(lc) + + +def _optimize_graph( + graph: _C.Graph, + operator_export_type: _C_onnx.OperatorExportTypes, + _disable_torch_constant_prop: bool = False, + fixed_batch_size: bool = False, + params_dict=None, + dynamic_axes=None, + input_names=None, + module=None, +): + if params_dict is None: + params_dict = {} + + # Inline everything + _C._jit_pass_inline(graph) + + # Remove fork/wait nodes + _C._jit_pass_inline_fork_wait(graph) + _C._jit_pass_lint(graph) + if GLOBALS.autograd_inlining: + _C._jit_pass_onnx_autograd_function_process(graph) + _C._jit_pass_lower_all_tuples(graph) + + # we now record some ops like ones/zeros + # into a trace where we previously recorded constants. + # use constant prop to maintain our current level of onnx support + # without implementing symbolics for all of them + if _disable_torch_constant_prop is False: + _C._jit_pass_constant_propagation(graph) + + _split_tensor_list_constants(graph, graph) + # run dce to eliminate dead parts of the graph that might have been + # left behind by things like symbolic_override + _C._jit_pass_dce(graph) + _C._jit_pass_lint(graph) + + # CSE should improve perf when Autocast is used with disabled cache + # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 + # Must run before _C._jit_pass_erase_number_types to prevent type substitution + if _C._jit_pass_cse(graph): + _C._jit_pass_onnx_lint(graph) + + _C._jit_pass_canonicalize_graph_fuser_ops(graph) + _C._jit_pass_lint(graph) + _C._jit_pass_peephole(graph, True) + _C._jit_pass_fuse_addmm(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_peephole(graph, True) + _C._jit_pass_lower_all_tuples(graph) + # in _jit_pass_onnx, symbolic functions are called for each node for conversion. + # However, there are nodes that cannot be converted without additional context. + # For example, the number of outputs from split (and whether it is static or dynamic) is unknown + # until the point where it is unpacked by listUnpack node. + # This pass does a preprocess, and prepares the nodes such that enough context can be received + # by the symbolic function. + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_onnx_preprocess(graph) + + # onnx does not support tuples, so try to remove them + _C._jit_pass_lint(graph) + + # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 + _C._jit_pass_prepare_division_for_onnx(graph) + + _C._jit_pass_onnx_remove_print(graph) + _C._jit_pass_onnx_preprocess_caffe2(graph) + + symbolic_helper._quantized_ops.clear() + # Unpack quantized weights for conv and linear ops and insert into graph. + _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) + # onnx only supports tensors, so we turn all out number types into tensors + _C._jit_pass_erase_number_types(graph) + if GLOBALS.onnx_shape_inference: + input_names = [] if input_names is None else input_names + dynamic_axes = {} if dynamic_axes is None else dynamic_axes + _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) + _C._jit_pass_onnx_lint(graph) + + graph = _C._jit_pass_onnx(graph, operator_export_type) + _C._jit_pass_onnx_lint(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_scalar_type_analysis( + graph, True, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_peephole( + graph, GLOBALS.export_onnx_opset_version, fixed_batch_size + ) + _C._jit_pass_lint(graph) + + # graph is not a valid jit graph anymore because types have been replaced + # (e.g. int with Tensor), so it now contains operators that don't actually + # exist. We can't run normal dead code elimination because it'd fail trying + # to look up if an operator has side effects, but we can run a dead code + # elimination variant that doesn't need to look up if an op has side effects. + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + _C._jit_pass_lint(graph) + graph = _C._jit_pass_canonicalize(graph) + _C._jit_pass_lint(graph) + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + return graph + + +def warn_on_static_input_change(input_states): + """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. + + We accept dictionaries and strings as ONNX inputs, but they should be only for + configuration use. we detect here if these inputs are modified, and if so we warn + the user that the changes won't take effect in the traced ONNX graph. + """ + for input, traced_input in zip(input_states[0], input_states[1]): + if isinstance(input, dict): + if list(input.keys()) != list(traced_input.keys()): + warning = ( + "We detected that you are modifying a dictionary that is an input to your " + "model. " + "Note that dictionaries are allowed as inputs in ONNX but they should be " + "handled with care. " + "Usages of dictionaries is not recommended, and should not be used except " + "for configuration use. " + "Also note that the order and values of the keys must remain the same. " + ) + warnings.warn(warning) + elif isinstance(input, str): + if input != traced_input: + warning = ( + "The model seems to have string inputs/outputs. " + "Note that strings will not appear as inputs/outputs of the ONNX graph. " + ) + warnings.warn(warning) + + +def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): + """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" + return arg_value + + +def _decide_keep_init_as_input( + keep_initializers_as_inputs: bool | None, + operator_export_type: _C_onnx.OperatorExportTypes, + opset_version: int, +): + """Decides whether the initializers in the graph should be listed as ONNX graph inputs. + + This method encapsulates the logic to decide whether the initializers in the graph + should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). + If keep_initializers_as_inputs is not specified (None), then we decide whether to keep + initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type + is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other + export types keep initializers as input (val_keep_init_as_ip=True). + If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, + in which case it must be ignored because for opset version <= 8, all initializers MUST be + part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. + + Special handling is needed for opset version 8 or lower, because irrespective + of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 + semantics, i.e. all initializers must be listed as ONNX graph input. + """ + + if opset_version < 9: + if keep_initializers_as_inputs is False: + warnings.warn( + "Setting 'keep_initializers_as_inputs=False' for opset version" + "8 or lower would lead to an invalid ONNX graph. Therefore, " + "'keep_initializers_as_inputs=False' is ignored during export." + "Exported model will have initializers as graph inputs (compliant " + " to ONNX IR v3)." + ) + return True # i.e. True == initializers are part of graph input (ONNX IR v3) + val_keep_init_as_ip = ( + True if keep_initializers_as_inputs is None else keep_initializers_as_inputs + ) + if ( + keep_initializers_as_inputs is None + and operator_export_type is _C_onnx.OperatorExportTypes.ONNX + ): + val_keep_init_as_ip = False + return val_keep_init_as_ip + + +def _decide_add_node_names(add_node_names, operator_export_type): + return _resolve_args_by_export_type( + "add_node_names", add_node_names, operator_export_type + ) + + +def _decide_constant_folding(do_constant_folding, operator_export_type, training): + do_constant_folding = _resolve_args_by_export_type( + "do_constant_folding", do_constant_folding, operator_export_type + ) + if do_constant_folding and ( + training is not None and training is not _C_onnx.TrainingMode.EVAL + ): + warnings.warn( + "It is recommended that constant folding be turned off ('do_constant_folding=False') " + "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " + "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " + "learnable model parameters may not translate correctly in the exported ONNX model " + "because constant folding mutates model parameters. Please consider " + "turning off constant folding or setting the training=TrainingMode.EVAL." + ) + return do_constant_folding + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") + + +def _decide_input_format(model, args): + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping _decide_input_format") + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == "self": + ordered_list_keys = ordered_list_keys[1:] + args_dict: dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], dict): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + warnings.warn("No input args, skipping _decide_input_format") + except Exception as e: + warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") + return args + + +def _trace(func, args, operator_export_type, return_outs=False): + # Special case for common case of passing a single Tensor + if isinstance(args, torch.Tensor): + args = (args,) + + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + func, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + warn_on_static_input_change(inputs_states) + + trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) + if return_outs: + return trace_graph, torch_out + return trace_graph + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + warn_on_static_input_change(inputs_states) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _check_flatten_did_not_remove(original, jit_flattened): + """torch.jit._flatten removes None. Check if it did so in this case.""" + + def flatten(x): + if isinstance(x, (list, tuple)): + for inner in x: + yield from flatten(inner) + elif isinstance(x, dict): + for inner in x.values(): + yield from flatten(inner) + else: + yield x + + flattened_with_none = list(flatten(original)) + num_none = len(flattened_with_none) - len(jit_flattened) + assert num_none >= 0 + if num_none: + raise ValueError( + f"args contained {num_none} None's after flattening. " + "When exporting a ScriptModule or ScriptFunction, no args may " + "be None because that breaks type propagation." + ) + + +def _create_jit_graph( + model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] +) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + _check_flatten_did_not_remove(args, flattened_args) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def _get_named_param_dict(graph, params): + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params) :] + _params_dict = dict(zip(param_names, params)) + return _params_dict + + +def _get_example_outputs(model, args): + input_args = copy.deepcopy(args) + input_kwargs = {} + if input_args and isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + + example_outputs = model(*input_args, **input_kwargs) + if isinstance(example_outputs, list): + example_outputs = [example_outputs] + elif not isinstance(example_outputs, tuple): + example_outputs = (example_outputs,) + + return example_outputs + + +_qtype_vtype_map = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.int8, +} + + +def unpack_quantized_tensor(value, cast_onnx_accepted=True): + if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: + q_value_dequantize = value.dequantize() + q_scale = ( + torch.tensor(value.q_scale(), dtype=torch.double) + if cast_onnx_accepted + else torch.tensor(value.q_scale(), dtype=torch.float32) + ) + q_zero_point = ( + torch.tensor(value.q_zero_point(), dtype=torch.int64) + if cast_onnx_accepted + else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) + ) + q_value = q_value_dequantize / q_scale + q_zero_point + q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) + return q_value, q_scale, q_zero_point + else: + return (value,) + + +def _pre_trace_quant_model(model, args): + r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return + original model. + + This is due to https://github.com/pytorch/pytorch/issues/75761. + """ + if any( + hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() + ) or any(getattr(arg, "is_quantized", False) for arg in args): + return torch.jit.trace(model, args) + return model + + +def _model_to_graph( + model, + args, + verbose=False, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + do_constant_folding=True, + _disable_torch_constant_prop=False, + fixed_batch_size=False, + training=_C_onnx.TrainingMode.EVAL, + dynamic_axes=None, +) -> tuple[ + _C.Graph, + dict[str, torch.Tensor], + torch.Tensor + | tuple[torch.Tensor, ...] + | list[torch.Tensor] + | dict[str, torch.Tensor] + | Any + | None, +]: + """Converts model into an ONNX graph. + + Returns: + graph: A TorchScript IR Graph with ONNX nodes. + params_dict: Dict from input param name to param value. + torch_out: The output tensors resulting from the trace of ``model``. + If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, + this will be None, since we are not doing any tracing. + """ + # TODO: can we simplify this to always return a tuple of Tensor or None? + + # Special case for common case of passing a single Tensor + if isinstance(args, (torch.Tensor, int, float, bool)): + args = (args,) + + model = _pre_trace_quant_model(model, args) + graph, params, torch_out, module = _create_jit_graph(model, args) + params_dict = _get_named_param_dict(graph, params) + + try: + graph = _optimize_graph( + graph, + operator_export_type, + _disable_torch_constant_prop=_disable_torch_constant_prop, + fixed_batch_size=fixed_batch_size, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + module=module, + ) + except Exception: + _C._jit_onnx_log("Torch IR graph at exception: ", graph) + raise + + is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) + if is_script: + example_outputs = _get_example_outputs(model, args) + example_outputs_final = () + for example_output in example_outputs: + example_outputs_final += unpack_quantized_tensor(example_output) + out_vars, desc = torch.jit._flatten(example_outputs_final) + _C._jit_pass_onnx_assign_output_shape( + graph, + out_vars, + desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + # NB: ONNX requires complete information about output types, which might be + # erased by some optimizations, so we need to set it explicitly again. + else: + if not isinstance(torch_out, (list, tuple)): + output_wrapped = [torch_out] + else: + output_wrapped = torch_out # type: ignore[assignment] + + output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) + # assign_output_shape pass is not compatible with quantized outputs. + # Quantized outputs are flattened to 3 values in ONNX, while packed as + # single value in PyTorch. + if not any(getattr(out, "is_quantized", False) for out in output_tensors): + _C._jit_pass_onnx_assign_output_shape( + graph, + output_tensors, + out_desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + _set_input_and_output_names(graph, input_names, output_names) + params_dict = _get_named_param_dict(graph, params) + + if ( + do_constant_folding + and GLOBALS.export_onnx_opset_version + >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) + + params_dict = _C._jit_pass_onnx_constant_fold( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if GLOBALS.export_onnx_opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + # If output names lack a proper name and are identified only by their unique + # give them a legible name for debugging purposes + _apply_friendly_debug_names(graph, params_dict) + + return graph, params_dict, torch_out + + +@deprecated( + "Unconvertible ops are not definitive. Please remove usage of this function" +) +def unconvertible_ops( + model, + args, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, +) -> tuple[_C.Graph, list[str]]: + """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. + + .. deprecated:: 2.5 + Unconvertible ops are not definitive. Please remove usage of this function. + + The list is approximated because some ops may be removed during the conversion + process and don't need to be converted. Some other ops may have partial support + that will fail conversion with particular inputs. Please open a Github Issue + for op support requests. + + Args: + model: Same as the `model` parameter in :func:`torch.onnx.export`. + args: Same as the `args` parameter in :func:`torch.onnx.export`. + training: Same as the `training` parameter in :func:`torch.onnx.export`. + opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. + + Returns: + The JIT graph and a list of unconvertible ops in the format of "domain::op". + """ + + opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET + GLOBALS.export_onnx_opset_version = opset_version + + try: + with exporter_context(model, training, verbose=False): + # Create a mostly clean JIT graph that contains the plain aten and + # other ops we can check with the symbolic registry. + # NOTE: We don't want to actually convert any ops to ONNX or run any + # symbolic functions because there is a higher chance that a pass + # fails or an unconvertible op messes up the graph during ONNX conversion. + # This way we can always generate a list just by looking at the names + # of the ops in the graph. + args = _decide_input_format(model, args) + model = _pre_trace_quant_model(model, args) + graph, _, _, module = _create_jit_graph(model, args) + _C._jit_pass_inline(graph) + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_erase_number_types(graph) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + except Exception as e: + raise errors.OnnxExporterError( + "Failed to discover unconvertible ops because of errors during the JIT graph " + "generation process." + ) from e + + unsupported_ops = [] + for node in graph.nodes(): + domain_op = node.kind() + if domain_op.startswith(("onnx::", "prim::")): + # We consider onnx and prim ops as supported ops, even though some "prim" + # ops are not implemented as symbolic functions, because they may be + # eliminated in the conversion passes. Users may still see errors caused + # by prim ops even though they don't show up in the list. + continue + if not registration.registry.is_registered_op( + domain_op.rstrip("_"), opset_version + ): + # We consider all registered ops supported, even though some of them are + # only partially supported, because there is not yet a good way to check + # if an op is fully supported. + # TODO(justinchuby): Create a way to check if an op is fully supported. + unsupported_ops.append(domain_op) + return graph, unsupported_ops + + +def _setup_trace_module_map( + model: torch.nn.Module | torch.jit.ScriptModule, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]], +) -> set[str]: + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + + Example: + >>> _unqualified_variable_name("__main__.Foo.bar") + 'bar' + >>> _unqualified_variable_name("__main__.Foo.bar.0") + 'bar.0' + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name( + torch.typename(type(_m)), _unqualified_variable_name(_n) + ) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _reset_trace_module_map(): + torch.jit._trace._trace_module_map = None + _C._jit_pass_onnx_clear_scope_records() + + +def _get_module_attributes(module): + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + # Check whether module attributes can be accessed. Some classes + # define attributes but don't provide access to them in their + # constructor. + # + # For example, torch.nn.Embedding has the `freeze` variable and its + # type specified in the class but the attribute is not created in the + # constructor. In other words, there is no `self.freeze = ` + # in the constructor. + # + # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def _trigger_symbolic_function_registration(): + """Trigger the registration of symbolic functions for all supported opsets.""" + + from torch.onnx._internal.torchscript_exporter import ( # noqa: F401 + symbolic_opset10, + symbolic_opset11, + symbolic_opset12, + symbolic_opset13, + symbolic_opset14, + symbolic_opset15, + symbolic_opset16, + symbolic_opset17, + symbolic_opset18, + symbolic_opset19, + symbolic_opset20, + symbolic_opset7, + symbolic_opset8, + symbolic_opset9, + ) + + +def _export( + model, + args, + f, + export_params=True, + verbose=False, + training=_C_onnx.TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + export_type=None, + opset_version=None, + do_constant_folding=True, + dynamic_axes=None, + keep_initializers_as_inputs=None, + fixed_batch_size=False, + custom_opsets=None, + add_node_names=True, + onnx_shape_inference=True, + export_modules_as_functions: Any = False, + autograd_inlining=True, +): + assert GLOBALS.in_onnx_export is False + + _trigger_symbolic_function_registration() + + if isinstance(model, torch.nn.DataParallel): + raise ValueError( + "torch.nn.DataParallel is not supported by ONNX " + "exporter, please use 'attribute' module to " + "unwrap model from torch.nn.DataParallel. Try " + "torch.onnx.export(model.module, ...)" + ) + + GLOBALS.onnx_shape_inference = onnx_shape_inference + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: + warnings.warn( + f"Exporting to ONNX opset version {opset_version} is not supported. " + f"by 'torch.onnx.export()'. " + f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " + f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", + category=errors.OnnxExporterWarning, + ) + + if export_modules_as_functions and opset_version < 15: + raise ValueError( + "`export_modules_as_functions` is not supported for `opset_version` < 15." + "This is because `opset_version` < 15 implies IR version < 8, which means " + "no local function support. " + ) + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + # By default, training=TrainingMode.EVAL, + # which is good because running a model in training mode could result in + # internal buffers getting updated, dropout getting applied, etc. + # If you really know what you're doing, you can turn + # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, + # (to preserve whatever the original training mode was.) + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + try: + GLOBALS.in_onnx_export = True + _autograd_inlining_previous = GLOBALS.autograd_inlining + GLOBALS.autograd_inlining = autograd_inlining + + module_typenames_to_export_as_functions: set[str] = set() + if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): + module_typenames_to_export_as_functions = _setup_trace_module_map( + model, export_modules_as_functions + ) + + with exporter_context(model, training, verbose): + val_keep_init_as_ip = _decide_keep_init_as_input( + keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = _decide_add_node_names( + add_node_names, operator_export_type + ) + val_do_constant_folding = _decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + # Normally f can be a file-like object, but for large models, the external data format requires a + # valid `model_file_location`. Code in export.cpp will enforce this. + if isinstance(f, str): + model_file_location = f + else: + model_file_location = "" + args = _decide_input_format(model, args) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + graph, params_dict, torch_out = _model_to_graph( + model, + args, + verbose, + input_names, + output_names, + operator_export_type, + val_do_constant_folding, + fixed_batch_size=fixed_batch_size, + training=training, + dynamic_axes=dynamic_axes, + ) + + if custom_opsets is None: + custom_opsets = {} + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + node_attr_to_name = {} # type: ignore[var-annotated] + if module_typenames_to_export_as_functions: + # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. + node_attr_to_name = _C._jit_pass_onnx_function_extraction( + graph, + module_typenames_to_export_as_functions, + list(params_dict.keys()), + ) + + if keep_initializers_as_inputs is not True: + params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] + graph, + params_dict, # type: ignore[arg-type] + getattr(model, "training", False), # type: ignore[arg-type] + ) + _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + defer_weight_export = False + if export_params: + ( + proto, + export_map, + _val_use_external_data_format, + _node_names, + ) = graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + else: + ( + proto, + export_map, + _, + _, + ) = graph._export_onnx( # type: ignore[attr-defined] + {}, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + # insert function_proto into model_proto. + proto = onnx_proto_utils._add_onnxscript_fn( + proto, + custom_opsets, + ) + if verbose: + _C._jit_onnx_log("Exported graph: ", graph) + onnx_proto_utils._export_file(proto, f, export_map) + finally: + assert GLOBALS.in_onnx_export + GLOBALS.in_onnx_export = False + GLOBALS.autograd_inlining = _autograd_inlining_previous + _reset_trace_module_map() + + return torch_out + + +def _apply_friendly_debug_names(graph, params): + for n in graph.nodes(): + for v in n.inputs(): + old_name = v.debugName() + if old_name != str(v.unique()): + continue + new_name = f"{n.kind()}_{v.unique()}" + v.setDebugName(new_name) + if old_name in params: + params[new_name] = params.pop(old_name) + + +def _set_input_and_output_names(graph, input_names, output_names): + def set_names(node_list, name_list, descriptor): + if name_list is None: + return + if len(name_list) > len(node_list): + raise RuntimeError( + f"number of {descriptor} names provided ({len(name_list)}) " + f"exceeded number of {descriptor}s ({len(node_list)})" + ) + + # Mark if the output node DebugName is set before. + output_node_set = set() + for i, (name, node) in enumerate(zip(name_list, node_list)): + # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). + if descriptor == "output": + if node in output_node_set: + identity_node = graph.create("onnx::Identity") + identity_node.insertAfter(node.node()) + identity_node.addInput(node) + identity_node.output().setType(node.type()) + graph.return_node().replaceInput(i, identity_node.output()) + node = identity_node.output() + output_node_set.add(node) + + if node.debugName() != name: + node.setDebugName(name) + + set_names(list(graph.inputs()), input_names, "input") + set_names(list(graph.outputs()), output_names, "output") + + +def _run_symbolic_method(g, op_name, symbolic_fn, args): + r""" + This trampoline function gets invoked for every symbolic method + call from C++. + """ + try: + graph_context = jit_utils.GraphContext( + graph=g, + block=g.block(), + opset=GLOBALS.export_onnx_opset_version, + original_node=None, # type: ignore[arg-type] + params_dict=_params_dict, + env={}, + values_in_env=set(), + new_nodes=[], + ) + return symbolic_fn(graph_context, *args) + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch + # to symbolic_fn. Otherwise, the backtrace will have the clues + # you need. + e.args = (f"{e.args[0]} (occurred when translating {op_name})",) + raise + + +def _add_block(node: _C.Node) -> _C.Block: + return node.addBlock() + + +def _add_input_to_block(block: _C.Block): + return block.addInputToBlock() # type: ignore[attr-defined] + + +def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: + return block.registerOutput(value) + + +def _should_aten_fallback( + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes +): + # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) + is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN + is_aten_fallback_export = ( + operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + ) + + if not name.startswith("aten::"): + return False + + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True + + return False + + +def _get_aten_op_overload_name(n: _C.Node) -> str: + # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds + schema = n.schema() + if not schema.startswith("aten::"): + return "" + return _C.parse_schema(schema).overload_name + + +def _run_symbolic_function( + graph: _C.Graph, + block: _C.Block, + node: _C.Node, + inputs: Any, + env: dict[_C.Value, _C.Value], + values_in_env: set[_C.Value], + new_nodes: list[_C.Node], + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, +) -> _C.Value | Sequence[_C.Value | None] | None: + """Runs a symbolic function. + + The function is used in C++ to export the node to ONNX. + + Returns: + A single or a tuple of Values. + None when the node gets cloned as is into the new graph. + """ + + opset_version = GLOBALS.export_onnx_opset_version + + # See Note [Export inplace] + node_kind = node.kind() + if node_kind.endswith("_"): + # Treat relu_ -> relu; add_ -> add etc. + ns_op_name = node_kind[:-1] + else: + ns_op_name = node_kind + + namespace, op_name = jit_utils.parse_node_kind(ns_op_name) + + graph_context = jit_utils.GraphContext( + graph=graph, + block=block, + opset=opset_version, + original_node=node, + params_dict=_params_dict, + env=env, + values_in_env=values_in_env, + new_nodes=new_nodes, + ) + + # Direct ATen export requested + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + + try: + domain = namespace + symbolic_function_name = f"{domain}::{op_name}" + + symbolic_function_group = registration.registry.get_function_group( + symbolic_function_name + ) + if symbolic_function_group is not None: + symbolic_fn = symbolic_function_group.get(opset_version) + if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. + attrs = { + k: symbolic_helper._node_get(node, k) for k in node.attributeNames() + } + return symbolic_fn(graph_context, *inputs, **attrs) + + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + if namespace == "onnx": + # Clone node to trigger ONNX shape inference + return graph_context.op( + op_name, *inputs, **attrs, outputs=node.outputsSize() + ) # type: ignore[attr-defined] + + raise errors.UnsupportedOperatorError( + symbolic_function_name, + opset_version, + symbolic_function_group.get_min_supported() + if symbolic_function_group + else None, + ) + + except RuntimeError: + if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: + return None + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + raise + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch. + # Otherwise, the backtrace will have the clues you need. + e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) + raise + + +def _verify_custom_op_name(symbolic_name: str): + if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): + raise errors.OnnxExporterError( + f"Failed to register operator {symbolic_name}. " + "The symbolic name must match the format domain::name, " + "and should start with a letter and contain only " + "alphanumerical characters" + ) + + ns, _ = jit_utils.parse_node_kind(symbolic_name) + if ns == "onnx": + raise ValueError( + f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." + ) + + +def register_custom_op_symbolic( + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, +): + """Registers a symbolic function for a custom operator. + + When the user registers symbolic for custom/contrib ops, + it is highly recommended to add shape inference for that operator via setType API, + otherwise the exported graph may have incorrect shape inference in some extreme cases. + An example of setType is `test_aten_embedding_2` in `test_operators.py`. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + symbolic_fn (Callable): A function that takes in the ONNX graph and + the input arguments to the current operator, and returns new + operator nodes to add to the graph. + opset_version (int): The ONNX opset version in which to register. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) + + +def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): + """Unregisters ``symbolic_name``. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + opset_version (int): The ONNX opset version in which to unregister. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.registry.unregister(symbolic_name, opset_version) + + +def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): + """Ensures dynamic axes argument is follows the expected format.""" + if len(dynamic_axes) == 0: + return + + if hasattr(model, "graph"): + # Extracting set of valid input/output names that shall be used for dynamic_axes + if (input_names is None) or len(input_names) == 0: + input_names = [x.debugName() for x in model.graph.inputs()] + if (output_names is None) or len(output_names) == 0: + output_names = [y.debugName() for y in model.graph.outputs()] + + valid_names = set((input_names or []) + (output_names or [])) + + # If dynamic axes are provided as a list rather than dictionary, they should + # first get converted to a dictionary in expected format. If desired axes names + # are not provided for dynamic axes, automatic names shall be generated for + # provided dynamic axes of specified input/output + for key, value in dynamic_axes.items(): + if key not in valid_names: + warnings.warn( + f"Provided key {key} for dynamic axes is not a valid input/output name" + ) + if isinstance(value, list): + warnings.warn( + "No names were found for specified dynamic axes of provided input." + f"Automatically generated names will be applied to each dynamic axes of input {key}" + ) + + value_dict = {} + for i, x in enumerate(value): + if not isinstance(x, int): + raise ValueError( + "The type of axis index is expected to be an integer" + ) + if x in value_dict: + warnings.warn( + f"Duplicate dynamic axis index {x} was provided for input {key}." + ) + else: + value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) + dynamic_axes[key] = value_dict + + +def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: + return inspect.signature( + model.forward if isinstance(model, torch.nn.Module) else model + ) diff --git a/torch/onnx/_internal/torchscript_exporter/verification.py b/torch/onnx/_internal/torchscript_exporter/verification.py new file mode 100644 index 000000000000..9cea8763b817 --- /dev/null +++ b/torch/onnx/_internal/torchscript_exporter/verification.py @@ -0,0 +1,1863 @@ +# mypy: allow-untyped-defs +"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models.""" + +from __future__ import annotations + +from torch.onnx._internal.torchscript_exporter import _experimental + + +__all__ = [ + "OnnxBackend", + "VerificationOptions", + "verify", + "check_export_model_diff", + "GraphInfo", + "GraphInfoPrettyPrinter", + "OnnxTestCaseRepro", + "find_mismatch", + "verify_aten_graph", +] + +import contextlib +import copy +import dataclasses +import datetime +import difflib +import enum +import functools +import io +import itertools +import os +import tempfile +import typing_extensions +import warnings +from collections.abc import Collection, Mapping, Sequence +from typing import Any, Callable, Union + +import numpy as np +import numpy.typing as npt + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants +from torch.onnx._internal.torchscript_exporter import onnx_proto_utils, utils +from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +from torch.types import Number + + +# Everything below are deprecated ############################################## + +_ORT_PROVIDERS = ("CPUExecutionProvider",) + +_NumericType = Union[Number, torch.Tensor, np.ndarray] +_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] +_InputArgsType = Union[torch.Tensor, tuple[Any, ...]] +_InputKwargsType = Mapping[str, Any] +_OutputsType = Union[Sequence[_NumericType], Sequence] + + +class OnnxBackend(enum.Enum): + """Enum class for ONNX backend used for export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + REFERENCE = "ONNXReferenceEvaluator" + ONNX_RUNTIME_CPU = "CPUExecutionProvider" + ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" + + +@dataclasses.dataclass +class VerificationOptions: + """Options for ONNX export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Attributes: + flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of + Tensors for ONNX. Set this to False if nested structures are to be preserved + for ONNX, which is usually the case with exporting ScriptModules. Default True. + ignore_none: Whether to ignore None type in torch output, which is usually the + case with tracing. Set this to False, if torch output should keep None type, + which is usually the case with exporting ScriptModules. Default to True. + check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs + are exactly the same. Set this to False to allow output shape broadcasting. + Default to True. + check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs + are consistent. Default to True. + backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. + rtol: relative tolerance in comparison between ONNX and PyTorch outputs. + atol: absolute tolerance in comparison between ONNX and PyTorch outputs. + remained_onnx_input_idx: If provided, only the specified inputs will be passed + to the ONNX model. Supply a list when there are unused inputs in the model. + Since unused inputs will be removed in the exported ONNX model, supplying + all inputs will cause an error on unexpected inputs. This parameter tells + the verifier which inputs to pass into the ONNX model. + acceptable_error_percentage: acceptable percentage of element mismatches in comparison. + It should be a float of value between 0.0 and 1.0. + """ + + flatten: bool = True + ignore_none: bool = True + check_shape: bool = True + check_dtype: bool = True + backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU + rtol: float = 1e-3 + atol: float = 1e-7 + remained_onnx_input_idx: Sequence[int] | None = None + acceptable_error_percentage: float | None = None + + +def _flatten_tuples(elem): + flattened = [] + for t in elem: + if isinstance(t, tuple): + flattened.extend(_flatten_tuples(t)) + else: + flattened.append(t) + return flattened + + +# TODO(justinchuby): Add type checking by narrowing down the return type when input is None +def _to_numpy(elem) -> list | npt.NDArray: + if isinstance(elem, torch.Tensor): + if elem.requires_grad: + return elem.detach().cpu().numpy() + else: + return elem.cpu().numpy() + elif isinstance(elem, (list, tuple)): + return [_to_numpy(inp) for inp in elem] + elif isinstance(elem, (bool, int, float)): + return np.array(elem) + elif isinstance(elem, dict): + flattened = [] + for k in elem: + flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) + return flattened + return elem + + +def _inline_flatten_list(inputs, res_list) -> list: + for i in inputs: + res_list.append(i) if not isinstance( + i, (list, tuple) + ) else _inline_flatten_list(i, res_list) + return res_list + + +def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: + value_unpacked = [] + for value in values: + value_unpacked.extend( + utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) + ) + return [_to_numpy(v) for v in value_unpacked] + + +def _run_onnx(onnx_session, inputs) -> _OutputsType: + kw_inputs = {} + if inputs and isinstance(inputs[-1], dict): + kw_inputs = inputs[-1] + inputs = inputs[:-1] + inputs = _unpack_to_numpy(_flatten_tuples(inputs)) + ort_inputs = {} + for input_name, input in kw_inputs.items(): + ort_inputs[input_name] = _to_numpy(input) + inputs = _to_numpy(inputs) + if hasattr(onnx_session, "get_inputs"): + # onnxruntime.InferenceSession + input_names = [i.name for i in onnx_session.get_inputs()] + elif hasattr(onnx_session, "input_names"): + # onnx.reference.ReferenceEvaluator + input_names = onnx_session.input_names + else: + raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") + + for i, input in enumerate(inputs): + if i == len(input_names) or input_names[i] in ort_inputs: + raise ValueError( + f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " + f"input names: {input_names}." + ) + ort_inputs[input_names[i]] = input + onnx_outs = onnx_session.run(None, ort_inputs) + return onnx_outs + + +def _ort_session( + model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS +): + try: + import onnxruntime # type: ignore[import] + except ImportError as e: + raise ImportError("onnxruntime is required for export verification.") from e + + if ort_providers is None: + ort_providers = _ORT_PROVIDERS + + session_options = onnxruntime.SessionOptions() + # suppress ort warnings. + # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + session_options.log_severity_level = 3 + ort_session = onnxruntime.InferenceSession( + model if isinstance(model, str) else model.getvalue(), + session_options, + providers=ort_providers, + ) + return ort_session + + +def _onnx_reference_evaluator_session(model: str | io.BytesIO): + try: + import onnx + from onnx import reference as onnx_reference # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc + + proto = ( + onnx.load(model) # type: ignore[attr-defined] + if isinstance(model, str) + else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] + ) + onnx_session = onnx_reference.ReferenceEvaluator(proto) + return onnx_session + + +def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): + if backend == OnnxBackend.REFERENCE: + onnx_session = _onnx_reference_evaluator_session(model) + elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: + onnx_session = _ort_session(model, (backend.value,)) + else: + raise ValueError(f"Unsupported backend: {backend}") + return onnx_session + + +def _compare_onnx_pytorch_outputs_in_np( + onnx_outs: _OutputsType, + pt_outs: _OutputsType, + options: VerificationOptions, +): + assert len(onnx_outs) == len(pt_outs), ( + f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + ) + acceptable_error_percentage = options.acceptable_error_percentage + if acceptable_error_percentage and ( + acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 + ): + raise ValueError( + "If set, acceptable_error_percentage should be between 0.0 and 1.0" + ) + + for ort_out, pt_out in zip(onnx_outs, pt_outs): + try: + # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. + if not options.check_shape: + # Allow different but broadcastable output shapes. + ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) + torch.testing.assert_close( + ort_out, + pt_out, + rtol=options.rtol, + atol=options.atol, + check_dtype=options.check_dtype, + equal_nan=True, + ) + except AssertionError as e: + if acceptable_error_percentage: + error_percentage = 1 - np.sum( + np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) + ) / np.prod(ort_out.shape) + if error_percentage <= acceptable_error_percentage: + warnings.warn( + f"Suppressed AssertionError:\n{e}.\n" + f"Error percentage {error_percentage} " + f"within acceptable range {acceptable_error_percentage}." + ) + continue + if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: + warnings.warn("ONNX output is quantized") + if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: + warnings.warn("PyTorch output is quantized") + raise + + +def _compare_onnx_pytorch_outputs( + onnx_outs: _OutputsType, + pt_outs: Any, + options: VerificationOptions, +): + """ + Compare ONNX and PyTorch outputs. + + Args: + onnx_outs: outputs from ONNX backend. + pt_outs: outputs from PyTorch. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options.ignore_none: + # torch.jit._flatten filters None type + pt_outs, _ = torch.jit._flatten(pt_outs) + else: + pt_outs = _inline_flatten_list([pt_outs], []) + pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) + onnx_outs = _inline_flatten_list(onnx_outs, []) + _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) + + +def _prepare_input_for_pytorch(args, kwargs): + """Prepare input for PyTorch model execution. + + Any future changes/formatting to the input before dispatching to the PyTorch + model should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + """ + if isinstance(args, (torch.Tensor, dict)): + args = (args,) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + args = copy.deepcopy(args) + if kwargs: + kwargs = copy.deepcopy(kwargs) + else: + kwargs = {} + return args, kwargs + + +def _prepare_input_for_export(args, kwargs): + """Prepare input for ONNX model export. + + Any future changes/formatting to the input before dispatching to the + :func:`torch.onnx.export` api should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + onnx_inputs: positional arguments for ONNX model export, as `args` in + :func:`torch.onnx.export`. + """ + args, kwargs = _prepare_input_for_pytorch(args, kwargs) + if not kwargs and len(args) > 0 and isinstance(args[-1], dict): + onnx_inputs = args + ({},) + elif kwargs: + onnx_inputs = args + (kwargs,) + else: + onnx_inputs = args + return onnx_inputs + + +def _prepare_input_for_onnx( + args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool +): + """Prepare input for ONNX model execution in ONNX backend. + + Any future changes/formatting to the input before dispatching to the ONNX backend + run should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. + flatten: whether to flatten the input before dispatching to the ONNX model execution. + + Returns: + onnx_inputs: positional arguments for ONNX model execution in ONNX backend. + """ + onnx_inputs = _prepare_input_for_export(args, kwargs) + if flatten: + onnx_inputs, _ = torch.jit._flatten(onnx_inputs) + elif onnx_inputs and onnx_inputs[-1] == {}: + # Handle empty kwargs (normally removed by flatten). + onnx_inputs = onnx_inputs[:-1] + if remained_onnx_input_idx is not None: + return [onnx_inputs[i] for i in remained_onnx_input_idx] + else: + return onnx_inputs + + +def _try_clone_model(model): + """Used for preserving original model in case forward mutates model states.""" + try: + return copy.deepcopy(model) + except Exception: + warnings.warn( + "Failed to clone model. Model state might be mutated during verification." + ) + return model + + +def _compare_onnx_pytorch_model( + pt_model: _ModelType, + onnx_model_f: str | io.BytesIO, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None, + additional_test_inputs: Sequence[_InputArgsType] | None, + options: VerificationOptions, +): + """Compare outputs from ONNX model runs with outputs from PyTorch model runs. + + Args: + pt_model: PyTorch model. + onnx_model_f: ONNX model file path or file-like object. + input_args: positional arguments for PyTorch model forward method. + input_kwargs: keyword arguments for PyTorch model forward method. + additional_test_inputs: additional positional arguments for PyTorch model + forward method. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(onnx_model_f, options.backend) + + def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): + pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) + # TODO: remove this and treat mutating model separately. See #77679 + pt_model_copy = _try_clone_model(pt_model) + pt_outs = pt_model_copy(*pt_args, **pt_kwargs) + + onnx_inputs = _prepare_input_for_onnx( + input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten + ) + + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=pt_outs, + options=options, + ) + + compare_onnx_pytorch_model_with_input(input_args, input_kwargs) + + if additional_test_inputs: + for test_input_args in additional_test_inputs: + compare_onnx_pytorch_model_with_input(test_input_args, {}) + + +class _GraphDiff: + """A class to represent the difference between two graphs.""" + + def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): + """Construct a _GraphDiff object. + + Args: + graph_a (_C.Graph): First graph to compare. + graph_b (_C.Graph): Second graph to compare. + """ + self.graph_a = graph_a + self.graph_b = graph_b + + def __str__(self): + """See function :func:`diff_report`.""" + return self.diff_report() + + def _indent(self, lines: str) -> str: + return "\n".join(["\t" + line for line in lines.splitlines()]) + + def diff_report(self) -> str: + """Return a string representation of the graph difference. + + The report shows the first pair of nodes that diverges. It also shows the source + location of the pair of nodes. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + graph_a = self.graph_a + graph_b = self.graph_b + + graph_a_str = str(graph_a) + graph_b_str = str(graph_b) + + if graph_a_str == graph_b_str: + return "" + + graph_diff = difflib.ndiff( + graph_a_str.splitlines(True), graph_b_str.splitlines(True) + ) + graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] + + for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): + if str(node_a) != str(node_b): + graph_diff_report.append("First diverging operator:") + node_diff = difflib.ndiff( + str(node_a).splitlines(True), str(node_b).splitlines(True) + ) + source_printout = ["node diff:", self._indent("".join(node_diff))] + + stack_a = node_a.sourceRange() if node_a else None + if stack_a: + source_printout.extend( + ["Former source location:", self._indent(str(stack_a))] + ) + stack_b = node_b.sourceRange() if node_b else None + if stack_b: + source_printout.extend( + ["Latter source location:", self._indent(str(stack_b))] + ) + + graph_diff_report.extend(source_printout) + + break + + return "\n".join(graph_diff_report) + + +def _check_graph_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions, + model_to_graph_func: Callable[ + [ + torch.nn.Module, + tuple[Any, ...], + Mapping[str, Any], + _experimental.ExportOptions, + ], + _C.Graph, + ], +) -> str: + """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. + + Args: + model: See :func:`check_export_model_diff`. + test_input_groups: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + if len(test_input_groups) < 2: + raise ValueError("Need at least two groups of test inputs to compare.") + + ref_jit_graph = None + for args, kwargs in test_input_groups: + jit_graph = model_to_graph_func(model, args, kwargs, export_options) + if ref_jit_graph is None: + ref_jit_graph = jit_graph + continue + + graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() + if graph_diff_report: + return graph_diff_report + return "" + + +def _traced_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + jit_graph (_C.Graph): A traced JIT graph. + """ + training = export_options.training + verbose = export_options.verbose + + with utils.exporter_context(model, training, verbose): + export_inputs = _prepare_input_for_export(args, kwargs) + model = utils._pre_trace_quant_model(model, export_inputs) + jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) + return jit_graph + + +def _onnx_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + onnx_graph (_C.Graph): An ONNX JIT graph. + """ + # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 + opset_version = export_options.opset_version + operator_export_type = export_options.operator_export_type + export_modules_as_functions = export_options.export_modules_as_functions + training = export_options.training + verbose = export_options.verbose + dynamic_axes = export_options.dynamic_axes + input_names = export_options.input_names + output_names = export_options.output_names + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + utils._setup_trace_module_map(model, export_modules_as_functions) + + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + with utils.exporter_context(model, training, verbose): + do_constant_folding = utils._decide_constant_folding( + export_options.do_constant_folding, operator_export_type, training + ) + + if dynamic_axes is None: + dynamic_axes = {} + utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + export_inputs = _prepare_input_for_export(args, kwargs) + export_inputs = utils._decide_input_format(model, export_inputs) + onnx_graph, _, _ = utils._model_to_graph( + model, + export_inputs, + verbose, + input_names, + output_names, + operator_export_type, + do_constant_folding, + training=training, + dynamic_axes=dynamic_axes, + ) + + return onnx_graph + + +def _onnx_graph_from_aten_graph( + graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, +) -> tuple[torch.Graph, dict[str, Any]]: + if params_dict is None: + params_dict = {} + operator_export_type = export_options.operator_export_type + dynamic_axes = export_options.dynamic_axes or {} + input_names = export_options.input_names + training = export_options.training + do_constant_folding = export_options.do_constant_folding + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + do_constant_folding = utils._decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + + # TODO: Below is doing aten graph to onnx. It should be abstracted as a + # function in torch/onnx/utils.py. + graph = graph.copy() + graph = utils._optimize_graph( + graph, + operator_export_type, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + ) + + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) + + if ( + do_constant_folding + and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if export_options.verbose: + print("ONNX graph: ", graph) + + return graph, params_dict + + +def _onnx_proto_from_onnx_graph( + onnx_graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any], +) -> tuple[bytes, Mapping[str, bytes]]: + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + dynamic_axes = export_options.dynamic_axes or {} + operator_export_type = export_options.operator_export_type + val_keep_init_as_ip = utils._decide_keep_init_as_input( + export_options.keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = utils._decide_add_node_names(True, operator_export_type) + custom_opsets = export_options.custom_opsets or {} + + proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + False, + operator_export_type, + not export_options.verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + "", + {}, + ) + + return proto, export_map + + +def check_export_model_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions | None = None, +) -> str: + """Verify exported model discrepancy between different groups of inputs. + + A graph is exported for each group of inputs. The exported graphs are then compared + to each other, and discrepancies of first pair of nodes are reported. This function + first checks the jit graph. If no discrepancies were found, it then checks the onnx + graph. + + Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless + of the inputs used for exporting. A discrepancy implies the graph exported is + not accurate when run on other groups of inputs, which will typically results in + runtime errors or mismatching output. + + Args: + model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. + test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence + of input groups to be used to export the model. Each input group is a pair of + (args, kwargs). + export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions + object that controls the export behavior. + + Returns: + str: A string containing the diff of the exported models. + """ + export_options = ( + _experimental.ExportOptions() if export_options is None else export_options + ) + + jit_diff_report = _check_graph_diff( + model, test_input_groups, export_options, _traced_graph_from_model + ) + if jit_diff_report: + return jit_diff_report + + return _check_graph_diff( + model, test_input_groups, export_options, _onnx_graph_from_model + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model", + category=None, +) +def verify( + model: _ModelType, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] + | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + fixed_batch_size: bool = False, + use_external_data: bool = False, + additional_test_inputs: Sequence[_InputArgsType] | None = None, + options: VerificationOptions | None = None, +): + """Verify model export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Args: + model: See :func:`torch.onnx.export`. + input_args: See :func:`torch.onnx.export`. + input_kwargs: See :func:`torch.onnx.export`. + do_constant_folding: See :func:`torch.onnx.export`. + dynamic_axes: See :func:`torch.onnx.export`. + input_names: See :func:`torch.onnx.export`. + output_names: See :func:`torch.onnx.export`. + training: See :func:`torch.onnx.export`. + opset_version: See :func:`torch.onnx.export`. + keep_initializers_as_inputs: See :func:`torch.onnx.export`. + verbose: See :func:`torch.onnx.export`. + fixed_batch_size: Legacy argument, used only by rnn test cases. + use_external_data: Explicitly specify whether to export the model with external data. + additional_test_inputs: List of tuples. Each tuple is a group of + input arguments to test. Currently only ``*args`` are supported. + options: A VerificationOptions object that controls the verification behavior. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options is None: + options = VerificationOptions() + + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(), contextlib.ExitStack() as stack: + model_f: str | io.BytesIO = io.BytesIO() + if use_external_data: + tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) + model_f = os.path.join(tmpdir_path, "model.onnx") + + inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) + + # TODO(#77679): remove this and treat mutating model separately. + model_copy = _try_clone_model(model) + utils._export( + model, + inputs_for_export, + model_f, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + fixed_batch_size=fixed_batch_size, + training=training, + verbose=verbose, + ) + + _compare_onnx_pytorch_model( + pt_model=model_copy, + onnx_model_f=model_f, + input_args=input_args, + input_kwargs=input_kwargs, + additional_test_inputs=additional_test_inputs, + options=options, + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def verify_aten_graph( + graph: torch.Graph, + input_args: tuple[Any, ...], + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, + verification_options: VerificationOptions | None = None, +) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """Verify aten graph export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + if verification_options is None: + verification_options = VerificationOptions() + if params_dict is None: + params_dict = {} + + original_jit_graph = graph + graph = graph.copy() + + # Execute aten graph and get reference torch jit outputs. + graph_inputs = list(graph.inputs()) + jit_inputs = tuple([arg for arg in input_args if arg is not None]) + weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] + assert all(w is not None for w in weights) + # TODO: Only copy the argument if mutation is detected in Graph. + jit_inputs = copy.deepcopy(jit_inputs) + jit_input_and_parameters = jit_inputs + tuple(weights) + jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] + if not isinstance(jit_outs, (list, tuple)): + jit_outs = [jit_outs] + + # Convert aten graph to onnx graph. + graph, onnx_params_dict = _onnx_graph_from_aten_graph( + graph, export_options, params_dict + ) + + proto, export_map = _onnx_proto_from_onnx_graph( + graph, export_options, onnx_params_dict + ) + model_f: str | io.BytesIO = io.BytesIO() + onnx_proto_utils._export_file(proto, model_f, export_map) + + # NOTE: Verification is unstable. Try catch to emit information for debugging. + try: + # NOTE: Input might be dce'ed, so we need to remove those from the input args. + new_input_names = {v.debugName() for v in graph.inputs()} + new_input_args = [] + for v, arg in zip(original_jit_graph.inputs(), input_args): + if v.debugName() in new_input_names: + new_input_args.append(arg) + input_args = tuple(new_input_args) + + onnx_inputs = _prepare_input_for_onnx( + input_args, + {}, + verification_options.remained_onnx_input_idx, + verification_options.flatten, + ) + + onnx_session = _onnx_backend_session(model_f, verification_options.backend) + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + del onnx_session # To free device memory + + try: + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=jit_outs, + options=verification_options, + ) + except AssertionError as e: + return e, graph, jit_outs, onnx_outs + + return None, graph, jit_outs, onnx_outs + + except Exception as e: + print("Unexpected error during verification.") + print("jit graph: ", original_jit_graph) + print("onnx graph: ", graph) + raise e + + +class GraphInfoPrettyPrinter: + graph_info: GraphInfo | None + upper_printer: GraphInfoPrettyPrinter | None + lower_printer: GraphInfoPrettyPrinter | None + + graph_str_lambdas: Mapping[int, str] + connector_str_lambdas: Mapping[int, str] + children_str_lambdas: Mapping[int, str] + + def __init__(self, graph_info: GraphInfo | None): + self.graph_info = graph_info + if ( + graph_info is not None + and graph_info.upper_graph_info is not None + and graph_info.lower_graph_info is not None + ): + self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) + self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) + else: + self.upper_printer = None + self.lower_printer = None + + def _total_rows(self) -> int: + if self.graph_info is None: + return 1 + if self.upper_printer and self.lower_printer: + return ( + self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 + ) + return 2 # Two lines: node count + id. + + def _node_count_segment_str(self) -> str: + if self.graph_info is None: + return "..." + node_count = self.graph_info.essential_node_count() + has_mismatch = self.graph_info.has_mismatch() + error_node_kind = ( + f"({self.graph_info.essential_node_kinds().pop()})" + if node_count == 1 and has_mismatch + else "" + ) + + return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" + + def _graph_id_segment_str(self) -> str: + if self.graph_info is None: + return "" + return f"id: {self.graph_info.id}" + + def _max_segment_columns(self) -> int: + return max( + map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) + ) + + def _graph_segment_str_at_line(self, line: int) -> str: + """Get the string representation of the graph segment at the given line.""" + if line == 0: + result_str = self._node_count_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if line == 1: + result_str = self._graph_id_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if 0 <= line < self._total_rows(): + return " " * self._max_segment_columns() + return "" + + def _connector_segment_str_at_line(self, line: int) -> str: + """Get the connector segment string at the given line.""" + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if line == 0: + return " __" + elif line < upper_total_rows + 1: + return " | " + elif line == upper_total_rows + 1: + return " |__" + elif line < upper_total_rows + lower_total_rows + 1: + return " " + return "" + + def _children_str_at_line(self, line: int) -> str: + """Get the string representation of the children at the given line. + + Recursively calls `_str_at_line` on children nodes. + """ + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if 0 <= line < upper_total_rows: + return ( + self.upper_printer._str_at_line(line) if self.upper_printer else "..." + ) + elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: + return ( + self.lower_printer._str_at_line(line - upper_total_rows - 1) + if self.lower_printer + else "..." + ) + return "" + + def _str_at_line(self, line: int) -> str: + """Get the string representation of the graph at the given line.""" + return ( + self._graph_segment_str_at_line(line) + + self._connector_segment_str_at_line(line) + + self._children_str_at_line(line) + ) + + def pretty_print(self): + if self.graph_info is None: + print(None) + return + # Print tree. + print(" Tree: ".center(80, "=")) + total_rows = self._total_rows() + for line in range(total_rows): + print(self._str_at_line(line).rstrip()) + if self.graph_info.has_mismatch(): + # Summarize leaf subgraphs with mismatch. + print(" Mismatch leaf subgraphs: ".center(80, "=")) + print( + [ + graph_info.id + for graph_info in self.graph_info.all_mismatch_leaf_graph_info() + ] + ) + # Summarize node kinds with mismatch. + mismatch_node_kinds: dict[str, int] = {} + for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): + node_kinds = graph_info.essential_node_kinds() + if len(node_kinds) == 1: + node_kind = node_kinds.pop() + mismatch_node_kinds[node_kind] = ( + mismatch_node_kinds.get(node_kind, 0) + 1 + ) + print(" Mismatch node kinds: ".center(80, "=")) + print(mismatch_node_kinds) + else: + print(" No mismatch found. ".center(80, "=")) + + +class OnnxTestCaseRepro: + def __init__(self, repro_dir): + self.repro_dir = repro_dir + self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( + repro_dir + ) + + @classmethod + def create_test_case_repro( + cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None + ): + """Create a repro under "{dir}/test_{name}" for an ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory + structure is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + proto: ONNX model proto. + inputs: Inputs to the model. + outputs: Outputs of the model. + dir: Directory to save the repro. + name: Name of the test case. If not specified, a name based on current time + will be generated. + Returns: + Path to the repro. + """ + if name is None: + name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return onnx_proto_utils.export_as_test_case( + proto, + _to_numpy(inputs), + _to_numpy(outputs), + name, + dir, + ) + + def validate(self, options: VerificationOptions): + """Run the ONNX test case with options.backend, and compare with the expected outputs. + + Args: + options: Options for validation. + + Raise: + AssertionError: if outputs from options.backend and expected outputs are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) + run_outputs = onnx_session.run(None, self.inputs) + if hasattr(onnx_session, "get_outputs"): + output_names = [o.name for o in onnx_session.get_outputs()] + elif hasattr(onnx_session, "output_names"): + output_names = onnx_session.output_names + else: + raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") + expected_outs = [self.outputs[name] for name in output_names] + _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +@dataclasses.dataclass +class GraphInfo: + """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + graph: torch.Graph + input_args: tuple[Any, ...] + params_dict: dict[str, Any] + export_options: _experimental.ExportOptions = dataclasses.field( + default_factory=_experimental.ExportOptions + ) + mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) + pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) + upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + id: str = dataclasses.field(default="") + _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) + + _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( + {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} + ) + + def clear(self): + """Clear states and results of previous verification.""" + self.mismatch_error = None + self.pt_outs = None + self._onnx_graph = None + self.upper_graph_info = None + self.lower_graph_info = None + + def pretty_print_tree(self): + """Pretty print `GraphInfo` tree. + + Each node represents a subgraph, showing the number of nodes in the subgraph and + a check mark if the subgraph has output mismatch between torch and ONNX. + + The id of the subgraph is shown under the node. The `GraphInfo` object for any + subgraph can be retrieved by calling `graph_info.find_partition(id)`. + + Example:: + + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + GraphInfoPrettyPrinter(self).pretty_print() + + def pretty_print_mismatch(self, graph: bool = False): + """Pretty print details of the mismatch between torch and ONNX. + + Args: + graph: If True, print the ATen JIT graph and ONNX graph. + """ + print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) + if graph: + print(" ATen JIT graph ".center(80, "=")) + # TODO: A more compact graph printer. + # * Drop stride, grad, device information. + # * Show source location on a separate line. + print(self.graph) + if self._onnx_graph is not None: + print(" ONNX graph ".center(80, "=")) + print(self._onnx_graph) + if self.has_mismatch(): + print(" Mismatch error ".center(80, "=")) + print(self.mismatch_error) + else: + print(" No mismatch ".center(80, "=")) + + def has_mismatch(self) -> bool: + """Return True if the subgraph has output mismatch between torch and ONNX.""" + return self.mismatch_error is not None + + def essential_node_count(self) -> int: + """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return sum( + 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS + ) + + def essential_node_kinds(self) -> set[str]: + """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return { + n.kind() + for n in self.graph.nodes() + if n.kind() not in self._EXCLUDED_NODE_KINDS + } + + def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: + """Return a list of all leaf `GraphInfo` objects that have mismatch.""" + if not self.has_mismatch(): + return [] + + no_mismatch_children = ( + self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() + ) and ( + self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() + ) + + if no_mismatch_children: + return [self] + + results = [] + if self.upper_graph_info is not None: + results += self.upper_graph_info.all_mismatch_leaf_graph_info() + if self.lower_graph_info is not None: + results += self.lower_graph_info.all_mismatch_leaf_graph_info() + + return results + + def find_partition(self, id: str) -> GraphInfo | None: + """Find the `GraphInfo` object with the given id.""" + if id == self.id: + return self + current_length = len(self.id) + if len(id) > current_length: + if id[current_length] == "0" and self.upper_graph_info is not None: + return self.upper_graph_info.find_partition(id) + elif id[current_length] == "1" and self.lower_graph_info is not None: + return self.lower_graph_info.find_partition(id) + return None + + def export_repro( + self, repro_dir: str | None = None, name: str | None = None + ) -> str: + """Export the subgraph to ONNX along with the input/output data for repro. + + The repro directory will contain the following files:: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + repro_dir: The directory to export the repro files to. Defaults to current + working directory if None. + name: An optional name for the test case folder: "test_{name}". + + Returns: + The path to the exported repro directory. + """ + + if repro_dir is None: + repro_dir = os.getcwd() + repro_dir = os.path.join(repro_dir, "onnx_debug") + + onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( + self.graph, self.export_options, self.params_dict + ) + + proto, _ = _onnx_proto_from_onnx_graph( + onnx_graph, self.export_options, onnx_params_dict + ) + return OnnxTestCaseRepro.create_test_case_repro( + proto, self.input_args, self.pt_outs, repro_dir, name + ) + + def _graph_partition_pivot(self) -> int: + """Find the pivot index to partition the graph. + + The pivot is the node that splits the graph into two parts. Each part should + have the similar amount of nodes, excluding non essential ops, defined in + `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. + If the graph has an odd number of nodes, the upper part will have one more node. + If the graph does not have any node that can be partitioned, return -1. + + Returns: + The index of the pivot node. + """ + included_node_indices = [ + i + for i, n in enumerate(self.graph.nodes()) + if n.kind() not in self._EXCLUDED_NODE_KINDS + ] + half_idx = len(included_node_indices) // 2 - 1 + if half_idx >= 0 and len(included_node_indices) > half_idx: + return included_node_indices[half_idx] + 1 + return -1 + + def _partition_upper_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + + def _process_bridge_value_for_upper( + new_outputs: list[torch.Value], bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as upper graph outputs. + new_outputs.append(bridge_value) + return bridge_value + + new_outputs: list[torch.Value] = [] + process_bridge_value_for_upper = functools.partial( + _process_bridge_value_for_upper, new_outputs + ) + _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( + graph, pivot, process_bridge_value_for_upper + ) + + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for node in reversed(dropped_nodes): + node.destroy() + + for i, input in reversed(list(enumerate(list(graph.inputs())))): + if ( + not _has_uses_by_nodes(input, complete_upper_nodes_set) + and input not in new_outputs + ): + try: + graph.eraseInput(i) + except RuntimeError as e: + print(input, graph) + raise e + + return graph + + def _partition_lower_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + original_inputs = list(graph.inputs()) + + def _process_bridge_value_for_lower( + graph: torch.Graph, bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as lower graph inputs. + new_input = graph.addInput() + bridge_value.replaceAllUsesWith(new_input) + new_input.copyMetadata(bridge_value) + return new_input + + process_bridge_value_for_lower = functools.partial( + _process_bridge_value_for_lower, graph + ) + + upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( + graph, pivot, process_bridge_value_for_lower + ) + + new_outputs = [ + output for output in original_outputs if _produced_by(output, lower_nodes) + ] + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for input in original_inputs: + if _has_uses_by_nodes(input, complete_lower_nodes_set): + new_input = graph.addInput() + input.replaceAllUsesWith(new_input) + new_input.copyMetadata(input) + + for node in reversed(upper_nodes): + if node not in complete_lower_nodes_set: + try: + node.destroy() + except RuntimeError as e: + print(node, graph) + raise e + + for _ in original_inputs: + graph.eraseInput(0) + + return graph + + def _partition_node( + self, + node: torch.Node, + complete_upper_nodes_set: set[torch.Node], + complete_lower_nodes_set: set[torch.Node], + original_graph_outputs: set[torch.Value], + covered_bridge_values: set[torch.Value], + process_bridge_value: Callable[[torch.Value], torch.Value], + ): + if node in complete_lower_nodes_set: + return + + if ( + _node_has_uses_by(node, complete_lower_nodes_set) + and node.kind() in self._EXCLUDED_NODE_KINDS + ): + complete_lower_nodes_set.update(_all_nodes([node])) + for input in node.inputs(): + if input in covered_bridge_values: + continue + self._partition_node( + input.node(), + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + else: + for output in node.outputs(): + if output in covered_bridge_values: + continue + if ( + _has_uses_by_nodes(output, complete_lower_nodes_set) + or output in original_graph_outputs + ): + covered_bridge_values.add(process_bridge_value(output)) + + def _partition_nodes( + self, + graph: torch.Graph, + pivot: int, + process_bridge_value: Callable[[torch.Value], torch.Value], + ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: + nodes = list(graph.nodes()) + upper_nodes = nodes[:pivot] + lower_nodes = nodes[pivot:] + # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter + # recursively contains nodes in subblock of `upper_nodes`. + # The same applies for `lower_nodes` and `complete_lower_nodes_set`. + # With addition that `complete_lower_nodes_set` will include nodes that + # are determined to be copied from `upper_nodes` to `lower_nodes`. + complete_upper_nodes_set = _all_nodes(upper_nodes) + complete_lower_nodes_set = _all_nodes(lower_nodes) + original_graph_outputs = set(graph.outputs()) + # Bridge values are values produced from upper graph, and consumed + # by lower graph. These values need to be become upper graph outputs + # and lower graph inputs, to bridge the interaction. + # Start with all graph inputs marked as covered. If any graph input is + # needed by lower graph, just keep it in lower graph inputs later. + covered_bridge_values = set(graph.inputs()) + for node in upper_nodes: + self._partition_node( + node, + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + return ( + upper_nodes, + lower_nodes, + complete_upper_nodes_set, + complete_lower_nodes_set, + ) + + def _bridge_kwargs(self): + pt_outs = self.pt_outs + graph_outputs = list(self.graph.outputs()) + assert pt_outs is not None + assert len(graph_outputs) == len(pt_outs), ( + f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + ) + return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} + + def _args_and_params_for_partition_graph( + self, + graph: torch.Graph, + bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], + full_kwargs: Mapping[str, torch.Tensor], + full_params: Mapping[str, torch.Tensor], + ): + input_names = [input.debugName() for input in graph.inputs()] + args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) + args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) + params = {k: full_params[k] for k in input_names if k in full_params} + assert len(args) + len(params) == len(input_names), ( + f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + ) + return args, params + + def verify_export( + self, options: VerificationOptions + ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """ + Verify the export from TorchScript IR graph to ONNX. + + Export the TorchScript IR graph to ONNX, with the inputs, parameters and export + options recorded in this object. Then verify the exported ONNX graph against + the original TorchScript IR graph under the provided verification options. + + Args: + options: The verification options. + + Returns: + error: The AssertionError raised during the verification. Returns None if no + error is raised. + onnx_graph: The exported ONNX graph in TorchScript IR format. + onnx_outs: The outputs from running exported ONNX model under the onnx + backend in `options`. + pt_outs: The outputs from running the TorchScript IR graph. + """ + return verify_aten_graph( + self.graph, + input_args=self.input_args, + params_dict=self.params_dict, + export_options=self.export_options, + verification_options=options, + ) + + def find_mismatch( + self, + options: VerificationOptions | None = None, + ): + """ + Find all mismatches between the TorchScript IR graph and the exported onnx model. + + Binary searches the model graph to find the minimal subgraph that exhibits the + mismatch. A `GraphInfo` object is created for each subgraph, recording the test + inputs and export options, as well as the validation results. + + Args: + options: The verification options. + """ + self.clear() + + if options is None: + options = VerificationOptions() + + if self.export_options.verbose: + print(self.graph) + + if len(list(self.graph.outputs())) == 0: + return + + assert len(self.input_args) + len(self.params_dict) == len( + list(self.graph.inputs()) + ), ( + f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " + f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." + ) + + self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( + options + ) + + if self.mismatch_error is None: + # No mismatch found in graph. + return + + if self.essential_node_count() <= 1: + # Reached leaf node, no more partitioning. + return + + full_kwargs = { + k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) + } + full_params = self.params_dict + + upper_graph = self._partition_upper_graph() + upper_args, upper_params = self._args_and_params_for_partition_graph( + upper_graph, {}, full_kwargs, full_params + ) + self.upper_graph_info = GraphInfo( + upper_graph, + upper_args, + upper_params, + self.export_options, + id=self.id + "0", + ) + + self.upper_graph_info.find_mismatch(options) + + bridge_kwargs = self.upper_graph_info._bridge_kwargs() + lower_graph = self._partition_lower_graph() + lower_args, lower_params = self._args_and_params_for_partition_graph( + lower_graph, bridge_kwargs, full_kwargs, full_params + ) + self.lower_graph_info = GraphInfo( + lower_graph, + lower_args, + lower_params, + self.export_options, + id=self.id + "1", + ) + + self.lower_graph_info.find_mismatch(options) + + +def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: + all_nodes = set(nodes) + for n in nodes: + for b in n.blocks(): + all_nodes.update(_all_nodes(list(b.nodes()))) + return all_nodes + + +def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return any(use.user in nodes for use in value.uses()) + + +def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: + for output in node.outputs(): + if _has_uses_by_nodes(output, nodes): + return True + return False + + +def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return value.node() in nodes + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def find_mismatch( + model: torch.nn.Module | torch.jit.ScriptModule, + input_args: tuple[Any, ...], + do_constant_folding: bool = True, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + options: VerificationOptions | None = None, +) -> GraphInfo: + r"""Find all mismatches between the original model and the exported model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Experimental. The API is subject to change. + + This tool helps debug the mismatch between the original PyTorch model and exported + ONNX model. It binary searches the model graph to find the minimal subgraph that + exhibits the mismatch. + + Args: + model: The model to be exported. + input_args: The input arguments to the model. + do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. + training: Same as `training` in :func:`torch.onnx.export`. + opset_version: Same as `opset_version` in :func:`torch.onnx.export`. + keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. + verbose: Same as `verbose` in :func:`torch.onnx.export`. + options: The options for the mismatch verification. + + Returns: + A GraphInfo object that contains the mismatch information. + + Example:: + + >>> import torch + >>> import torch.onnx.verification + >>> torch.manual_seed(0) + >>> opset_version = 15 + >>> # Define a custom symbolic function for aten::relu. + >>> # The custom symbolic function is incorrect, which will result in mismatches. + >>> def incorrect_relu_symbolic_function(g, self): + ... return self + >>> torch.onnx.register_custom_op_symbolic( + ... "aten::relu", + ... incorrect_relu_symbolic_function, + ... opset_version=opset_version, + ... ) + >>> class Model(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.layers = torch.nn.Sequential( + ... torch.nn.Linear(3, 4), + ... torch.nn.ReLU(), + ... torch.nn.Linear(4, 5), + ... torch.nn.ReLU(), + ... torch.nn.Linear(5, 6), + ... ) + ... def forward(self, x): + ... return self.layers(x) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> graph_info = torch.onnx.verification.find_mismatch( + ... Model(), + ... (torch.randn(2, 3),), + ... opset_version=opset_version, + ... ) + ===================== Mismatch info for graph partition : ====================== + ================================ Mismatch error ================================ + Tensor-likes are not close! + Mismatched elements: 12 / 12 (100.0%) + Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) + Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + if options is None: + options = VerificationOptions() + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + """From aten graph, do binary search on graph partition to find operator export discrepancy.""" + # TODO: Copied from utils.py `export` until `_optimize_graph`. + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(): + inputs_for_export = _prepare_input_for_export(input_args, {}) + args = utils._decide_input_format(model, inputs_for_export) + + model = utils._pre_trace_quant_model(model, args) + graph, params, _torch_out, _module = utils._create_jit_graph(model, args) + params_dict = utils._get_named_param_dict(graph, params) + + utils._apply_friendly_debug_names(graph, params_dict) + + graph_info = GraphInfo( + graph, + input_args, + params_dict, + _experimental.ExportOptions( + do_constant_folding=do_constant_folding, + training=training, + opset_version=opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose, + ), + ) + graph_info.find_mismatch(options) + graph_info.pretty_print_mismatch() + graph_info.pretty_print_tree() + + return graph_info diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index dc6312e5f7a3..76b50a8eb3f7 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1,2267 +1,8 @@ -# mypy: allow-untyped-defs +"""Backward compatibility module for torch.onnx.symbolic_helper.""" + from __future__ import annotations -import functools -import inspect -import math -import sys -import typing -import warnings -from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar -from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec -import torch -import torch._C._onnx as _C_onnx -from torch import _C +__all__: list[str] = [] -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import _constants, _type_utils, errors, utils -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils - - -if typing.TYPE_CHECKING: - from collections.abc import Sequence - - from torch.types import Number - -_T = _TypeVar("_T") -_U = _TypeVar("_U") -_P = _ParamSpec("_P") - -# --------------------------------------------------------------------------------- -# Helper functions -# --------------------------------------------------------------------------------- - -_ValueDescriptor = Literal[ - "v", - "i", - "is", - "f", - "fs", - "b", - "s", - "t", - "none", -] - - -def _parse_arg( - value, - desc: _ValueDescriptor, - arg_name: str | None = None, - node_name: str | None = None, -): - if desc == "none": - return value - if desc == "v" or not _is_value(value): - return value - - node = value.node() - if node.mustBeNone(): - return None - if node.kind() == "onnx::Constant": - node_val = _node_get(node, "value") - if desc == "i": - return int(node_val) - elif desc == "f": - return float(node_val) - elif desc == "b": - return bool(node_val) - elif desc == "s": - return str(node_val) - elif desc == "t": - return node_val - elif desc == "is": - return [int(v) for v in node_val] - elif desc == "fs": - return [float(v) for v in node_val] - else: - raise errors.SymbolicValueError( - f"ONNX symbolic does not understand the Constant node '{node}' " - f"specified with descriptor '{desc}'.", - value, - ) - elif node.kind() == "prim::ListConstruct": - if desc == "is": - for v in node.inputs(): - element_node = v.node() - if element_node.kind() != "onnx::Constant": - raise errors.SymbolicValueError( - f"Failed to export a node '{element_node}' " - f"(in list node {node}) " - f"because it is not constant. " - f"Please try to make things (e.g. kernel sizes) static if possible.", - value, - ) - return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] - else: - raise errors.SymbolicValueError( - f"ONNX symbolic does not know how to unpack the ListConstruct node that " - f"is not a list of integers: '{node}'", - value, - ) - - if arg_name is None or node_name is None: - raise errors.SymbolicValueError( - f"Expected node type 'onnx::Constant', got '{node.kind()}'.", - value, - ) - - raise errors.SymbolicValueError( - "Expected node type 'onnx::Constant' " - f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", - value, - ) - - -def _node_get(node: _C.Node, key: str): - """Gets attributes of a node which is polymorphic over return type.""" - assert isinstance(node, _C.Node) - sel = node.kindOf(key) - return getattr(node, sel)(key) - - -def _is_onnx_constant(value: _C.Value): - """Whether a Value is an ONNX constant.""" - return value.node().kind() == "onnx::Constant" - - -def _maybe_get_const( - value: _C.Value | torch.Tensor | Number | Sequence | None, - descriptor: _ValueDescriptor, -): - # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, - # otherwise it'd be converted to onnx::Constant - # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy - if isinstance(value, _C.Value) and _is_onnx_constant(value): - return _parse_arg(value, descriptor) - return value - - -def _maybe_get_scalar(value): - value_t = _maybe_get_const(value, "t") - if isinstance(value_t, torch.Tensor) and value_t.shape == (): - return value_t - return value - - -def _get_const(value, desc, arg_name): - if not _is_constant(value): - raise errors.SymbolicValueError( - f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " - f"got '{value}'", - value, - ) - return _parse_arg(value, desc) - - -def _unpack_list(list_value: _C.Value) -> list[_C.Value]: - list_node = list_value.node() - if list_node.kind() != "prim::ListConstruct": - raise errors.SymbolicValueError( - f"ONNX symbolic expected node type prim::ListConstruct, got '{list_node}'.", - list_value, - ) - return list(list_node.inputs()) - - -def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: - tuple_node = tuple_value.node() - if not _is_tuple_construct(tuple_value): - raise errors.SymbolicValueError( - f"ONNX symbolic expected node type 'prim::TupleConstruct', " - f"got '{tuple_node.kind()}'.", - tuple_value, - ) - return tuple(tuple_node.inputs()) - - -def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: - """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. - Args: - tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. - Returns: - A tuple of tensor, scale, zero_point, and optionally axis. - """ - tuple_node = tuple_value.node() - # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, ) - if not _is_tuple_construct(tuple_value): - raise errors.SymbolicValueError( - f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " - f"tensor. Is this likely due to missing support for quantized " - f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", - tuple_value, - ) - unpacked = tuple(tuple_node.inputs()) - assert len(unpacked) == 3 or len(unpacked) == 4 - return unpacked - - -# Check if list_value is output from prim::ListConstruct -# This is usually called before _unpack_list to ensure the list can be unpacked. -def _is_packed_list(list_value: Any) -> bool: - return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" - - -def parse_args( - *arg_descriptors: _ValueDescriptor, -) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: - """A decorator which converts args from torch._C.Value to built-in types. - - For example: - - ``` - @parse_args('v', 'i', 'fs') - foo(g, a, b, c): - assert isinstance(a, torch._C.Value) - assert isinstance(b, int) - assert isinstance(c, list) - assert isinstance(c[0], float) - ``` - - Args: - arg_descriptors: list of str, where each element is - a string that specifies the type to convert to. Valid descriptors: - "v": no conversion, keep torch._C.Value. - "i": int - "is": list of int - "f": float - "fs": list of float - "b": bool - "s": str - "t": torch.Tensor - "none": the variable is unused - """ - - def decorator( - fn: Callable[_Concatenate[_U, _P], _T], - ) -> Callable[_Concatenate[_U, _P], _T]: - fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] - - @functools.wraps(fn) - def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: - # some args may be optional, so the length may be smaller - FILE_BUG_MSG = ( - "If you believe this is not due to custom symbolic implementation within your code or " - "an external library, please file an issue at " - "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." - ) - assert len(arg_descriptors) >= len(args), ( - f"A mismatch between the number of arguments ({len(args)}) and " - f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " - f"{FILE_BUG_MSG}" - ) - - try: - sig = inspect.signature(fn) - arg_names = list(sig.parameters.keys())[1:] - fn_name = fn.__name__ - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - arg_names = [None] * len(args) # type: ignore[list-item] - fn_name = None - args = [ - _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] - for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) - ] - # only support _outputs in kwargs - assert len(kwargs) <= 1, ( - f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " - f"key/value entry. " - f"{FILE_BUG_MSG}" - ) - - if len(kwargs) == 1: - assert "_outputs" in kwargs, ( - f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " - f"'_outputs' key at '**kwargs'. " - f"{FILE_BUG_MSG}" - ) - return fn(g, *args, **kwargs) - - return wrapper - - return decorator - - -def quantized_args( - *arg_q_descriptors: bool, - scale: float | None = None, - zero_point: int | None = None, - quantize_output: bool = True, -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - """A decorator which extends support for quantized version of the base operator. - - Quantization is detected by examining the arguments that are annotated by - `arg_q_descriptors`. - - If quantization is detected, the base operator symbolic function will be wrapped with - argument de-quantization and output quantization. - - Otherwise, only the base symbolic function will be invoked. - - For example: - - ``` - @quantized_args(True, False) - def foo(g, x, y): - return x + y - ``` - - is equivalent to - - ``` - def q_foo(g, x, y): - if is_quantized_tensor(x): - x = dequantize(x) - out = foo(g, x, y) - return quantize(out) - else: - return foo(g, x, y) - ``` - - Args: - arg_q_descriptors: A sequence of bool, where each element represents if the - argument is QTensor for quantized version of this operator. It defaults - to False for unspecified (variable length) arguments. - scale: Quantized output scale. If None, derive from - the first quantized input scale. - zero_point: Quantized output zero point. If None, - derive from the first quantized input zero point. - quantize_output: If True, quantize the output of the base operator. Default is True - """ - - def decorator(fn): - @functools.wraps(fn) - def wrapper(g, *args, **kwargs): - nonlocal scale - nonlocal zero_point - if scale is not None: - _scale = g.op("Constant", value_t=torch.tensor(scale)) - else: - _scale = None - if zero_point is not None: - _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) - else: - _zero_point = None - - # Support variable length arguments by marking unspecified ones as non-quantized - arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( - len(args) - len(arg_q_descriptors) - ) - descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) - - def _is_arg_quantized(descriptor, arg): - return descriptor and _is_value(arg) and _is_tuple_construct(arg) - - # Run regular symbolic function if none of the argument is QTensor. - is_quantized: list[bool] = [] - for descriptor, arg in descriptor_args: - # ListConstruct - if _is_packed_list(arg): - is_quantized.extend( - _is_arg_quantized(descriptor, arg_input) - for arg_input in arg.node().inputs() - ) - else: - is_quantized.append(_is_arg_quantized(descriptor, arg)) - - if not any(is_quantized): - return fn(g, *args, **kwargs) - - # Dequantize arguments that are quantized - non_quantized_args = [] - for descriptor, arg in descriptor_args: - if _is_arg_quantized(descriptor, arg): - # Quantized arg is a tuple of (value, scale, zero_point) - dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( - g, arg - ) - non_quantized_args.append(dequantized_arg) - # Set scale and zero_point to the first quantized input if not already set - if _scale is None: - _scale = arg_scale - if _zero_point is None: - _zero_point = arg_zero_point - # ListConstruct - elif _is_packed_list(arg): - for arg_input in arg.node().inputs(): - if _is_arg_quantized(descriptor, arg_input): - # Quantized arg is a tuple of (value, scale, zero_point) - ( - dequantized_arg, - arg_scale, - arg_zero_point, - _, - ) = dequantize_helper(g, arg_input) - # Set scale and zero_point to the first quantized input if not already set - if _scale is None: - _scale = arg_scale - if _zero_point is None: - _zero_point = arg_zero_point - arg_input.replaceAllUsesWith(dequantized_arg) - non_quantized_args.append(arg) - else: - # Non-quantized arg - non_quantized_args.append(arg) - # TODO(justinchuby): Only single output is supported for now. We may want to - # support multiple outputs in the future. - output = fn(g, *non_quantized_args, **kwargs) - - assert _scale is not None, "Bug: Scale must be set for quantized operator" - assert _zero_point is not None, ( - "Bug: Zero point must be set for quantized operator" - ) - - if quantize_output: - return quantize_helper(g, output, _scale, _zero_point) - return output - - return wrapper - - return decorator - - -def _scalar(x: Any) -> Number | None: - """Convert a scalar tensor into a Python value.""" - if isinstance(x, torch.Tensor) and x.shape == (): - return x.item() - return None - - -def _if_scalar_type_as(self, tensor): - """ - Convert self into the same type of tensor, as necessary. - We only support implicit casting for scalars, so we never - actually need to insert an ONNX cast operator here; just - fix up the scalar. - """ - if isinstance(self, _C.Value): - return self - - scalar_type = _type_utils.JitScalarType.from_value( - tensor, _type_utils.JitScalarType.UNDEFINED - ) - if scalar_type != _type_utils.JitScalarType.UNDEFINED: - ty = scalar_type.scalar_name().lower() - return getattr(self, ty)() - return self - - -def _is_none(x: Any) -> bool: - return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) - - -def _is_value(x: Any) -> bool: - return isinstance(x, _C.Value) - - -def _is_constant(value: Any) -> bool: - return not _is_value(value) or value.node().kind() in { - "onnx::Constant", - "prim::Constant", - } - - -def _is_tensor(x: _C.Value) -> bool: - return x.type().isSubtypeOf(_C.TensorType.get()) - - -# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. -def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: - if isinstance(jit_type, _C.ListType): - return jit_type - return None - - -def _is_list(x: _C.Value) -> bool: - return _as_list_type(x.type()) is not None - - -def _is_tensor_list(x: _C.Value) -> bool: - x_type = _as_list_type(x.type()) - if x_type is None: - return False - return isinstance(x_type.getElementType(), _C.TensorType) - - -def _is_scalar_list(x: _C.Value) -> bool: - """Checks if x is a scalar list, for example: List[float], List[int]. - - Besides checking the type is ListType, we also check if the data type is - a valid ONNX data type. - """ - x_type = _as_list_type(x.type()) - if x_type is None: - return False - scalar_type = _type_utils.JitScalarType.from_value(x) - return scalar_type.onnx_compatible() - - -def _is_tuple_construct(x: _C.Value) -> bool: - return x.node().kind() == "prim::TupleConstruct" - - -def is_complex_value(x: _C.Value) -> bool: - assert _is_value(x) - return _type_utils.JitScalarType.from_value( - x, _type_utils.JitScalarType.UNDEFINED - ) in { - _type_utils.JitScalarType.COMPLEX32, - _type_utils.JitScalarType.COMPLEX64, - _type_utils.JitScalarType.COMPLEX128, - } - - -def _get_tensor_rank(x: _C.Value) -> int | None: - if not _is_tensor(x) or x.type() is None: - return None - x_type = x.type() - x_type = typing.cast(_C.TensorType, x_type) - return x_type.dim() - - -def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): - if not _is_tensor(x) or x.type() is None: - return None - x_type = x.type() - x_type = typing.cast(_C.TensorType, x_type) - if allow_nonstatic: - # Each individual symbol is returned as None. - # e.g. [1, "a", "b"] -> [1, None, None] - return x_type.varyingSizes() - # returns None, if exists any symbol in sizes. - # e.g. [1, "a", "b"] -> None - return x_type.sizes() - - -def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: - sizes = _get_tensor_sizes(x) - return sizes[dim] if sizes else None - - -def _get_dim_for_cross(x: _C.Value, dim: int | None): - if dim == -1: - tensor_rank = _get_tensor_rank(x) - assert tensor_rank is not None - return dim + tensor_rank - # If dim is not given, it defaults to the first dimension found with the size 3 - if dim is None: - sizes = _get_tensor_sizes(x) - assert sizes is not None - for index, size in enumerate(sizes): - if size is not None and size == 3: - return index - return dim - - -def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: - # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators - if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: - _onnx_unsupported(f"{op}, {msg}", value) - - -def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: - message = ( - f"Unsupported: ONNX export of operator {op_name}. " - f"Please feel free to request support or submit a pull request " - f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" - ) - if isinstance(value, _C.Value): - raise errors.SymbolicValueError( - message, - value, - ) - raise errors.OnnxExporterError(message) - - -def _onnx_opset_unsupported( - op_name: str, - current_opset: int, - supported_opset: int, - value: _C.Value | None = None, -) -> NoReturn: - message = ( - f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " - f"Please try opset version {supported_opset}." - ) - if isinstance(value, _C.Value): - raise errors.SymbolicValueError( - message, - value, - ) - raise errors.OnnxExporterError(message) - - -def _onnx_opset_unsupported_detailed( - op_name: str, - current_opset: int, - supported_opset: int, - reason: str, - value: _C.Value | None = None, -) -> NoReturn: - message = ( - f"Unsupported: ONNX export of {op_name} in " - f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." - ) - if isinstance(value, _C.Value): - raise errors.SymbolicValueError( - message, - value, - ) - raise errors.OnnxExporterError(message) - - -def _block_list_in_opset(name: str): - def symbolic_fn(*args, **kwargs): - raise errors.OnnxExporterError( - f"ONNX export failed on {name}, which is not implemented for opset " - f"{GLOBALS.export_onnx_opset_version}. " - "Try exporting with other opset versions." - ) - - return symbolic_fn - - -def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: - for arg in args: - scalar_type = _type_utils.JitScalarType.from_value( - arg, _type_utils.JitScalarType.UNDEFINED - ) - if scalar_type != _type_utils.JitScalarType.UNDEFINED: - return scalar_type - return None - - -def _type_promote_from_values(*args) -> _type_utils.JitScalarType: - undef = _type_utils.JitScalarType.UNDEFINED - jit_types = [_try_get_scalar_type(arg) for arg in args] - if len(jit_types) == 0: - return undef - if len(jit_types) == 1: - return jit_types[0] # type: ignore[return-value] - new_dtype = jit_types[0].dtype() # type: ignore[union-attr] - for t in jit_types: - new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] - return _type_utils.JitScalarType.from_dtype(new_dtype) - - -def _maybe_cast_to_type( - g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType -): - if ( - _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) - != jit_type - ): - return g.op( - "Cast", - value, - to_i=jit_type.onnx_type(), - ) - return value - - -def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): - index_const = _maybe_get_scalar(index) - index_dim = _get_tensor_rank(index) - if not _is_value(index_const): - # Index is a constant scalar. Make it a size 1 constant tensor. - index = g.op("Constant", value_t=torch.LongTensor([index_const])) - elif index_dim is not None and apply_reshape: - if index_dim == 0: - # Index is a scalar. Reshape it to a size 1 tensor. - index = _reshape_helper( - g, index, g.op("Constant", value_t=torch.LongTensor([1])) - ) - - index_scalar_type = _type_utils.JitScalarType.from_value( - index, _type_utils.JitScalarType.UNDEFINED - ) - if index_scalar_type not in { - _type_utils.JitScalarType.INT64, - _type_utils.JitScalarType.INT, - }: - index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) - return g.op("Gather", self, index, axis_i=dim) - - -def _slice_helper( - g: jit_utils.GraphContext, - input, - axes, - starts, - ends, - steps=None, -): - if g.opset <= 9: - from torch.onnx.symbolic_opset9 import _slice as _slice9 - - return _slice9(g, input, axes, starts, ends) - else: - from torch.onnx.symbolic_opset10 import _slice as _slice10 - - return _slice10(g, input, axes, starts, ends, steps) - - -def _is_fp(value) -> bool: - return _type_utils.JitScalarType.from_value( - value, _type_utils.JitScalarType.UNDEFINED - ) in { - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.DOUBLE, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, - } - - -def _is_bool(value) -> bool: - return _type_utils.JitScalarType.from_value( - value, _type_utils.JitScalarType.UNDEFINED - ) in {_type_utils.JitScalarType.BOOL} - - -def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): - """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. - - A Tensor is a considered a "wrapped number" if it is - auto-wrapped from a C++ or Python number type. Integer types are - wrapped as 0-dim int64 tensors and floating-point types are - wrapped as 0-dim double tensors. - - The input to this function is constant value. If the data type - is a floating point type, it is converted to a 0-dim double - tensor, else it is converted to a 0-dim tensor of its original type - """ - assert not isinstance(scalar, torch.Tensor) - if isinstance(scalar, float): - return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) - return g.op("Constant", value_t=torch.tensor(scalar)) - - -def _sort_helper(g: jit_utils.GraphContext, input, dim, descending=True, out=None): - if out is not None: - _unimplemented("Sort", "Out parameter is not supported") - shape_ = g.op("Shape", input) - dim_size_ = g.op( - "Gather", - shape_, - g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), - ) - if g.opset <= 10: - if not descending: - _unimplemented("Sort", "Ascending is not supported") - return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) - else: - return g.op( - "TopK", input, dim_size_, axis_i=dim, largest_i=descending, outputs=2 - ) - - -def _topk_helper( - g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None -): - if out is not None: - _unimplemented("TopK", "Out parameter is not supported") - if not _is_value(k): - k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) - else: - k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) - if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: - k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) - if g.opset <= 10: - if not largest: - _unimplemented("TopK", "Ascending is not supported") - return g.op("TopK", input, k, axis_i=dim, outputs=2) - else: - return g.op( - "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 - ) - - -def _lt_helper(g: jit_utils.GraphContext, input, other): - if g.opset <= 8: - from torch.onnx.symbolic_opset8 import lt as _lt8 - - return _lt8(g, input, other) - else: - from torch.onnx.symbolic_opset9 import lt as _lt9 - - return _lt9(g, input, other) - - -def _interpolate_warning(interpolate_mode): - onnx_op = ( - "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" - ) - warnings.warn( - "You are trying to export the model with " - + onnx_op - + " for ONNX opset version " - "" + str(GLOBALS.export_onnx_opset_version) + ". " - "This operator might cause results to not match the expected results by PyTorch.\n" - "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " - "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " - "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" - "We recommend using opset 11 and above for models using this operator." - ) - - -def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): - if len(axes_i) == 0: - # unnecessary unsqueeze if axes length==0 - return input - elif _is_constant(axes_i[0]): - if g.opset >= 13: - axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) - return g.op("Unsqueeze", input, axes) - return g.op("Unsqueeze", input, axes_i=axes_i) - # Tensor type - if g.opset < 13: - raise errors.SymbolicValueError( - "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input - ) - return g.op("Unsqueeze", input, axes_i[0]) - - -def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): - if _is_constant(axes_i[0]): - if g.opset >= 13: - axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) - return g.op("Squeeze", input, axes) - return g.op("Squeeze", input, axes_i=axes_i) - # Tensor type - if g.opset < 13: - raise errors.SymbolicValueError( - "Opset version must be >= 13 for Squeeze with dynamic axes.", input - ) - axes_t = axes_i[0] - axes_rank = _get_tensor_rank(axes_t) - assert axes_rank is not None - if axes_rank > 1: - raise errors.SymbolicValueError( - "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input - ) - elif axes_rank == 0: - # The axes is a scalar. Unsqueeze it to a rank 1 tensor. - axes_t = _unsqueeze_helper(g, axes_t, [0]) - return g.op("Squeeze", input, axes_t) - return g.op("Squeeze", input, axes_t) - - -def _reducesum_helper( - g: jit_utils.GraphContext, - input, - axes_i=None, - keepdims_i=1, - noop_with_empty_axes_i=0, -): - keepdims_i = _maybe_get_const(keepdims_i, "i") - if g.opset >= 13: - if axes_i: - if not _is_value(axes_i): - axes_i = g.op( - "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) - ) - return g.op( - "ReduceSum", - input, - axes_i, - keepdims_i=keepdims_i, - noop_with_empty_axes_i=noop_with_empty_axes_i, - ) - return g.op( - "ReduceSum", - input, - keepdims_i=keepdims_i, - noop_with_empty_axes_i=noop_with_empty_axes_i, - ) - else: - return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) - - -def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): - output_size = _maybe_get_const(output_size, "is") - if _is_value(output_size): - offset = 2 - offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) - dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) - divisor = _slice_helper( - g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] - ) - divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) - scale_dims = g.op("Div", dividend, divisor) - scales = g.op("Concat", offsets, scale_dims, axis_i=0) - else: - scales_constant = [ - 1.0 - if i < 2 - else float(output_size[-(dim - i)]) - / float(input.type().sizes()[-(dim - i)]) - for i in range(0, dim) - ] - scales = g.op( - "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) - ) - return scales - - -def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): - available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( - scales[0] - ) - - if not available_scales: - return None - - offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) - scales_list = g.op( - "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) - ) - scales = g.op("Concat", offsets, scales_list, axis_i=0) - return scales - - -def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): - if mode == "nearest": - align_corners = None - scales = args[0:] - else: - align_corners = args[0] - scales = args[1:] - scales = _interpolate_get_scales_if_available(g, scales) - return scales, align_corners - - -def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): - offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) - scale_factor_rank = _get_tensor_rank(scale_factor) - if isinstance(scale_factor.type(), _C.ListType) or ( - scale_factor_rank is not None and scale_factor_rank > 0 - ): - return g.op("Concat", offsets, scale_factor, axis_i=0) - else: - scale_factor = _unsqueeze_helper(g, scale_factor, [0]) - scale_factor = g.op( - "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT - ) - scales = [scale_factor for i in range(dim - 2)] - scale_factor = g.op("Concat", offsets, *scales, axis_i=0) - return scale_factor - - -def _interpolate_get_scales_and_mode( - g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners -): - mode = _maybe_get_const(mode, "s") - if "linear" in mode: - mode = "linear" - if "cubic" in mode: - mode = "cubic" - _interpolate_warning(mode) - - align_corners = _maybe_get_const(align_corners, "b") - if isinstance(align_corners, bool) and align_corners: - return _unimplemented("interpolate", "align_corners == True") - - if not input.type().dim(): - return _unimplemented("interpolate", "missing input shape") - dim = input.type().dim() - - if not _is_none(scale_factor): - scale_factor = _interpolate_get_scales(g, scale_factor, dim) - elif not _is_none(size): - if not _is_packed_list(size): - is_scalar = _maybe_get_const(size, "t").dim() == 0 - if is_scalar: - size = _unsqueeze_helper(g, size, [0]) - size = [size for i in range(dim - 2)] - size = g.op("Concat", *size, axis_i=0) - scale_factor = _interpolate_size_to_scales(g, input, size, dim) - else: - return _unimplemented( - "interpolate", "Both size and scales are None in __interpolate" - ) - return scale_factor, mode - - -def _argmin_argmax_helper( - g: jit_utils.GraphContext, - input: torch._C.Value, - dim: torch._C.Value, - keepdim: bool, - op_name: str, -): - def op_wrapper(input, axis_i, keepdims_i): - if g.opset >= 12: - return g.op( - op_name, - input, - axis_i=axis_i, - keepdims_i=keepdims_i, - select_last_index_i=False, - ) - return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) - - if _is_none(dim): - flattened = _reshape_helper( - g, input, g.op("Constant", value_t=torch.tensor([-1])) - ) - output = op_wrapper(flattened, axis_i=0, keepdims_i=False) - if keepdim: - input_shape = g.op("Shape", input) - input_shape_shape = g.op("Shape", input_shape) - new_shape = g.op( - "ConstantOfShape", - input_shape_shape, - value_t=torch.tensor([1], dtype=torch.int64), - ) - output = g.op("Reshape", output, new_shape) - return output - - dim = _parse_arg(dim, "i") - return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) - - -def _interpolate_helper(name, dim, interpolate_mode): - @quantized_args(True, False, False) - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) - align_corners = _maybe_get_scalar(align_corners) - coordinate_transformation_mode = ( - "asymmetric" - if interpolate_mode == "nearest" - else "align_corners" - if align_corners - else "half_pixel" - ) - - if scales is None: - input_size = g.op("Shape", input) - input_size_beg = _slice_helper( - g, input_size, axes=[0], ends=[2], starts=[0] - ) - output_size = g.op( - "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 - ) - output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) - - if g.opset >= 13: - empty_roi = _optional_input_placeholder_tensor(g) - empty_scales = _optional_input_placeholder_tensor(g) - else: - empty_roi = g.op( - "Constant", value_t=torch.tensor([], dtype=torch.float32) - ) - empty_scales = g.op( - "Constant", value_t=torch.tensor([], dtype=torch.float32) - ) - - return g.op( - "Resize", - input, - empty_roi, - empty_scales, - output_size, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s="floor", - ) # only valid when mode="nearest" - else: - if g.opset >= 13: - empty_roi = _optional_input_placeholder_tensor(g) - else: - empty_roi = g.op( - "Constant", value_t=torch.tensor([], dtype=torch.float32) - ) - - return g.op( - "Resize", - input, - empty_roi, - scales, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s="floor", - ) # only valid when mode="nearest" - - return symbolic_fn - - -def __interpolate_helper( - g: jit_utils.GraphContext, - input, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, -): - mode = _maybe_get_const(mode, "s") - if "linear" in mode: - mode = "linear" - if "cubic" in mode: - mode = "cubic" - align_corners = _maybe_get_const(align_corners, "b") - align_corners = False if not isinstance(align_corners, bool) else align_corners - coordinate_transformation_mode = ( - "asymmetric" - if mode == "nearest" - else "align_corners" - if align_corners - else "half_pixel" - ) - - if not _is_none(size): - input_size = g.op("Shape", input) - input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) - # in some cases size is not a packed list but size is a scalar - # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) - # but this information is not always available. Try to get the dim, - # and if not assume that it is not a scalar. - try: - is_scalar = not _is_packed_list(size) and ( - _maybe_get_const(size, "t").dim() == 0 - ) - except AttributeError: - is_scalar = not _is_packed_list(size) - if not is_scalar: - warnings.warn( - "Cannot verify if the output_size is a scalar " - "while exporting interpolate. Assuming that it is not a scalar." - ) - - if is_scalar: - rank = _get_tensor_rank(input) - if rank is None: - return _unimplemented( - "interpolate (with a scalar output_size)", - "missing input shape (try giving an array of output_size values)", - ) - size = _unsqueeze_helper(g, size, [0]) - size = [size for i in range(rank - 2)] - size = g.op("Concat", *size, axis_i=0) - size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) - size = g.op("Concat", input_size, size, axis_i=0) - - if g.opset >= 13: - empty_roi = _optional_input_placeholder_tensor(g) - empty_scales = _optional_input_placeholder_tensor(g) - else: - empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - empty_scales = g.op( - "Constant", value_t=torch.tensor([], dtype=torch.float32) - ) - - return g.op( - "Resize", - input, - empty_roi, - empty_scales, - size, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=mode, # nearest, linear, or cubic - nearest_mode_s="floor", - ) - else: # if not _is_none(scales) - rank = _get_tensor_rank(input) - if rank is None: - return _unimplemented("interpolate (with scales)", "missing input shape") - - if g.opset >= 13: - empty_roi = _optional_input_placeholder_tensor(g) - else: - empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - - scales = _interpolate_get_scales(g, scale_factor, rank) - return g.op( - "Resize", - input, - empty_roi, - scales, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=mode, # nearest, linear, or cubic - nearest_mode_s="floor", - ) # only valid when mode="nearest" - - -def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): - if g.opset < 11: - from torch.onnx.symbolic_opset9 import unbind - elif g.opset <= 12: - from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] - else: - from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef] - return unbind(g, self, dim, _outputs) - - -def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): - if g.opset <= 10: - from torch.onnx.symbolic_opset9 import scatter - else: - # for mypy, scatter was imported two lines above - from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] - return scatter(g, self, dim, index, src) - - -def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): - if g.opset <= 12: - split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) - else: - from torch.onnx.symbolic_opset13 import split - - repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) - split_out = split(g, self, repeats, dim, _outputs=reps) - return split_out if reps > 1 else [split_out] - - -def _repeat_interleave_single_value_repeat_helper( - g: jit_utils.GraphContext, self, repeats, dim -): - from torch.onnx.symbolic_opset9 import flatten, unsqueeze - - if not _is_tensor(repeats): - repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - - const_repeats: bool = _is_constant(repeats) - reps = _maybe_get_const(repeats, "t") - - # Convert 'repeats' to 1-d if it is 0-d. - if _get_tensor_rank(repeats) == 0: - repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) - - # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. - unsqueezed = unsqueeze(g, self, dim + 1) - - # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. - if const_repeats: - # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. - onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] - onehot[dim + 1] = reps - repeats_per_dim = g.op("Constant", value_t=onehot) - else: - # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. - onehot = g.op( - "OneHot", - unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional - g.op( - "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) - ), # depth - g.op( - "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 - ), # on/off values - ) - repeats_per_dim = flatten(g, onehot, 0, 1) - - tiled = g.op("Tile", unsqueezed, repeats_per_dim) - return flatten(g, tiled, dim, dim + 1) - - -def _arange_cast_helper( - g: jit_utils.GraphContext, end, start=None, step=None, dtype=None -) -> tuple[ - _type_utils.JitScalarType, - _C.Value | None, - _C.Value | None, - _C.Value | None, -]: - def _is_all_integral(scalars): - for scalar in scalars: - scalar_type = _type_utils.JitScalarType.from_value( - scalar, _type_utils.JitScalarType.UNDEFINED - ) - if ( - scalar_type != _type_utils.JitScalarType.INT64 - and scalar_type != _type_utils.JitScalarType.UNDEFINED - ): - return False - return True - - # This logic is based on torch.arange docs. If "dtype" is provided, - # infer input types from dtype. If not, then check if any of start, stop, - # or step are floating point, and infer the type from get_default. - # Otherwise, the dtype is inferred to be torch.int64. - if dtype is None or (_is_value(dtype) and _is_none(dtype)): - if _is_all_integral([start, end, step]): - scalar_type = _type_utils.JitScalarType.INT64 - else: - scalar_type = _type_utils.JitScalarType.from_dtype( - torch.get_default_dtype() - ) - else: - assert isinstance(dtype, int) - # TODO(justinchuby): Check if dtype is indeed a int. - scalar_type = _type_utils.JitScalarType(dtype) - - start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None - end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None - step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None - return scalar_type, end, start, step - - -def _arange_helper(g: jit_utils.GraphContext, *args): - if g.opset <= 10: - from torch.onnx.symbolic_opset9 import arange - else: - from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef] - return arange(g, *args) - - -def _size_helper(g: jit_utils.GraphContext, self, dim): - full_shape = g.op("Shape", self) - from torch.onnx.symbolic_opset9 import select - - return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) - - -def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): - # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] - # 2. expand index => [..., dim, ...], same shape as self except for dim. - # 3. expand value as well. - # 4. apply onnx::scatter. - - from torch.onnx.symbolic_opset9 import expand - - if g.opset <= 10: - from torch.onnx.symbolic_opset9 import scatter - else: - # for mypy, scatter was imported two lines above - from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] - - if self.type().dim() is None: - return _unimplemented("index_fill", "input rank not accessible") - self_dim = self.type().dim() - dim_value = _parse_arg(dim, "i") - if dim_value < 0: - dim_value += self_dim - unsqueezed_index = _unsqueeze_helper( - g, index, [i for i in range(self_dim) if i != dim_value] - ) - expanded_index_shape = scatter( - g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) - ) - expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) - return expanded_index_shape, expanded_index - - -# By default, when any value in the 'shape' input is equal to zero -# the corresponding dimension value is copied from the input tensor dynamically. -# allowzero=1 indicates that if any value in the 'shape' input is set to zero, -# the zero value is honored, similar to NumPy. -# allowzero=1 is only supported for opset version >= 14. -def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): - shape = _maybe_get_const(shape, "is") - if not _is_value(shape): - shape = g.op("Constant", value_t=torch.LongTensor(shape)) - if g.opset <= 13: - if allowzero == 1: - _onnx_opset_unsupported( - "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input - ) - return g.op("Reshape", input, shape) - else: - return g.op("Reshape", input, shape, allowzero_i=allowzero) - - -def _batchnorm_helper( - g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var -): - from torch.onnx.symbolic_opset9 import _var_mean - - batch_size = _get_tensor_dim_size(input, 0) - channel_size = _get_tensor_dim_size(input, 1) - - if weight is None or _is_none(weight): - if channel_size is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of batch_norm for unknown channel size.", - input, - ) - weight_value = torch.tensor( - [1.0] * channel_size, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ) - weight = g.op("Constant", value_t=weight_value) - if bias is None or _is_none(bias): - if channel_size is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of batch_norm for unknown channel size.", - input, - ) - bias_value = torch.tensor( - [0.0] * channel_size, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ) - bias = g.op("Constant", value_t=bias_value) - # If track_running_stats is set to False batch statistics are instead used during evaluation time - if ( - running_mean is None - or _is_none(running_mean) - or running_var is None - or _is_none(running_var) - ): - assert batch_size is not None and channel_size is not None - reshape_in = _reshape_helper( - g, - input, - g.op( - "Constant", - value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), - ), - ) - trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) - running_var, running_mean = _var_mean( - g, - trans_in, - g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), - False, - False, - ) - return weight, bias, running_mean, running_var - - -def _avgpool_helper( - tuple_fn: Callable[[Any], Sequence[int]], - padding: int | Sequence[int], - kernel_size, - stride, - divisor_override, - name, -) -> tuple[int, ...]: - if divisor_override and divisor_override.node().kind() != "prim::Constant": - _unimplemented(name, "divisor_override") - return tuple(tuple_fn(padding)) - - -def check_training_mode(op_train_mode: int, op_name: str) -> None: - """Warns the user if the model's training mode and the export mode do not agree.""" - if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: - return - - if op_train_mode: - op_mode_enum = _C_onnx.TrainingMode.TRAINING - else: - op_mode_enum = _C_onnx.TrainingMode.EVAL - if op_mode_enum == GLOBALS.training_mode: - # The modes agree. Do nothing - return - - op_mode_text = f"train={bool(op_train_mode)}" - # Setting the model mode could result in op_mode != GLOBALS.training_mode - # if the model is a FuncModule. In this case we warn the user of - # the state and export depending on op_mode - # This is to support use-cases of fixing certain layer weights - # in training. - warnings.warn( - f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " - f"is set to {op_mode_text}. Exporting with {op_mode_text}." - ) - - -def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): - input_size = g.op("Shape", input) - slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) - slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] - if end_dim < dim - 1: - slice3 = _slice_helper( - g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] - ) - slices = [ - slice1, - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), - slice3, - ] - - final_shape = g.op("Concat", *slices, axis_i=0) - from torch.onnx.symbolic_opset9 import _reshape_from_tensor - - return _reshape_from_tensor(g, input, final_shape) - - -def _is_split_static(split_size_or_sizes, _outputs): - if _outputs is None: - return False - if ( - _is_value(split_size_or_sizes) - and split_size_or_sizes.node().kind() != "onnx::Constant" - ): - return False - return True - - -def _optional_input_placeholder_tensor(g): - n = g.op("prim::Constant") - n.setType(_C.OptionalType.ofTensor()) - return n - - -def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): - rank = _get_tensor_rank(self) - if rank is not None and any( - _get_tensor_dim_size(self, i) == 0 for i in range(rank) - ): - # If input tensor is empty, according to ONNX ReduceSum definition, - # set keepdims=1 so that the resulted tensor has the same rank as the input. - return g.op(op_name, self, keepdims_i=1) - return g.op(op_name, self, keepdims_i=0) - - -def dequantize_helper( - g: jit_utils.GraphContext, - qtensor: _C.Value, - qdtype: _C_onnx.TensorProtoDataType | None = None, -) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: - """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. - - Args: - g: Graph, the ONNX IR graph that is under construction. - qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) - for per tensor quantization, or - (quantized_tensor, scale, zero_point, axis) for per channel quantization, - representing the quantized tensor. - qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the - data type of quantized tensor. It must be either - torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. - """ - unpacked_qtensors = _unpack_quantized_tensor(qtensor) - tensor, scale, zero_point = unpacked_qtensors[:3] - axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None - axis_i = _get_const(axis, "i", "axis") - input_qdtype = _type_utils.JitScalarType.from_value(tensor) - if qdtype is None: - if input_qdtype is not None: - qdtype = input_qdtype.onnx_type() - else: - qdtype = _C_onnx.TensorProtoDataType.UINT8 - value = g.op("Cast", tensor, to_i=qdtype) - scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) - zero_point = g.op("Cast", zero_point, to_i=qdtype) - - if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: - _onnx_opset_unsupported_detailed( - "DequantizeLinear", - GLOBALS.export_onnx_opset_version, - 13, - "Attribute axis is not supported.", - qtensor, - ) - - return ( - g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), - scale, - zero_point, - axis, - ) - - -def quantize_helper( - g: jit_utils.GraphContext, - tensor: _C.Value, - scale: _C.Value, - zero_point: _C.Value, - axis: _C.Value | None = None, -) -> _C.Value: - """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. - - Args: - g: Graph, the ONNX IR graph that is under construction. - tensor: torch._C.Value, representing the tensor to be quantized. - scale: torch._C.Value, quantized scale. - zero_point: torch._C.Value, quantized zero point. - axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. - Otherwise, represents per channel quantization, along given axis. - - Returns: - A TupleConstruct storing information of the quantized tensor. - """ - if ( - axis is not None - and not _is_none(axis) - and GLOBALS.export_onnx_opset_version < 13 - ): - _onnx_opset_unsupported_detailed( - "QuantizeLinear", - GLOBALS.export_onnx_opset_version, - 13, - "Attribute axis is not supported.", - tensor, - ) - - assert scale is not None - if ( - _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) - != _type_utils.JitScalarType.FLOAT - ): - scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - assert zero_point is not None - if _type_utils.JitScalarType.from_value( - zero_point, _type_utils.JitScalarType.UNDEFINED - ) not in { - _type_utils.JitScalarType.UINT8, - _type_utils.JitScalarType.INT8, - }: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) - output = g.op( - "QuantizeLinear", - tensor, - scale, - zero_point, - axis_i=_get_const(axis, "i", "axis"), - ) - args = [output, scale, zero_point] - if axis is not None and not _is_none(axis): - args.append(axis) - return g.op("prim::TupleConstruct", *args) - - -def requantize_bias_helper( - g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None -): - """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. - In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. - Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using - regular operators. - """ - bias_scale = g.op("Mul", weight_scale, input_scale) - bias_scale_shape = g.op("Shape", bias_scale) - bias_zero_point = g.op( - "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) - ) - q_bias = g.op( - "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 - ) - axis_args = [] - if axis is not None and not _is_none(axis): - axis_args.append(axis) - return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) - - -def args_have_same_dtype(args): - assert args - base_dtype = _type_utils.JitScalarType.from_value(args[0]) - has_same_dtype = all( - _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args - ) - return has_same_dtype - - -def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): - """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. - This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch - operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic - `Clip(INPUT)` (opset version < 12). - - Args: - g (torch._C.Graph): graph to write the ONNX representation into. - op_name (str): operator name in ONNX. - *args (tuple): operands to the operator. - **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) - indicating the smallest opset version to trigger such casting behavior and "target_float_t" - (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. - - Returns: - Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. - """ - opset_before = kwargs.pop("opset_before", None) - target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) - - inputs = list(args) - dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) - - require_cast = not _is_fp(inputs[0]) and ( - opset_before is None or GLOBALS.export_onnx_opset_version < opset_before - ) - - if require_cast: - for input in inputs: - if input.isCompleteTensor(): - input_scalar_type = _type_utils.JitScalarType.from_value(input) - if input_scalar_type != dtype_0: - raise errors.SymbolicValueError( - f"Inputs of {op_name} must have same dtype." - f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", - input, - ) - for i, input in enumerate(inputs): - if input.isCompleteTensor() and not _is_fp(input): - inputs[i] = g.op( - "Cast", - input, - to_i=target_float_t.onnx_type(), - ) - - self = g.op(op_name, *inputs, **kwargs) - - if require_cast: - self = g.op("Cast", self, to_i=dtype_0.onnx_type()) - - return self - - -def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.UNDEFINED - ) - if scalar_type != _type_utils.JitScalarType.UNDEFINED: - # This check only covers traced modules where dtype is present - # pytorch reduce-ops cast all other integral types to int64 - if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: - self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) - return self - - -def _apply_params(*args, **kwargs): - """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" - - def _apply(fn): - return fn(*args, **kwargs) - - return _apply - - -def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): - def symbolic(g, self, dim=None, keepdim=None): - self = _maybe_cast_reduce_op_input(g, self) - if dim is None or dim == (): - # Dim can be 0, which will cause (not dim) == True. So we don't want to do - # (not dim) - # all-reduce path - return _handle_reduce_dim_none(g, self, onnx_op_name) - else: - # dim-reduce path - keepdim = _get_const(keepdim, "i", "keepdim") - if g.opset < 18: - desc = "is" if allow_multi_dim_support else "i" - dim = _get_const(dim, desc, "dim") - dim_list = dim if allow_multi_dim_support else [dim] - return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) - else: - if _is_value(dim): - axes = dim - else: - if allow_multi_dim_support: - axes = g.op( - "Constant", value_t=torch.tensor(dim, dtype=torch.long) - ) - else: - axes = g.op( - "Constant", value_t=torch.tensor([dim], dtype=torch.long) - ) - return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) - - return symbolic - - -def _overload_by_arg_count(fn): - @functools.wraps(fn) - def wrapper(g, *args): - overloads = fn(g, *args) - for overload in overloads: - arg_descriptors = overload._arg_descriptors - if len(arg_descriptors) == len(args): - return overload(g, *args) - return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") - - return wrapper - - -def _reduce_with_dtype_helper( - onnx_op: str, name: str, allow_multi_dim_support: bool = True -): - symbolic = _reduce_op_symbolic_helper( - onnx_op, allow_multi_dim_support=allow_multi_dim_support - ) - - @_overload_by_arg_count - def reduce(g, *args, **kwargs): - @quantized_args(True) - @parse_args("v", "none") - def reduce_nodim(g, self, dtype): - dtype_onnx = None - if dtype.node().kind() == "onnx::Constant": - dtype = _get_const(dtype, "i", "dtype") - dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() - self = g.op("Cast", self, to_i=dtype_onnx) - elif dtype.node().kind() != "prim::Constant": - return _unimplemented(name, "dtype", dtype) - result = symbolic(g, self) - if dtype_onnx is not None: - result_dtype_onnx = _type_utils.JitScalarType.from_value( - result - ).onnx_type() - if result_dtype_onnx != dtype_onnx: - result = g.op("Cast", result, to_i=dtype_onnx) - return result - - dim_desc = "is" if allow_multi_dim_support else "i" - - @quantized_args(True) - @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] - def reduce_dim(g, self, dim, keepdim, dtype): - dtype_onnx = None - if dtype.node().kind() == "onnx::Constant": - dtype = _get_const(dtype, "i", "dtype") - dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() - self = g.op("Cast", self, to_i=dtype_onnx) - elif dtype.node().kind() != "prim::Constant": - return _unimplemented(name, "dtype", dtype) - result = symbolic(g, self, dim, keepdim) - if dtype_onnx is not None: - result_dtype_onnx = _type_utils.JitScalarType.from_value( - result - ).onnx_type() - if result_dtype_onnx != dtype_onnx: - result = g.op("Cast", result, to_i=dtype_onnx) - return result - - return reduce_nodim, reduce_dim - - return reduce - - -def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - # torch.max(input) - if dim_or_y is None and keepdim is None: - return g.op("ReduceMax", self, keepdims_i=0) - # torch.max(input, other) - if keepdim is None: - return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) - # torch.max(input, dim, keepdim) - else: - keepdim = _get_const(keepdim, "i", "keepdim") - dim = _get_const(dim_or_y, "i", "dim") - if g.opset < 18: - max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) - else: - axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) - indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) - return max, indices - - -def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - # torch.min(input) - if dim_or_y is None and keepdim is None: - return g.op("ReduceMin", self, keepdims_i=0) - # torch.min(input, other) - if keepdim is None: - return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) - # torch.min(input, dim, keepdim) - else: - keepdim = _get_const(keepdim, "i", "keepdim") - dim = _get_const(dim_or_y, "i", "dim") - if g.opset < 18: - min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) - else: - axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) - indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) - return min, indices - - -def _numel_helper(g: jit_utils.GraphContext, self): - shape = g.op("Shape", self) - return g.op("ReduceProd", shape, keepdims_i=0) - - -@parse_args("v", "is", "i", "i") -def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): - if g.opset < 18: - if dim is None: - mean = g.op("ReduceMean", input, keepdims_i=0) - t_mean = mean - num_elements = _numel_helper(g, input) - else: - mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) - t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) - redudced_dims = g.op("Shape", input) - # dim could contain one or multiple dimensions - redudced_dims = g.op( - "Gather", - redudced_dims, - g.op("Constant", value_t=torch.tensor(dim)), - axis_i=0, - ) - num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) - sub_v = g.op("Sub", input, t_mean) - sqr_sub = g.op("Mul", sub_v, sub_v) - keepdim_mean = 0 if dim is None else keepdim - var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) - # Correct bias in calculating variance, by dividing it over (N - correction) instead on N - if correction is None: - correction = 1 - if correction != 0: - num_elements = g.op( - "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT - ) - one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) - mul = g.op("Mul", var, num_elements) - var = g.op("Div", mul, g.op("Sub", num_elements, one)) - return var, mean - else: - axes = None - if dim is None: - mean = g.op("ReduceMean", input, keepdims_i=0) - t_mean = mean - num_elements = _numel_helper(g, input) - else: - axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) - t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) - redudced_dims = g.op("Shape", input) - # dim could contain one or multiple dimensions - redudced_dims = g.op( - "Gather", - redudced_dims, - g.op("Constant", value_t=torch.tensor(dim)), - axis_i=0, - ) - num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) - sub_v = g.op("Sub", input, t_mean) - sqr_sub = g.op("Mul", sub_v, sub_v) - keepdim_mean = 0 if dim is None else keepdim - if axes is None: - var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) - else: - var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) - # Correct bias in calculating variance, by dividing it over (N - correction) instead on N - if correction is None: - correction = 1 - if correction != 0: - num_elements = g.op( - "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT - ) - one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) - mul = g.op("Mul", var, num_elements) - var = g.op("Div", mul, g.op("Sub", num_elements, one)) - return var, mean - - -def _embedding_bag_helper( - g: jit_utils.GraphContext, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, -): - if scale_grad_by_freq and GLOBALS.export_training: - return _onnx_unsupported( - "embedding_bag with scale_grad_by_freq for training mode" - ) - if padding_idx is not None and padding_idx >= 0: - raise RuntimeError("embedding_bag with padding_idx") - - loop_condition = g.op("Constant", value_t=torch.tensor(1)) - loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) - zero = g.op("Constant", value_t=torch.tensor([0])) - - indices_len = _unsqueeze_helper( - g, - _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), - [0], - ) - if not include_last_offset: - offsets = [offsets, indices_len] - offsets = g.op("Concat", *offsets, axis_i=0) - - # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by - # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. - # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. - offsets_starts = _slice_helper( - g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] - ) - offsets_ends = _slice_helper( - g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] - ) - - loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) - - loop, (loop_context,), _ = jit_utils.add_op_with_blocks( - g, "Loop", loop_len, loop_condition, n_blocks=1 - ) - loop_block = loop_context.block - - # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return - block_input_iter = utils._add_input_to_block(loop_block) - utils._add_input_to_block(loop_block) - - indices_start = loop_context.op( - "Gather", offsets_starts, block_input_iter, axis_i=0 - ) - indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) - indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) - indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) - - indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) - embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) - if not _is_none(per_sample_weights): - per_sample_weights_row = loop_context.op( - "Slice", per_sample_weights, indices_start, indices_end, zero - ) - per_sample_weights_row = _unsqueeze_helper( - loop_context, per_sample_weights_row, [1] - ) - embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) - if mode == 0: - embeddings = _reducesum_helper( - loop_context, embeddings, axes_i=[0], keepdims_i=0 - ) - elif mode == 1: - if loop_context.opset < 18: - embeddings = loop_context.op( - "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 - ) - else: - axes = loop_context.op( - "Constant", value_t=torch.tensor([0], dtype=torch.long) - ) - embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) - else: - if loop_context.opset < 18: - embeddings = loop_context.op( - "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 - ) - else: - axes = loop_context.op( - "Constant", value_t=torch.tensor([0], dtype=torch.long) - ) - embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) - - cond_out = loop_context.op( - "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL - ) - utils._add_output_to_block(loop_block, cond_out) - utils._add_output_to_block(loop_block, embeddings) - - # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. - # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. - return loop.node().output(), None, None, None - - -def _linalg_vector_norm_helper( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: float, - dim: Sequence[int] | None, - keepdim: bool, - dtype: torch._C.Value, -): - axes = None - # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html - if _is_none(dim): - self = _reshape_helper(g, self, [-1]) - keepdim = False - elif g.opset >= 18: - axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - - if ord == math.inf: - if g.opset < 18: - result = g.op( - "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim - ) - else: - if axes is None: - result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) - else: - result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) - elif ord == -math.inf: - if g.opset < 18: - result = g.op( - "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim - ) - else: - if axes is None: - result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) - else: - result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) - elif ord == 0: - if g.opset < 11: - return _onnx_opset_unsupported_detailed( - "linalg_vector_norm", 9, 11, "ord=0 not supported", self - ) - else: - if dim is None: - self = _reshape_helper( - g, - self, - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), - ) - keepdim = False - - cond_op = g.op( - "Not", - g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), - ) - cond_op = g.op( - "Cast", - cond_op, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) - elif ord == 1: - if g.opset < 18: - result = _reduce_op_symbolic_helper("ReduceL1")( - g, self, dim=dim, keepdim=keepdim - ) - else: - if axes is None: - result = _reduce_op_symbolic_helper("ReduceL1")( - g, self, keepdim=keepdim - ) - else: - result = _reduce_op_symbolic_helper("ReduceL1")( - g, self, axes, keepdim=keepdim - ) - elif ord == 2: - if g.opset < 18: - result = _reduce_op_symbolic_helper("ReduceL2")( - g, self, dim=dim, keepdim=keepdim - ) - else: - if axes is None: - result = _reduce_op_symbolic_helper("ReduceL2")( - g, self, keepdim=keepdim - ) - else: - result = _reduce_op_symbolic_helper("ReduceL2")( - g, self, axes, keepdim=keepdim - ) - else: - ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) - result = _reducesum_helper( - g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim - ) - result = g.op( - "Pow", - result, - g.op( - "Div", - g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), - ord_op, - ), - ) - - if not _is_none(dtype): - dtype = _get_const(dtype, "i", "dtype") - result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] - return result - - -# Deprecated. Internally use _type_utils.ScalarType -# TODO: remove these once we support Type's in the JIT IR and we can once again -# use the unified toType operator -cast_pytorch_to_onnx = { - "Byte": _C_onnx.TensorProtoDataType.UINT8, - "Char": _C_onnx.TensorProtoDataType.INT8, - "Double": _C_onnx.TensorProtoDataType.DOUBLE, - "Float": _C_onnx.TensorProtoDataType.FLOAT, - "Half": _C_onnx.TensorProtoDataType.FLOAT16, - "Int": _C_onnx.TensorProtoDataType.INT32, - "Long": _C_onnx.TensorProtoDataType.INT64, - "Short": _C_onnx.TensorProtoDataType.INT16, - "Bool": _C_onnx.TensorProtoDataType.BOOL, - "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, - "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, - "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, - "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, -} - -# Deprecated. Internally use _type_utils.ScalarType -scalar_name_to_pytorch = { - "uint8_t": "Byte", - "int8_t": "Char", - "double": "Double", - "float": "Float", - "half": "Half", - "int": "Int", - "int64_t": "Long", - "int16_t": "Short", - "bool": "Bool", - "complex64": "ComplexFloat", - "complex128": "ComplexDouble", - "qint8": "QInt8", - "quint8": "QUInt8", - "qint32": "QInt32", - "bfloat16": "BFloat16", -} - - -# Deprecated. Internally use _type_utils.ScalarType -# This indicates each scalar type's corresponding -# torch type. Related source: -# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h -scalar_type_to_pytorch_type = [ - torch.uint8, # 0 - torch.int8, # 1 - torch.short, # 2 - torch.int, # 3 - torch.int64, # 4 - torch.half, # 5 - torch.float, # 6 - torch.double, # 7 - torch.complex32, # 8 - torch.complex64, # 9 - torch.complex128, # 10 - torch.bool, # 11 - torch.qint8, # 12 - torch.quint8, # 13 - torch.qint32, # 14 - torch.bfloat16, # 15 -] - -# Deprecated. Internally use _type_utils.ScalarType -# source of truth is -# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp -pytorch_name_to_type = { - "Byte": torch.uint8, - "Char": torch.int8, - "Double": torch.double, - "Float": torch.float, - "Half": torch.half, - "Int": torch.int, - "Long": torch.int64, - "Short": torch.short, - "Bool": torch.bool, - "ComplexFloat": torch.complex64, - "ComplexDouble": torch.complex128, - "QInt8": torch.qint8, - "QUInt8": torch.quint8, - "QInt32": torch.qint32, - "BFloat16": torch.bfloat16, -} - - -# Deprecated. Internally use _type_utils.ScalarType -scalar_type_to_onnx = [ - cast_pytorch_to_onnx["Byte"], # 0 - cast_pytorch_to_onnx["Char"], # 1 - cast_pytorch_to_onnx["Short"], # 2 - cast_pytorch_to_onnx["Int"], # 3 - cast_pytorch_to_onnx["Long"], # 4 - cast_pytorch_to_onnx["Half"], # 5 - cast_pytorch_to_onnx["Float"], # 6 - cast_pytorch_to_onnx["Double"], # 7 - cast_pytorch_to_onnx["Undefined"], # 8 - cast_pytorch_to_onnx["ComplexFloat"], # 9 - cast_pytorch_to_onnx["ComplexDouble"], # 10 - cast_pytorch_to_onnx["Bool"], # 11 - cast_pytorch_to_onnx["Char"], # 12 - cast_pytorch_to_onnx["Byte"], # 13 - cast_pytorch_to_onnx["Int"], # 14 - cast_pytorch_to_onnx["BFloat16"], # 15 -] - -# Global set to store the list of quantized operators in the network. -# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. -_quantized_ops: set[int] = set() +from torch.onnx._internal.torchscript_exporter.symbolic_helper import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 0b8e2478ce33..9bda69b81ab6 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,1190 +1,11 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type +"""Backward compatibility module for torch.onnx.symbolic_opset10.""" + from __future__ import annotations -import functools -import sys -import warnings -from typing import TYPE_CHECKING -import torch -import torch._C._onnx as _C_onnx -import torch.onnx -from torch import _C +__all__: list[str] = [] -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import ( - _constants, - _type_utils, - errors, - symbolic_helper, - symbolic_opset9 as opset9, +from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import ( # noqa: F401 + _slice, ) -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils, registration - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -# This file exports ONNX ops for opset 10 -# Opset 10 is supported by ONNX release 1.5.0 -# release on 04/24/19 - - -__all__ = [ - "dequantize", - "div", - "embedding_bag", - "fake_quantize_per_tensor_affine", - "flip", - "fmod", - "isfinite", - "isinf", - "nan_to_num", - "quantize_per_tensor", - "quantized_add_relu", - "quantized_add", - "quantized_cat", - "quantized_conv1d_relu", - "quantized_conv2d_relu", - "quantized_conv3d_relu", - "quantized_conv1d", - "quantized_conv2d", - "quantized_conv3d", - "quantized_conv_transpose1d", - "quantized_conv_transpose2d", - "quantized_conv_transpose3d", - "quantized_group_norm", - "quantized_hardswish", - "quantized_instance_norm", - "quantized_layer_norm", - "quantized_leaky_relu", - "quantized_linear", - "quantized_linear_relu", - "quantized_mul", - "quantized_sigmoid", - "slice", - "sort", - "topk", -] - - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) - - -@_onnx_symbolic("aten::div") -def div(g: jit_utils.GraphContext, self, other, *args): - if len(args) == 0: - return opset9.true_divide(g, self, other) - else: - return _div_rounding_mode(g, self, other, *args) - - -@symbolic_helper.parse_args("v", "v", "s") -def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): - if rounding_mode == "floor": - return _floor_divide(g, self, other) - else: - return opset9._div_rounding_mode(g, self, other, rounding_mode) - - -@_onnx_symbolic("aten::_floor_divide") -def _floor_divide(g: jit_utils.GraphContext, self, other): - if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): - out = opset9.true_divide(g, self, other) - return g.op("Floor", out) - else: - # Integer division does truncation rounding - div = g.op("Div", self, other) - # Division is negative if: self < 0 != other < 0 - zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) - negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) - - # For negative numbers with self % other != 0, subtract 1 to round down instead of up - mod = g.op("Mod", self, other, fmod_i=0) - fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) - - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - fixup = g.op("Sub", div, one) - return g.op("Where", fixup_mask, fixup, div) - - -@_onnx_symbolic("aten::sort") -@symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): - return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) - - -@_onnx_symbolic("aten::topk") -@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") -def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): - return symbolic_helper._topk_helper( - g, self, k, dim, largest=largest, sorted=sorted, out=out - ) - - -def _aten_max_pool_onnx( - g: jit_utils.GraphContext, - self: _C.Value, - kernel_shape: Sequence[int], - strides: Sequence[int], - pads: Sequence[int], - dilations: Sequence[int], - ceil_mode: bool, - unbatched_rank: int, -) -> _C.Value: - self_rank = g.op("Size", g.op("Shape", self)) - if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 - self = g.op( - "Unsqueeze", - self, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), - ) - - pool_result, _ = g.op( - "MaxPool", - self, - outputs=2, - ceil_mode_i=ceil_mode, - dilations_i=dilations, - kernel_shape_i=kernel_shape, - pads_i=pads, - strides_i=strides, - ) - - if self_rank == unbatched_rank: - pool_result = g.op( - "Squeeze", - pool_result, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), - ) - - return pool_result - - -# For MaxPool -def _adjust_attributes_of_max_pool( - expand_size: int, - kernel_size: Sequence[int] | int, - stride: Sequence[int] | int, - padding: Sequence[int] | int, - dilation: Sequence[int] | int, -) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: - """Adjust attributes of avg_pool to match ONNX specification.""" - - if isinstance(dilation, int): - dilation = [dilation] * expand_size - - if isinstance(kernel_size, int): - kernel_shape = [kernel_size] * expand_size - else: - kernel_shape = kernel_size # type: ignore[assignment] - - if isinstance(padding, int): - pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] - elif len(padding) == 1: - pads = padding * expand_size * 2 # type: ignore[operator, assignment] - elif len(padding) == 2: - # 2D padding - pads = padding * 2 # type: ignore[operator, assignment] - elif len(padding) == 3: - # 3D padding - pads = padding * 2 # type: ignore[operator, assignment] - else: - # When padding is already done for all dimensions, - # we don't need to double it - # eg: (1, 1, 1, 1, 1, 1) - pads = padding # type: ignore[assignment] - - if isinstance(stride, int): - strides = [stride] * expand_size - elif not stride: - strides = kernel_shape - else: - strides = stride # type: ignore[assignment] - - return (kernel_shape, strides, pads, dilation) - - -def _aten_max_pool_with_indices_onnx( - g: jit_utils.GraphContext, - self: _C.Value, - kernel_shape: Sequence[int], - strides: Sequence[int], - pads: Sequence[int], - dilations: Sequence[int], - ceil_mode: bool, - unbatched_rank: int, - n_dims_one: Sequence[int], - n_dims_zero: Sequence[int], - n_dims_axes: Sequence[int], -) -> tuple[_C.Value, Sequence[int]]: - self_rank = g.op("Size", g.op("Shape", self)) - if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 - self = g.op( - "Unsqueeze", - self, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), - ) - - pool_result, indices = g.op( - "MaxPool", - self, - outputs=2, - ceil_mode_i=ceil_mode, - dilations_i=dilations, - kernel_shape_i=kernel_shape, - pads_i=pads, - strides_i=strides, - ) - _, flatten_indices = g.op( - "MaxPool", - self, - outputs=2, - dilations_i=dilations, - kernel_shape_i=n_dims_one, - strides_i=n_dims_one, - ) - - ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) - starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) - axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) - - delta = g.op("Slice", flatten_indices, starts, ends, axes) - indices = g.op("Sub", indices, delta) - - if self_rank == unbatched_rank: - pool_result = g.op( - "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) - ) - indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) - - return (pool_result, indices) - - -@_onnx_symbolic( - "aten::max_pool1d", - decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], -) -@_onnx_symbolic( - "aten::max_pool2d", - decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], -) -@_onnx_symbolic( - "aten::max_pool3d", - decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], -) -@_onnx_symbolic( - "aten::max_pool1d_with_indices", - decorate=[ - symbolic_helper._apply_params( - "max_pool1d_with_indices", - 1, - return_indices=True, - ) - ], -) -@_onnx_symbolic( - "aten::max_pool2d_with_indices", - decorate=[ - symbolic_helper._apply_params( - "max_pool2d_with_indices", - 2, - return_indices=True, - ) - ], -) -@_onnx_symbolic( - "aten::max_pool3d_with_indices", - decorate=[ - symbolic_helper._apply_params( - "max_pool3d_with_indices", - 3, - return_indices=True, - ) - ], -) -def _max_pool(name: str, expand_size: int, return_indices: bool): - @symbolic_helper.quantized_args(True, False, False, False, False, False) - @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") - def symbolic_fn( - g: jit_utils.GraphContext, - input: _C.Value, - kernel_size: Sequence[int], - stride: Sequence[int], - padding: int | Sequence[int], - dilation: Sequence[int], - ceil_mode: bool, - ): - kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( - expand_size, kernel_size, stride, padding, dilation - ) - - if return_indices: - return _aten_max_pool_with_indices_onnx( - g, - input, - kernel_shape, - strides, - pads, - dilations, - ceil_mode, - expand_size + 1, - ([1] * expand_size), - ([0] * expand_size), - ([2 + i for i in range(expand_size)]), - ) - else: - return _aten_max_pool_onnx( - g, - input, - kernel_shape, - strides, - pads, - dilations, - ceil_mode, - expand_size + 1, - ) - - return symbolic_fn - - -# For AvgPool -def _adjust_attributes_of_avg_pool( - expand_size: int, - kernel_size: Sequence[int] | int, - stride: Sequence[int] | int, - padding: Sequence[int] | int, -) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: - """Adjust attributes of avg_pool to match ONNX specification.""" - - if isinstance(kernel_size, int): - kernel_shape = [kernel_size] * expand_size - else: - kernel_shape = kernel_size # type: ignore[assignment] - - if isinstance(padding, int): - pads = [padding] * expand_size * 2 - elif len(padding) == 1: - pads = padding * expand_size * 2 # type: ignore[operator, assignment] - elif len(padding) == 2: - pads = padding * expand_size # type: ignore[operator, assignment] - else: - pads = padding * 2 # type: ignore[operator, assignment] - - if isinstance(stride, int): - strides = [stride] * expand_size - elif not stride: - strides = kernel_shape - else: - strides = stride # type: ignore[assignment] - - return (kernel_shape, strides, pads) - - -@_onnx_symbolic( - "aten::avg_pool1d", - decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], -) -@_onnx_symbolic( - "aten::avg_pool2d", - decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], -) -@_onnx_symbolic( - "aten::avg_pool3d", - decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], -) -def _avg_pool(name, expand_size): - @symbolic_helper.quantized_args(True, False, False, False, False, False, False) - @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") - def symbolic_fn( - g, - input: _C.Value, - kernel_size: Sequence[int], - stride: Sequence[int], - padding: int | Sequence[int], - ceil_mode: int, - count_include_pad: int, - divisor_override=None, - ): - kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( - expand_size, kernel_size, stride, padding - ) - - result = g.op( - "AveragePool", - input, - ceil_mode_i=ceil_mode, - count_include_pad_i=count_include_pad, - kernel_shape_i=kernel_shape, - pads_i=pads, - strides_i=strides, - ) - - return result - - return symbolic_fn - - -@_onnx_symbolic( - "aten::upsample_nearest1d", - decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest2d", - decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest3d", - decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_linear1d", - decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], -) -@_onnx_symbolic( - "aten::upsample_bilinear2d", - decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], -) -@_onnx_symbolic( - "aten::upsample_trilinear3d", - decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], -) -def _interpolate(name, dim, interpolate_mode): - @symbolic_helper.quantized_args(True, False, False) - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = symbolic_helper._get_interpolate_attributes( - g, interpolate_mode, args - ) - symbolic_helper._interpolate_warning(interpolate_mode) - align_corners = symbolic_helper._maybe_get_scalar(align_corners) - if align_corners: - return symbolic_helper._unimplemented(name, "align_corners == True", input) - if scales is None: - scales = symbolic_helper._interpolate_size_to_scales( - g, input, output_size, dim - ) - return g.op("Resize", input, scales, mode_s=interpolate_mode) - - return symbolic_fn - - -@_onnx_symbolic("aten::__interpolate") -def __interpolate( - g: jit_utils.GraphContext, - input, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, - antialias, -): - scales, mode = symbolic_helper._interpolate_get_scales_and_mode( - g, input, size, scale_factor, mode, align_corners - ) - return g.op("Resize", input, scales, mode_s=mode) - - -def _slice( - g: jit_utils.GraphContext, - input: torch._C.Value, - axes: list | torch.Tensor | torch._C.Value, - starts: list | torch.Tensor | torch._C.Value, - ends: list | torch.Tensor | torch._C.Value, - steps: list | torch.Tensor | torch._C.Value | None = None, -): - def is_none_value(value): - if value is None: - return True - return ( - isinstance(value, torch._C.Value) - and value.node().kind() == "prim::Constant" - and isinstance(value.type(), _C.NoneType) - ) - - def to_slice_input(list_or_value, default_value=None): - # Convert input param into a 1D torch.Value. - if is_none_value(list_or_value) and default_value is not None: - list_or_value = [default_value] - - if isinstance(list_or_value, torch.Tensor): - return g.op("Constant", value_t=list_or_value.clone().detach()) - elif isinstance(list_or_value, list): - return g.op("Constant", value_t=torch.tensor(list_or_value)) - - rank = symbolic_helper._get_tensor_rank(list_or_value) - if rank == 0: - return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) - if rank == 1: - return list_or_value - raise errors.SymbolicValueError( - f"Rank must be 0 or 1, not {rank}", list_or_value - ) - - def get_const_value(list_or_value): - if isinstance(list_or_value, (list, torch.Tensor)): - if len(list_or_value) == 1: - return list_or_value[0] - return None - return symbolic_helper._maybe_get_const(list_or_value, "i") - - # Check if slice is a no-op - if ( - get_const_value(starts) == 0 - and get_const_value(ends) == _constants.INT64_MAX - and (steps is None or get_const_value(steps) == 1) - ): - return input - - axes = to_slice_input(axes) - starts = to_slice_input(starts, default_value=0) - ends = to_slice_input(ends, default_value=_constants.INT64_MAX) - if steps is None: - return g.op("Slice", input, starts, ends, axes) - steps = to_slice_input(steps, default_value=1) - return g.op("Slice", input, starts, ends, axes, steps) - - -@_onnx_symbolic("aten::slice") -def slice(g: jit_utils.GraphContext, self, *args): - if len(args) == 4: - # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor - dims, start, end, step = args - elif len(args) == 3: - # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] - start, end, step = args - dims = [0] - else: - raise errors.SymbolicValueError("Unknown aten::slice signature", self) - - return symbolic_helper._slice_helper( - g, - self, - axes=dims, - starts=start, - ends=end, - steps=step, - ) - - -@_onnx_symbolic("aten::flip") -@symbolic_helper.parse_args("v", "is") -def flip(g: jit_utils.GraphContext, input, dims): - return symbolic_helper._slice_helper( - g, - input, - axes=dims, - starts=[-1] * len(dims), - ends=[-_constants.INT64_MAX] * len(dims), - steps=[-1] * len(dims), - ) - - -@_onnx_symbolic("aten::fmod") -def fmod(g: jit_utils.GraphContext, input, other): - return g.op("Mod", input, other, fmod_i=1) - - -@_onnx_symbolic("aten::embedding_bag") -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") -def embedding_bag( - g: jit_utils.GraphContext, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, -): - if scale_grad_by_freq and GLOBALS.export_training: - return symbolic_helper._onnx_unsupported( - "embedding_bag with scale_grad_by_freq for training mode" - ) - if padding_idx is not None and padding_idx >= 0: - raise RuntimeError("embedding_bag with padding_idx") - - warnings.warn( - "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " - "Please use opset 11 or higher to export model for dynamic input shape.'" - ) - offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) - if offsets_dim_0 is not None: - if include_last_offset: - offset_len = offsets_dim_0 - 1 - offsets_extended = offsets - else: - offset_len = offsets_dim_0 - offsets_extended = [ - offsets, - g.op("Constant", value_t=torch.tensor([sys.maxsize])), - ] - offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) - list_ = [] - for i in range(offset_len): - start_ = symbolic_helper._unsqueeze_helper( - g, - opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), - [0], - ) - end_ = symbolic_helper._unsqueeze_helper( - g, - opset9.select( - g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) - ), - [0], - ) - axes_ = g.op("Constant", value_t=torch.tensor([0])) - indices_row = g.op("Slice", indices, start_, end_, axes_) - - embeddings = g.op("Gather", embedding_matrix, indices_row) - if not symbolic_helper._is_none(per_sample_weights): - per_sample_weights_row = g.op( - "Slice", per_sample_weights, start_, end_, axes_ - ) - per_sample_weights_row = symbolic_helper._unsqueeze_helper( - g, per_sample_weights_row, [1] - ) - embeddings = g.op("Mul", embeddings, per_sample_weights_row) - if mode == 0: - embeddings = symbolic_helper._reducesum_helper( - g, embeddings, axes_i=[0], keepdims_i=0 - ) - elif mode == 1: - embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) - else: - embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) - - embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) - list_.append(embeddings) - - output = g.op("Concat", *list_, axis_i=0) - # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. - # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. - return output, None, None, None - else: - return symbolic_helper._onnx_unsupported( - "embedding_bag with unknown shape of offsets for opset 10 is not supported. " - "please use opset 11 or higher." - ) - - -@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") -@symbolic_helper.parse_args("v", "v", "v", "i", "i") -def fake_quantize_per_tensor_affine( - g: jit_utils.GraphContext, - inputs, - scale, - zero_point, - quant_min=-128, - quant_max=127, -): - # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). - # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 - if (quant_min, quant_max) == (0, 127): - symbolic_helper._onnx_opset_unsupported_detailed( - "fake_quantize_per_tensor_affine", - 10, - 13, - "Quantize range (0, 127) not supported, requires opset 13 Clip", - inputs, - ) - if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: - raise errors.SymbolicValueError( - f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " - f"Got ({quant_min}, {quant_max})", - inputs, - ) - scale = symbolic_helper._maybe_get_scalar(scale) - if scale is None: - symbolic_helper._onnx_opset_unsupported_detailed( - "fake_quantize_per_tensor_affine", - 10, - 13, - "Non-constant scale not supported", - inputs, - ) - scale = scale.float().data # Avoid exporter generating double type - if quant_min == 0: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) - else: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) - return g.op( - "DequantizeLinear", - g.op("QuantizeLinear", inputs, scale, zero_point), - scale, - zero_point, - ) - - -@_onnx_symbolic("aten::isinf") -def isinf(g: jit_utils.GraphContext, input): - return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) - - -@_onnx_symbolic("aten::isfinite") -def isfinite(g: jit_utils.GraphContext, input): - inf_node = isinf(g, input) - nan_node = opset9.isnan(g, input) - return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) - - -@_onnx_symbolic("aten::quantize_per_tensor") -def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - # TODO(justinchuby): Extract all the cast ops into a helper function. - zero_point = g.op( - "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() - ) - scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) - return symbolic_helper.quantize_helper(g, input, scale, zero_point) - - -@_onnx_symbolic("aten::dequantize") -def dequantize(g: jit_utils.GraphContext, input): - return symbolic_helper.dequantize_helper(g, input)[0] - - -@_onnx_symbolic("aten::nan_to_num") -@symbolic_helper.parse_args("v", "f", "f", "f") -def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): - # Cannot create a int type tensor with inf/nan values, so we simply - # return the original tensor - if not symbolic_helper._is_fp(input): - return input - input_dtype = _type_utils.JitScalarType.from_value(input).dtype() - if nan is None: - nan = 0.0 - nan_cond = opset9.isnan(g, input) - nan_result = g.op( - "Where", - nan_cond, - g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), - input, - ) - - # For None values of posinf, neginf we use the greatest/lowest finite - # value representable by input's dtype. - finfo = torch.finfo(input_dtype) - if posinf is None: - posinf = finfo.max - posinf_cond = opset9.logical_and( - g, - isinf(g, nan_result), - opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), - ) - nan_posinf_result = g.op( - "Where", - posinf_cond, - g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), - nan_result, - ) - - if neginf is None: - neginf = finfo.min - neginf_cond = opset9.logical_and( - g, - isinf(g, nan_posinf_result), - opset9.lt( - g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) - ), - ) - return g.op( - "Where", - neginf_cond, - g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), - nan_posinf_result, - ) - - -# Quantized symbolics --------------------------------------------------------- -# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export -# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were -# introduced in opset version 10. -@_onnx_symbolic("quantized::linear") -def quantized_linear( - g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.linear(g, input, weight, bias) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::linear_relu") -def quantized_linear_relu( - g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.linear(g, input, weight, bias) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::add") -def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - - output = opset9.add(g, x, y) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::add_relu") -def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - - output = opset9.add(g, x, y) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::mul") -def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - - output = opset9.mul(g, x, y) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::hardswish") -def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = opset9.hardswish(g, x) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::sigmoid") -def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = opset9.sigmoid(g, x) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::leaky_relu") -def quantized_leaky_relu( - g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point -): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = opset9.leaky_relu(g, x, negative_slope, inplace) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::layer_norm") -def quantized_layer_norm( - g: jit_utils.GraphContext, - x, - normalized_shape, - weight, - bias, - eps, - op_scale, - op_zero_point, -): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::group_norm") -def quantized_group_norm( - g: jit_utils.GraphContext, - x, - num_groups, - weight, - bias, - eps, - op_scale, - op_zero_point, -): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::instance_norm") -@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") -def quantized_instance_norm( - g: jit_utils.GraphContext, - q_input, - weight, - bias, - eps, - op_scale, - op_zero_point, -): - input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) - - output = opset9.instance_norm( - g, input, weight, bias, None, None, False, 0.0, eps, False - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv1d_relu") -def quantized_conv1d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv2d_relu") -def quantized_conv2d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv3d_relu") -def quantized_conv3d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv1d") -def quantized_conv1d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv2d") -def quantized_conv2d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv3d") -def quantized_conv3d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose1d") -def quantized_conv_transpose1d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose2d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose2d") -def quantized_conv_transpose2d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose2d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose3d") -def quantized_conv_transpose3d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose3d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::cat") -@symbolic_helper.parse_args("v", "i", "v", "v") -def quantized_cat( - g: jit_utils.GraphContext, - q_inputs: _C.Value, - dim: int, - op_scale: _C.Value, - op_zero_point: _C.Value, -) -> _C.Value: - unpacked_inputs = symbolic_helper._unpack_list(q_inputs) - dequantized = [ - symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs - ] - concatenated = g.op("Concat", *dequantized, axis_i=dim) - return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 47ed56bcfeac..276ef7209bf6 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -1,1469 +1,8 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type -"""This file exports ONNX ops for opset 11.""" +"""Backward compatibility module for torch.onnx.symbolic_opset11.""" from __future__ import annotations -import functools -import sys -import warnings -from typing import TYPE_CHECKING -import torch -from torch import _C -from torch._C import _onnx as _C_onnx -from torch.onnx import ( - _type_utils, - errors, - symbolic_helper, - symbolic_opset10 as opset10, - symbolic_opset9 as opset9, - utils, -) -from torch.onnx._internal import jit_utils, registration +__all__: list[str] = [] - -if TYPE_CHECKING: - from collections.abc import Sequence - - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -__all__ = [ - "add", - "append", - "arange", - "argsort", - "atleast_1d", - "atleast_2d", - "atleast_3d", - "cat", - "chunk", - "clamp_max", - "clamp_min", - "clamp", - "constant_pad_nd", - "cumsum", - "Delete", - "embedding_bag", - "embedding_renorm", - "flatten", - "gather", - "hardtanh", - "hstack", - "im2col", - "index_fill", - "index", - "index_copy", - "index_put", - "insert", - "linalg_det", - "linalg_vector_norm", - "logdet", - "masked_scatter", - "masked_select", - "mm", - "narrow", - "normal", - "pad", - "pixel_shuffle", - "pop", - "prim_constant_chunk", - "reflection_pad", - "relu6", - "remainder", - "replication_pad", - "round", - "scatter", - "select", - "size", - "sort", - "split_with_sizes", - "split", - "squeeze", - "stack", - "topk", - "unbind", - "unique_dim", - "unsqueeze", - "vstack", -] - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) - - -@_onnx_symbolic("aten::hardtanh") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "f", "f") -def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - min_val = g.op( - "Constant", - value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), - ) - max_val = g.op( - "Constant", - value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), - ) - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min_val, max_val, opset_before=12 - ) - - -@_onnx_symbolic("aten::clamp") -def clamp(g: jit_utils.GraphContext, self, min, max): - def _cast_if_not_none(tensor, dtype): - if tensor is not None and not symbolic_helper._is_none(tensor): - return g.op( - "Cast", - tensor, - to_i=dtype.onnx_type(), - ) - else: - return tensor - - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.UNDEFINED - ) - if scalar_type != _type_utils.JitScalarType.UNDEFINED: - min = _cast_if_not_none(min, scalar_type) - max = _cast_if_not_none(max, scalar_type) - - if symbolic_helper._is_none(min): - return clamp_max(g, self, max) - elif symbolic_helper._is_none(max): - return clamp_min(g, self, min) - else: - if ( - symbolic_helper._get_tensor_rank(min) == 0 - and symbolic_helper._get_tensor_rank(max) == 0 - ): - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min, max, opset_before=12 - ) - else: - return clamp_max(g, clamp_min(g, self, min), max) - - -@_onnx_symbolic("aten::clamp_min") -@symbolic_helper.parse_args("v", "v") -def clamp_min(g: jit_utils.GraphContext, self, min): - min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) - if symbolic_helper._get_tensor_rank(min) == 0: - max = opset9.unused(g) - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min, max, opset_before=12 - ) - else: - return symbolic_helper._op_with_optional_float_cast( - g, "Max", self, min, opset_before=12 - ) - - -@_onnx_symbolic("aten::clamp_max") -@symbolic_helper.parse_args("v", "v") -def clamp_max(g: jit_utils.GraphContext, self, max): - max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) - if symbolic_helper._get_tensor_rank(max) == 0: - min = opset9.unused(g) - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min, max, opset_before=12 - ) - else: - return symbolic_helper._op_with_optional_float_cast( - g, "Min", self, max, opset_before=12 - ) - - -@_onnx_symbolic("aten::relu6") -def relu6(g: jit_utils.GraphContext, input): - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - min_val = g.op( - "Constant", - value_t=torch.tensor(0, dtype=scalar_type.dtype()), - ) - max_val = g.op( - "Constant", - value_t=torch.tensor(6, dtype=scalar_type.dtype()), - ) - return clamp(g, input, min_val, max_val) - - -@_onnx_symbolic("aten::select") -# Opset 11 gather accepts negative indices -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "i", "v") -def select(g: jit_utils.GraphContext, self, dim, index): - return g.op("Gather", self, index, axis_i=dim) - - -@_onnx_symbolic("aten::index_put") -def index_put( - g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False -): - if symbolic_helper._is_packed_list(indices_list_value): - indices_list = symbolic_helper._unpack_list(indices_list_value) - else: - indices_list = [indices_list_value] - accumulate = symbolic_helper._parse_arg(accumulate, "b") - - if len(indices_list) == 0: - return values - - if len(indices_list) > 1: - for idx_ in range(len(indices_list)): - if symbolic_helper._is_bool(indices_list[idx_]): - indices_list[idx_] = g.op("NonZero", indices_list[idx_]) - index = indices_list[0] - - for ind in indices_list[1:]: - index = opset9.add(g, index, ind) - broadcast_index_shape = g.op("Shape", index) - indices_list = [ - symbolic_helper._unsqueeze_helper( - g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] - ) - for ind in indices_list - ] - index = g.op("Concat", *indices_list, axis_i=-1) - else: - # Replace index_put node with masked_scatter or masked_fill - # when inputs to the index_put node contains a single boolean input. - # - # index_put -> masked_fill - # * input index contains single tensor of Bool type (e.g.: %24 <- %23). - # * input value contains single element (e.g.: %18). - # - # Torch IR - # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) - # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = - # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) - # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() - # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) - # %24 : Tensor?[] = prim::ListConstruct(%23) - # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = - # aten::index_put(%mask, %24, %18, %30) - # return (%25) - # - # - # index_put -> masked_scatter - # * input index contains single tensor of Bool type (e.g.: %32 <- %31). - # * input value contains multiple elements (e.g.: %28). - # - # Torch IR - # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) - # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) - # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() - # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) - # = aten::ne(%mask, %some_const) - # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) - # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) - # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() - # %30 : int[] = prim::Constant[value=[-1]]() - # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) - # %32 : Tensor?[] = prim::ListConstruct(%31) - # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) - # = aten::index_put(%mask, %32, %28, %38) - # return (%33) - index = indices_list[0] - bool_inp = index - if symbolic_helper._is_bool(bool_inp): - rank = symbolic_helper._get_tensor_rank(values) - if rank is not None and rank == 0: - return opset9.masked_fill(g, self, bool_inp, values) - mask_rank = symbolic_helper._get_tensor_rank(bool_inp) - self_rank = symbolic_helper._get_tensor_rank(self) - if ( - mask_rank is not None - and self_rank is not None - and self_rank > mask_rank - ): - # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. - bool_inp = symbolic_helper._unsqueeze_helper( - g, bool_inp, list(range(mask_rank, self_rank)) - ) - return masked_scatter(g, self, bool_inp, values) - broadcast_index_shape = g.op("Shape", index) - index = symbolic_helper._unsqueeze_helper(g, index, [-1]) - sub_data_shape = symbolic_helper._slice_helper( - g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] - ) - values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) - # Check if values is a singular value and expand accordingly - rank = symbolic_helper._get_tensor_rank(values) - if rank is not None and rank == 0: - values = opset9.expand(g, values, values_shape, None) - values = symbolic_helper._reshape_helper(g, values, values_shape) - - self_scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.UNDEFINED - ) - if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: - values_scalar_type = _type_utils.JitScalarType.from_value( - values, _type_utils.JitScalarType.UNDEFINED - ) - if self_scalar_type != values_scalar_type: - values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) - elif accumulate: - raise errors.SymbolicValueError("self does not have a valid scalar type.", self) - - if accumulate: - zeros = g.op( - "ConstantOfShape", - g.op("Shape", self), - value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), - ) - result = g.op("ScatterND", zeros, index, values) - result = add(g, self, result) - else: - result = g.op("ScatterND", self, index, values) - - return result - - -@_onnx_symbolic("aten::pixel_shuffle") -@symbolic_helper.parse_args("v", "i") -def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): - rank = symbolic_helper._get_tensor_rank(self) - if rank is not None and rank != 4: - return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") - return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") - - -@_onnx_symbolic( - "aten::upsample_nearest1d", - decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest2d", - decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest3d", - decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_linear1d", - decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], -) -@_onnx_symbolic( - "aten::upsample_bilinear2d", - decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], -) -@_onnx_symbolic( - "aten::upsample_trilinear3d", - decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], -) -@_onnx_symbolic( - "aten::upsample_bicubic2d", - decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], -) -def _interpolate(name: str, dim: int, interpolate_mode: str): - return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) - - -@_onnx_symbolic("aten::__interpolate") -@symbolic_helper.quantized_args(True, False, False, False, False, False, False) -def __interpolate( - g: jit_utils.GraphContext, - input, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, - antialias, -): - return symbolic_helper.__interpolate_helper( - g, input, size, scale_factor, mode, align_corners, recompute_scale_factor - ) - - -@_onnx_symbolic("aten::gather") -@symbolic_helper.parse_args("v", "i", "v", "v") -def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): - if symbolic_helper._maybe_get_const(sparse_grad, "i"): - return symbolic_helper._unimplemented("gather", "sparse_grad == True") - return g.op("GatherElements", self, index, axis_i=dim) - - -@_onnx_symbolic("aten::scatter") -@symbolic_helper.parse_args("v", "i", "v", "v") -def scatter(g: jit_utils.GraphContext, self, dim, index, src): - src_type = _type_utils.JitScalarType.from_value(src) - src = symbolic_helper._maybe_get_scalar(src) - if symbolic_helper._is_value(src): - return g.op("ScatterElements", self, index, src, axis_i=dim) - else: - # Check if scalar "src" has same type as self (PyTorch allows different - # type for scalar src (but not when src is tensor)). If not, insert Cast node. - if _type_utils.JitScalarType.from_value(self) != src_type: - src = g.op( - "Cast", - src, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - return g.op( - "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim - ) - - -@_onnx_symbolic("aten::cumsum") -@symbolic_helper.parse_args("v", "i", "none") -def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): - dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - cast = g.op( - "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() - ) - else: - cast = self - csum = g.op("CumSum", cast, dim_tensor) - return csum - - -@_onnx_symbolic("aten::masked_select") -def masked_select(g: jit_utils.GraphContext, self, mask): - index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) - return g.op("GatherND", self, index) - - -@_onnx_symbolic("aten::masked_scatter") -def masked_scatter(g: jit_utils.GraphContext, self, mask, source): - index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) - # NOTE: source can have more elements than needed. - # It could also have arbitrary shape. - # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. - source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) - source = symbolic_helper._slice_helper( - g, - source, - axes=torch.LongTensor([0]), - starts=torch.LongTensor([0]), - ends=opset9.size(g, index, torch.LongTensor([0])), - ) - return g.op("ScatterND", self, index, source) - - -@_onnx_symbolic("aten::len") -def _len(g: jit_utils.GraphContext, self): - if ( - symbolic_helper._is_tensor_list(self) - or self.node().kind() == "onnx::SplitToSequence" - ): - return g.op("SequenceLength", self) - sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) - return symbolic_helper._squeeze_helper(g, sz_0, [0]) - - -@_onnx_symbolic("aten::__getitem_") -def __getitem_(g: jit_utils.GraphContext, self, i): - if symbolic_helper._is_tensor_list(self): - # SequenceAt requires that the input be a List of Tensors - return g.op("SequenceAt", self, i) - else: - from torch.onnx.symbolic_opset9 import __getitem_ as getitem - - return getitem(g, self, i) - - -@_onnx_symbolic("aten::_set_item") -def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): - tensor_list = g.op("SequenceErase", tensor_list, i) - return g.op("SequenceInsert", tensor_list, v, i) - - -@_onnx_symbolic("aten::append") -def append(g: jit_utils.GraphContext, self, tensor): - return g.op("SequenceInsert", self, tensor) - - -@_onnx_symbolic("aten::add") -def add(g: jit_utils.GraphContext, self, other, alpha=None): - if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): - tensor_list_node = other.node() - if tensor_list_node.kind() != "prim::ListConstruct": - return symbolic_helper._unimplemented( - "add", "does not support adding dynamic tensor list to another" - ) - tensors = symbolic_helper._unpack_list(other) - l = self - for t in tensors: - l = g.op("SequenceInsert", l, t) - return l - - return opset9.add(g, self, other, alpha) - - -@_onnx_symbolic("aten::insert") -def insert(g: jit_utils.GraphContext, self, pos, tensor): - return g.op("SequenceInsert", self, tensor, pos) - - -@_onnx_symbolic("aten::pop") -def pop(g: jit_utils.GraphContext, tensor_list, dim): - return g.op("SequenceErase", tensor_list, dim) - - -@_onnx_symbolic("aten::Delete") -def Delete(g: jit_utils.GraphContext, tensor_list, dim): - return g.op("SequenceErase", tensor_list, dim) - - -@_onnx_symbolic("aten::cat") -@symbolic_helper.quantized_args(True) -def cat(g: jit_utils.GraphContext, tensor_list, dim): - if symbolic_helper._is_packed_list(tensor_list): - return opset9.cat(g, tensor_list, dim) - else: - dim = symbolic_helper._get_const(dim, "i", "dim") - return g.op("ConcatFromSequence", tensor_list, axis_i=dim) - - -@_onnx_symbolic("aten::stack") -def stack(g: jit_utils.GraphContext, tensor_list, dim): - if symbolic_helper._is_packed_list(tensor_list): - return opset9.stack(g, tensor_list, dim) - else: - dim = symbolic_helper._get_const(dim, "i", "dim") - return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) - - -@_onnx_symbolic("aten::_unique2") -@symbolic_helper.parse_args("v", "i", "i", "i") -def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): - u, _indices, inverse_indices, counts = g.op( - "Unique", self, sorted_i=sorted, outputs=4 - ) - return u, inverse_indices, counts - - -@_onnx_symbolic("aten::unique_dim") -@symbolic_helper.parse_args("v", "i", "i", "i", "i") -def unique_dim( - g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts -): - u, _indices, inverse_indices, counts = g.op( - "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 - ) - return u, inverse_indices, counts - - -@_onnx_symbolic("aten::topk") -@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") -def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): - return symbolic_helper._topk_helper( - g, self, k, dim, largest=largest, sorted=sorted, out=out - ) - - -@_onnx_symbolic("aten::sort") -@symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): - return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) - - -@_onnx_symbolic("aten::argsort") -@symbolic_helper.parse_args("v", "i", "i", "none") -def argsort(g: jit_utils.GraphContext, self, dim, descending, out=None): - _, indices = symbolic_helper._sort_helper( - g, self, dim, descending=descending, out=out - ) - return indices - - -@_onnx_symbolic("aten::round") -@symbolic_helper.parse_args("v", "i") -def round(g: jit_utils.GraphContext, self, decimals=0): - if not symbolic_helper._is_fp(self): - return self - if decimals == 0: - return g.op("Round", self) - mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) - round = g.op("Round", mul) - return g.op( - "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) - ) - - -@_onnx_symbolic("aten::remainder") -def remainder(g: jit_utils.GraphContext, input, other): - if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): - return opset9.remainder(g, input, other) - return g.op("Mod", input, other, fmod_i=0) - - -@_onnx_symbolic("aten::split") -@symbolic_helper.parse_args("v", "v", "i", "i") -def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): - if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): - split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) - if _outputs is None: - return split_out - # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. - if ( - symbolic_helper._is_packed_list(split_size_or_sizes) - and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs - ): - split_sizes = [ - symbolic_helper._unsqueeze_helper(g, v, [0]) - for v in symbolic_helper._unpack_list(split_size_or_sizes) - ] - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - res = [] - for i in range(_outputs): - end = g.op( - "Add", start, split_sizes[i] - ) # split_sizes is a list of same length as _outputs - res.append(g.op("Slice", self, start, end, axis)) - start = end - return res - return [ - g.op( - "SequenceAt", - split_out, - g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), - ) - for i in range(_outputs) - ] - else: - return opset9.split(g, self, split_size_or_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::split_with_sizes") -@symbolic_helper.parse_args("v", "v", "i", "i") -def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): - return split(g, self, split_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::unbind") -@symbolic_helper.parse_args("v", "i", "i") -def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): - if _outputs is None: - return g.op( - "SplitToSequence", - self, - g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), - axis_i=dim, - keepdims_i=0, - ) - else: - return opset9.unbind(g, self, dim, _outputs) - - -def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): - """Generate paddings in ONNX order based on pad in pytorch. - - Args: - input: the input tensor. - pad: the paddings in pytorch. - The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, - where m is in range [0, n]. - """ - if ( - not symbolic_helper._is_packed_list(pad) - and symbolic_helper._is_list(pad) - and symbolic_helper._is_scalar_list(pad) - ): - pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) - # The desired order of paddings is - # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. - # n is the dimension of input. - # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning - pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) - # Set extension = [0] * (dim * 2 - len(pad)) - rank = symbolic_helper._get_tensor_rank(input) - if rank is None: - rank = g.op("Size", g.op("Shape", input)) - else: - rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) - extension = g.op( - "Sub", - g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), - pad_len, - ) - # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] - # Currently ONNX only supports int64 type for Pad - pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) - paddings = g.op( - "Concat", - pad, - g.op( - "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) - ), - axis_i=0, - ) - # Reshape and reverse order and collate first beginnings and then ends - # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], - # [..., 0, dim_n-1_end, dim_n_end]] - # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] - paddings = symbolic_helper._reshape_helper( - g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) - ) - paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) - paddings = symbolic_helper._reshape_helper( - g, paddings, g.op("Constant", value_t=torch.tensor([-1])) - ) - padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) - return padding_c - - -@_onnx_symbolic("aten::constant_pad_nd") -def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): - mode = "constant" - value = symbolic_helper._maybe_get_scalar(value) - value = symbolic_helper._if_scalar_type_as(value, input) - pad = _prepare_onnx_paddings(g, input, padding) - return g.op("Pad", input, pad, value, mode_s=mode) - - -@_onnx_symbolic("aten::reflection_pad1d") -@_onnx_symbolic("aten::reflection_pad2d") -@_onnx_symbolic("aten::reflection_pad3d") -def reflection_pad(g: jit_utils.GraphContext, input, padding): - mode = "reflect" - paddings = _prepare_onnx_paddings(g, input, padding) - return g.op("Pad", input, paddings, mode_s=mode) - - -@_onnx_symbolic("aten::replication_pad1d") -@_onnx_symbolic("aten::replication_pad2d") -@_onnx_symbolic("aten::replication_pad3d") -def replication_pad(g: jit_utils.GraphContext, input, padding): - mode = "edge" - paddings = _prepare_onnx_paddings(g, input, padding) - return g.op("Pad", input, paddings, mode_s=mode) - - -@_onnx_symbolic("aten::pad") -def pad( - g: jit_utils.GraphContext, - input: _C.Value, - pad: _C.Value, - mode: _C.Value, - value: _C.Value, -): - mode = symbolic_helper._parse_arg(mode, "s") - if mode == "replicate": - return replication_pad(g, input, pad) - elif mode == "reflect": - return reflection_pad(g, input, pad) - elif mode == "constant": - return constant_pad_nd(g, input, pad, value) - elif mode == "circular": - return opset9._pad_circular(g, input, pad) - else: - raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) - - -@_onnx_symbolic("aten::linalg_det") -def linalg_det(g: jit_utils.GraphContext, self): - return g.op("Det", self) - - -@_onnx_symbolic("aten::logdet") -def logdet(g: jit_utils.GraphContext, input): - return opset9.log(g, linalg_det(g, input)) - - -@_onnx_symbolic("aten::arange") -def arange(g: jit_utils.GraphContext, *args): - def _get_arange_dtype(dtype): - dtype = symbolic_helper._maybe_get_const(dtype, "i") - return dtype - - if len(args) == 2 and all(isinstance(val, int) for val in args): - # aten::arange(Scalar start, Scalar end) - dtype = torch.int64 - # Start index. - start = g.op( - "Constant", - value_t=torch.tensor(args[0], dtype=dtype), - ) - # End (exclusive) index. - end = g.op( - "Constant", - value_t=torch.tensor(args[1], dtype=dtype), - ) - # Step size from start to end indexes. - delta_default = g.op( - "Constant", - value_t=torch.tensor(1, dtype=dtype), - ) - return g.op("Range", start, end, delta_default) - elif len(args) == 2 or len(args) == 5: - if len(args) == 2: - # aten::arange(Scalar end, Tensor out) - dtype = None - else: - # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[1]) - type_, end, start, step = symbolic_helper._arange_cast_helper( - g, end=args[0], dtype=dtype - ) - start_default = g.op( - "Constant", - value_t=torch.tensor(0, dtype=type_.dtype()), - ) - delta_default = g.op( - "Constant", - value_t=torch.tensor(1, dtype=type_.dtype()), - ) - return g.op("Range", start_default, end, delta_default) - elif len(args) == 4 or len(args) == 7: - if len(args) == 4: - # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) - dtype = None - else: - # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[3]) - _, end, start, step = symbolic_helper._arange_cast_helper( - g, start=args[0], end=args[1], step=args[2], dtype=dtype - ) - return g.op("Range", start, end, step) - elif len(args) == 6: - # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[2]) - type_, end, start, step = symbolic_helper._arange_cast_helper( - g, start=args[0], end=args[1], dtype=dtype - ) - delta_default = g.op( - "Constant", - value_t=torch.tensor(1, dtype=type_.dtype()), - ) - return g.op("Range", start, end, delta_default) - else: - return symbolic_helper._unimplemented( - "aten::arange", f"with {len(args)} arguments" - ) - - -@_onnx_symbolic("aten::_dim_arange") -@symbolic_helper.parse_args("v", "i") -def _dim_arange(g: jit_utils.GraphContext, like, dim): - like_shape = g.op("Shape", like) - stop = g.op( - "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 - ) - return arange(g, stop, 4, None, None, None) - - -@_onnx_symbolic("aten::size") -@symbolic_helper.quantized_args(True, quantize_output=False) -def size(g: jit_utils.GraphContext, self, dim=None): - if dim is None: - return g.op("Shape", self) - return symbolic_helper._size_helper(g, self, dim) - - -@_onnx_symbolic("aten::squeeze") -def squeeze(g: jit_utils.GraphContext, self, dim=None): - if dim is None: - return g.op("Squeeze", self) - - # dim as a tensor - if not symbolic_helper._is_constant(dim): - return symbolic_helper._squeeze_helper(g, self, [dim]) - - dim = symbolic_helper._get_const(dim, "i", "dim") - - input_rank = symbolic_helper._get_tensor_rank(self) - adjusted_dim = dim - if input_rank is not None and dim < 0: - adjusted_dim += input_rank - dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) - if (dim < 0 and input_rank is None) or dim_size is None: - # If onnx shape inference is not on, export always as dynamic. - # Because we cannot tell if observed static shape is also static at runtime. - # create "cond" node (condition is shape[i]==1) - dim_constant = g.op("Constant", value_t=torch.tensor([dim])) - size = symbolic_helper._size_helper(g, self, dim_constant) - const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) - cond = g.op("Equal", size, const_one) - # create the "If" node and add the "then" and "else" blocks to it. - if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( - g, "If", cond, n_blocks=2 - ) - squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) - utils._add_output_to_block(if_context.block, squeeze_) - identity_ = else_context.op("Identity", self) - utils._add_output_to_block(else_context.block, identity_) - return if_op - - # For static input shape - dim = adjusted_dim - if dim_size > 1: - warnings.warn( - "This model contains a squeeze operation on dimension " - + str(dim) - + ". The size of " - + "this dimension in the given input is " - + str(dim_size) - + ". The model will " - + "be exported without the squeeze node. If the model is intended to be used with dynamic " - + "input shapes, please export with dynamic_axes argument." - ) - return self - return symbolic_helper._squeeze_helper(g, self, [dim]) - - -@_onnx_symbolic("aten::unsqueeze") -def unsqueeze(g: jit_utils.GraphContext, self, dim): - if symbolic_helper._is_constant(dim): - dim = symbolic_helper._get_const(dim, "i", "dim") - - return symbolic_helper._unsqueeze_helper(g, self, [dim]) - - -@_onnx_symbolic("aten::mm") -def mm(g: jit_utils.GraphContext, self, other): - return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) - - -@_onnx_symbolic("aten::index") -def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper._is_packed_list(index): - indices = symbolic_helper._unpack_list(index) - else: - indices = [index] - - # Handle single mask index. - if len(indices) == 1: - index = indices[0] - if not symbolic_helper._is_none(index) and ( - symbolic_helper._is_bool(index) - or _type_utils.JitScalarType.from_value(index) - == _type_utils.JitScalarType.UINT8 - ): - index = opset9.nonzero(g, index) - return g.op("GatherND", self, index) - return opset9.index(g, self, index) - - -@_onnx_symbolic("aten::index_fill") -def index_fill(g: jit_utils.GraphContext, self, dim, index, value): - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( - g, self, dim, index - ) - value = symbolic_helper._maybe_get_scalar(value) - value = symbolic_helper._if_scalar_type_as(value, self) - expanded_value = opset9.expand(g, value, expanded_index_shape, None) - return scatter(g, self, dim, expanded_index, expanded_value) - - -@_onnx_symbolic("aten::index_copy") -def index_copy(g: jit_utils.GraphContext, self, dim, index, source): - _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( - g, self, dim, index - ) - return scatter(g, self, dim, expanded_index, source) - - -@_onnx_symbolic("aten::bitwise_right_shift") -@_onnx_symbolic("aten::__rshift_") -def __rshift_(g: jit_utils.GraphContext, self, other): - # make sure to cast other to self's type - # (when self is long, make sure that other is not float) - if _type_utils.JitScalarType.from_value( - other, _type_utils.JitScalarType.UNDEFINED - ) != _type_utils.JitScalarType.from_value(self): - other = g.op( - "Cast", - other, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - - if ( - _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) - == _type_utils.JitScalarType.UINT8 - ): - return g.op("BitShift", self, other, direction_s="RIGHT") - - two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) - # exponent (same type as self) has to be float or double in onnx::Pow - if not symbolic_helper._is_fp(self): - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) - two_pow = g.op("Pow", two, other) - two_pow = g.op( - "Cast", - two_pow, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - rshift = g.op("Div", self, two_pow) - return rshift - - -@_onnx_symbolic("aten::bitwise_left_shift") -@_onnx_symbolic("aten::__lshift_") -def __lshift_(g: jit_utils.GraphContext, self, other): - # make sure to cast other to self's type - # (when self is long, make sure that other is not float) - if _type_utils.JitScalarType.from_value( - other, _type_utils.JitScalarType.UNDEFINED - ) != _type_utils.JitScalarType.from_value(self): - other = g.op( - "Cast", - other, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - - if ( - _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) - == _type_utils.JitScalarType.UINT8 - ): - return g.op("BitShift", self, other, direction_s="LEFT") - - two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) - # exponent (same type as self) has to be float or double in onnx::Pow - if not symbolic_helper._is_fp(self): - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) - two_pow = g.op("Pow", two, other) - two_pow = g.op( - "Cast", - two_pow, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - lshift = g.op("Mul", self, two_pow) - return lshift - - -def _get_im2col_indices_along_dim( - g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d -): - # Input is always 4-D (N, C, H, W) - # Calculate indices of sliding blocks along spatial dimension - # Slide kernel over input each dim d: - # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) - # with steps = stride - - blocks_d = g.op( - "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) - ) - blocks_d = g.op( - "Sub", - blocks_d, - g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), - ) - - # Stride kernel over input and find starting indices along dim d - blocks_d_indices = g.op( - "Range", - g.op("Constant", value_t=torch.tensor(0)), - blocks_d, - g.op("Constant", value_t=torch.tensor(stride_d)), - ) - - # Apply dilation on kernel and find its indices along dim d - kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) - kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) - - # Broadcast and add kernel staring positions (indices) with - # kernel_grid along dim d, to get block indices along dim d - blocks_d_indices = symbolic_helper._unsqueeze_helper( - g, blocks_d_indices, [0] - ) # Reshape to [1, -1] - kernel_mask = symbolic_helper._reshape_helper( - g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) - ) - block_mask = g.op("Add", blocks_d_indices, kernel_mask) - - return block_mask - - -def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): - # Input is always 4-D tensor (N, C, H, W) - # Padding tensor has the following format: (padding_h, padding_w) - # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) - pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) - return g.op("Pad", input, pad) - - -def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): - batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) - channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) - channel_unfolded = g.op( - "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) - ) - - return g.op( - "Concat", - symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), - symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), - g.op("Constant", value_t=torch.tensor([-1])), - axis_i=0, - ) - - -@_onnx_symbolic("aten::im2col") -@symbolic_helper.parse_args("v", "is", "is", "is", "is") -def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): - # Input is always 4-D tensor (N, C, H, W) - # All other args are int[2] - - input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) - input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) - - stride_h, stride_w = stride[0], stride[1] - padding_h, padding_w = padding[0], padding[1] - dilation_h, dilation_w = dilation[0], dilation[1] - kernel_h, kernel_w = kernel_size[0], kernel_size[1] - - blocks_row_indices = _get_im2col_indices_along_dim( - g, input_h, kernel_h, dilation_h, padding_h, stride_h - ) - blocks_col_indices = _get_im2col_indices_along_dim( - g, input_w, kernel_w, dilation_w, padding_w, stride_w - ) - - output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) - padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) - - # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 - # [[[[1., 2., 3.,], - # [4., 5., 6.,], - # [7., 8., 9.,]]]] - # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: - # [[[[[1., 2., 3.], - # [4., 5., 6.]], - # [[4., 5., 6.], - # [7., 8., 9.]]]]] - # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: - # [[[[[[1., 2.], - # [4., 5.]], - # [[2., 3.], - # [5., 6]]], - # [[[4., 5.], - # [7., 8.]], - # [[5., 6.], - # [8., 9.]]]]]] - # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: - # [[[1., 2., 4., 5.], - # [2., 3., 5., 6.], - # [4., 5., 7., 8.], - # [5., 6., 8., 9.]]] - output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) - output = g.op("Gather", output, blocks_col_indices, axis_i=4) - output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) - return symbolic_helper._reshape_helper(g, output, output_shape) - - -@_onnx_symbolic("aten::narrow") -def narrow(g: jit_utils.GraphContext, input, dim, start, length): - end = g.op("Add", start, length) - return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) - - -@_onnx_symbolic("aten::flatten") -@symbolic_helper.quantized_args(True, False, False) -@symbolic_helper.parse_args("v", "i", "i") -def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): - dim = symbolic_helper._get_tensor_rank(input) - if dim == 1: - return input - # use ONNX's Flatten operator for cases where the output shape is 2D - if start_dim == 1: - if end_dim == -1 or (dim is not None and end_dim == dim - 1): - return g.op("Flatten", input, axis_i=start_dim) - elif start_dim == 0: - if end_dim == -2 or (dim is not None and end_dim == dim - 2): - return g.op("Flatten", input, axis_i=end_dim + 1) - if dim is None: - return symbolic_helper._unimplemented( - "dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.", - ) - # if end_dim is negative add dim - if end_dim < 0: - end_dim = dim + end_dim - - return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) - - -@_onnx_symbolic("aten::linalg_vector_norm") -@symbolic_helper.parse_args("v", "f", "is", "b", "v") -def linalg_vector_norm( - g: jit_utils.GraphContext, - self, - ord, - dim: Sequence[int] | None, - keepdim: bool, - dtype, -): - return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) - - -@_onnx_symbolic("aten::embedding_bag") -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") -def embedding_bag( - g: jit_utils.GraphContext, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, -): - return symbolic_helper._embedding_bag_helper( - g, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, - ) - - -@_onnx_symbolic("aten::embedding_renorm") -@symbolic_helper.parse_args("v", "v", "f", "f") -def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): - unique_indices = g.op("Unique", indices) - partial_weight = g.op("Gather", weight, unique_indices) - norm_i = int(norm_type) - if norm_i == 1: - norm_type = "ReduceL1" - elif norm_i == 2: - norm_type = "ReduceL2" - else: - raise errors.SymbolicValueError( - f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " - "Only 1. and 2. are supported.", - weight, - ) - partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) - # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 - # Add 1e-7 to prevent division by zero. - partial_weight_norm_ = g.op( - "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) - ) - max_norm = torch.tensor(max_norm) - scales = g.op("Div", max_norm, partial_weight_norm_) - partial_weight_renorm = g.op("Mul", partial_weight, scales) - partial_weight_renorm = g.op( - "Where", - g.op("Greater", partial_weight_norm, max_norm), - partial_weight_renorm, - partial_weight, - ) - return g.op( - "ScatterND", - weight, - symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), - partial_weight_renorm, - ) - - -@_onnx_symbolic("aten::chunk") -def chunk(g: jit_utils.GraphContext, self, chunks, dim): - # Calculate chunk size for dynamic chunk - dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) - chunk_size_s = g.op( - "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) - ) - chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) - # Create splits vector - chunk_vec = [ - opset9.expand(g, chunk_size, chunk_size_s, None), - g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), - ] - chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) - return split(g, self, chunk_vec, dim) - - -@_onnx_symbolic("aten::normal") -def normal( - g: jit_utils.GraphContext, - mean, - std, - sizes=None, - generator=None, - dtype=None, - layout=None, - device=None, - pin_memory=None, -): - # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a - # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample - # from a mean 0 and variance 1 distribution then - # sigma x+mu - # is a sample with mean mu and variance sigma's square. - if sizes is not None and not symbolic_helper._is_none(sizes): - mean = opset9.expand(g, mean, sizes, None) - result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) - return add(g, result, mean) - - -@_onnx_symbolic("aten::atleast_1d") -def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): - # NOTE: If it's 0D, reshape to 1D - - # NOTE: self could be a packed list or a tensor - if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): - tensor_list = symbolic_helper._unpack_list(self) - new_tensor_list = [] - for tensor in tensor_list: - new_tensor = tensor - tensor_rank = symbolic_helper._get_tensor_rank(tensor) - if tensor_rank == 0: - new_tensor = symbolic_helper._reshape_helper( - g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) - ) - new_tensor_list.append(new_tensor) - return g.op("SequenceConstruct", *new_tensor_list) - - tensor_rank = symbolic_helper._get_tensor_rank(self) - if tensor_rank == 0: - self = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([1])) - ) - return self - - -@_onnx_symbolic("aten::atleast_2d") -def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): - # NOTE: If it's 0D, reshape to 2D - # If it's 1D, unsqueeze to 2D - - # NOTE: self could be a packed list or a tensor - if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): - tensor_list = symbolic_helper._unpack_list(self) - new_tensor_list = [] - for tensor in tensor_list: - new_tensor = tensor - tensor_rank = symbolic_helper._get_tensor_rank(tensor) - if tensor_rank == 0: - new_tensor = symbolic_helper._reshape_helper( - g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) - ) - elif tensor_rank == 1: - new_tensor = symbolic_helper._unsqueeze_helper( - g, new_tensor, axes_i=[0] - ) - new_tensor_list.append(new_tensor) - return g.op("SequenceConstruct", *new_tensor_list) - - tensor_rank = symbolic_helper._get_tensor_rank(self) - if tensor_rank == 0: - self = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([1, 1])) - ) - elif tensor_rank == 1: - self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) - return self - - -@_onnx_symbolic("aten::atleast_3d") -def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): - # NOTE: If it's 0D, reshape to 3D - # If it's 1D, unsqueeze to 3D - # If it's 2D, unsqueeze to 3D - - # NOTE: self could be a packed list or a tensor - if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): - tensor_list = symbolic_helper._unpack_list(self) - new_tensor_list = [] - for tensor in tensor_list: - new_tensor = tensor - tensor_rank = symbolic_helper._get_tensor_rank(tensor) - if tensor_rank == 0: - new_tensor = symbolic_helper._reshape_helper( - g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) - ) - elif tensor_rank == 1: - new_tensor = symbolic_helper._unsqueeze_helper( - g, new_tensor, axes_i=[0] - ) - new_tensor = symbolic_helper._unsqueeze_helper( - g, new_tensor, axes_i=[-1] - ) - elif tensor_rank == 2: - new_tensor = symbolic_helper._unsqueeze_helper( - g, new_tensor, axes_i=[-1] - ) - new_tensor_list.append(new_tensor) - return g.op("SequenceConstruct", *new_tensor_list) - - tensor_rank = symbolic_helper._get_tensor_rank(self) - if tensor_rank == 0: - self = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) - ) - elif tensor_rank == 1: - self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) - self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) - elif tensor_rank == 2: - self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) - return self - - -@_onnx_symbolic("prim::ConstantChunk") -def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): - input_shape = g.op("Shape", self) - axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) - chunk_size_minus_1 = g.op( - "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) - ) - input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) - chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) - res = [] - for i in range(chunks): - index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) - end = g.op("Mul", chunk_dim, index) - res.append(g.op("Slice", self, start, end, axis)) - start = end - return res - - -@_onnx_symbolic("aten::hstack") -def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): - tensor_list = atleast_1d(g, tensor_list) - first_tensor = g.op( - "SequenceAt", - tensor_list, - g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), - ) - first_tensor_shape = g.op("Shape", first_tensor) - first_tensor_dim = g.op("Size", first_tensor_shape) - - const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) - equal_to_one = g.op("Equal", first_tensor_dim, const_one) - - ( - if_op_greater, - (if_context_equal, else_context_equal), - _, - ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) - result_if = if_context_equal.op( - "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 - ) - utils._add_output_to_block(if_context_equal.block, result_if) - result_else = else_context_equal.op( - "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 - ) - utils._add_output_to_block(else_context_equal.block, result_else) - result = if_op_greater.node().output() - - return result - - -@_onnx_symbolic("aten::vstack") -def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): - tensor_list = atleast_2d(g, tensor_list) - return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) +from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 21489fbb7972..63e137734e8a 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,464 +1,8 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type +"""Backward compatibility module for torch.onnx.symbolic_opset12.""" + from __future__ import annotations -import functools -import sys -import torch -from torch._C import _onnx as _C_onnx -from torch.onnx import ( - _type_utils, - errors, - symbolic_helper, - symbolic_opset9 as opset9, - utils, -) -from torch.onnx._internal import jit_utils, registration +__all__: list[str] = [] - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -# This file exports ONNX ops for opset 12 - -__all__ = [ - "argmax", - "argmin", - "binary_cross_entropy_with_logits", - "celu", - "cross_entropy_loss", - "dropout", - "einsum", - "ge", - "le", - "native_dropout", - "nll_loss", - "nll_loss2d", - "nll_loss_nd", - "outer", - "pow", - "tensordot", - "unfold", -] - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) - - -def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): - if not tensors: - raise RuntimeError("Einsum inputs are empty.") - # ONNX does not support bool for Einsum inputs. - if symbolic_helper._is_bool(tensors[0]): - tensors = [ - g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) - for tensor in tensors - ] - return g.op( - "Cast", - g.op("Einsum", *tensors, equation_s=equation), - to_i=_C_onnx.TensorProtoDataType.BOOL, - ) - else: - return g.op("Einsum", *tensors, equation_s=equation) - - -@_onnx_symbolic("aten::einsum") -@symbolic_helper.parse_args("s", "v", "is") -def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): - tensors = symbolic_helper._unpack_list(tensor_list) - return _einsum_helper(g, equation, tensors) - - -@_onnx_symbolic("aten::outer") -@symbolic_helper.parse_args("v", "v") -def outer(g: jit_utils.GraphContext, input, other): - # make sure to cast other to self's type - if _type_utils.JitScalarType.from_value( - other, _type_utils.JitScalarType.UNDEFINED - ) != _type_utils.JitScalarType.from_value(input): - other = g.op( - "Cast", - other, - to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), - ) - return _einsum_helper(g, "i,j->ij", [input, other]) - - -def _dropout_returns_masked_input_and_mask( - g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool -) -> tuple[torch._C.Value, torch._C.Value | None]: - symbolic_helper.check_training_mode(train, "dropout") - # In eval mode, dropout is non-op. That is, if the node's - # train param is set to False, dropout just returns its inputs. - if not train: - return input, None - p = g.op("Constant", value_t=torch.tensor(p)) - t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) - r, mask = g.op("Dropout", input, p, t, outputs=2) - return r, mask - - -@_onnx_symbolic("aten::dropout") -@symbolic_helper.parse_args("v", "f", "b") -def dropout(g: jit_utils.GraphContext, input, p, train): - masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) - return masked - - -@_onnx_symbolic("aten::native_dropout") -@symbolic_helper.parse_args("v", "f", "b") -def native_dropout(g: jit_utils.GraphContext, input, p, train): - return _dropout_returns_masked_input_and_mask(g, input, p, train) - - -@_onnx_symbolic("aten::nll_loss") -def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): - # none reduction : onnx::Constant[value={0}] - # mean reduction : onnx::Constant[value={1}] - # sum reduction : onnx::Constant[value={2}] - reduction = symbolic_helper._maybe_get_const(reduction, "i") - reduction_vals = ["none", "mean", "sum"] - reduction = reduction_vals[reduction] - - # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. - # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). - ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") - if weight.node().mustBeNone(): - nllloss = g.op( - "NegativeLogLikelihoodLoss", - self, - target, - reduction_s=reduction, - ignore_index_i=ignore_index, - ) - else: - nllloss = g.op( - "NegativeLogLikelihoodLoss", - self, - target, - weight, - reduction_s=reduction, - ignore_index_i=ignore_index, - ) - - return nllloss - - -@_onnx_symbolic("aten::nll_loss2d") -def nll_loss2d( - g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index -): - return nll_loss(g, self, target, weight, reduction, ignore_index) - - -@_onnx_symbolic("aten::nll_loss_nd") -def nll_loss_nd( - g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index -): - return nll_loss(g, self, target, weight, reduction, ignore_index) - - -@_onnx_symbolic("aten::cross_entropy_loss") -def cross_entropy_loss( - g: jit_utils.GraphContext, - self, - target, - weight, - reduction, - ignore_index, - label_smoothing, -): - # none reduction : onnx::Constant[value={0}] - # mean reduction : onnx::Constant[value={1}] - # sum reduction : onnx::Constant[value={2}] - reduction = symbolic_helper._maybe_get_const(reduction, "i") - reduction_vals = ["none", "mean", "sum"] - reduction = reduction_vals[reduction] - - label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") - if label_smoothing is not None and label_smoothing > 0.0: - raise errors.SymbolicValueError( - "Unsupported: ONNX does not support label_smoothing", self - ) - - # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. - # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). - ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") - if weight.node().mustBeNone(): - celoss = g.op( - "SoftmaxCrossEntropyLoss", - self, - target, - reduction_s=reduction, - ignore_index_i=ignore_index, - ) - else: - celoss = g.op( - "SoftmaxCrossEntropyLoss", - self, - target, - weight, - reduction_s=reduction, - ignore_index_i=ignore_index, - ) - - return celoss - - -@_onnx_symbolic("aten::binary_cross_entropy_with_logits") -@symbolic_helper.parse_args("v", "v", "v", "v", "i") -def binary_cross_entropy_with_logits( - g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction -): - p = g.op("Constant", value_t=torch.tensor([1])) - sig_x = opset9.sigmoid(g, input) - log_sig_x = opset9.log(g, sig_x) - sub_1_x = opset9.sub(g, p, sig_x) - sub_1_y = opset9.sub(g, p, target) - log_1_x = opset9.log(g, sub_1_x) - if pos_weight is None or symbolic_helper._is_none(pos_weight): - output = opset9.neg( - g, - opset9.add( - g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) - ), - ) - else: - output = opset9.neg( - g, - opset9.add( - g, - opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), - opset9.mul(g, sub_1_y, log_1_x), - ), - ) - - if weight is not None and not symbolic_helper._is_none(weight): - output = opset9.mul(g, weight, output) - - reduction = symbolic_helper._maybe_get_const(reduction, "i") - if reduction == 0: - return output - elif reduction == 1: - return g.op("ReduceMean", output, keepdims_i=0) - elif reduction == 2: - return g.op("ReduceSum", output, keepdims_i=0) - else: - return symbolic_helper._onnx_unsupported( - "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", - input, - ) - - -@_onnx_symbolic("aten::celu") -def celu(g: jit_utils.GraphContext, self, alpha): - alpha = symbolic_helper._maybe_get_const(alpha, "f") - # if the input is of type double cast it to float - if ( - _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) - == _type_utils.JitScalarType.DOUBLE - ): - self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = g.op("Celu", self, alpha_f=alpha) - return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) - - return g.op("Celu", self, alpha_f=alpha) - - -@_onnx_symbolic("aten::argmax") -@symbolic_helper.parse_args("v", "v", "b") -def argmax( - g: jit_utils.GraphContext, - input: torch._C.Value, - dim: torch._C.Value, - keepdim: bool, -): - return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") - - -@_onnx_symbolic("aten::argmin") -@symbolic_helper.parse_args("v", "v", "b") -def argmin( - g: jit_utils.GraphContext, - input: torch._C.Value, - dim: torch._C.Value, - keepdim: bool, -): - return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") - - -@_onnx_symbolic("aten::pow") -def pow(g: jit_utils.GraphContext, self, exponent): - return g.op("Pow", self, exponent) - - -@_onnx_symbolic("aten::ge") -def ge(g: jit_utils.GraphContext, input, other): - return g.op("GreaterOrEqual", input, other) - - -@_onnx_symbolic("aten::le") -def le(g: jit_utils.GraphContext, input, other): - return g.op("LessOrEqual", input, other) - - -@_onnx_symbolic("aten::unfold") -@symbolic_helper.parse_args("v", "i", "v", "v") -def unfold(g: jit_utils.GraphContext, input, dimension, size, step): - const_size = symbolic_helper._maybe_get_const(size, "i") - const_step = symbolic_helper._maybe_get_const(step, "i") - if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( - const_step - ): - return opset9.unfold(g, input, dimension, const_size, const_step) - - sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) - if sizedim is not None: - low_start = g.op("Constant", value_t=torch.tensor(0)) - low_end = g.op("Constant", value_t=torch.tensor(sizedim)) - hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) - low_indices = g.op("Range", low_start, low_end, step) - hi_indices = g.op("Range", size, hi_end, step) - - low_size = symbolic_helper._size_helper( - g, low_indices, g.op("Constant", value_t=torch.tensor(0)) - ) - hi_size = symbolic_helper._size_helper( - g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) - ) - - ndim = symbolic_helper._get_tensor_rank(input) - assert ndim is not None - perm = list(range(0, ndim)) - perm.append(perm.pop(dimension)) - - unsqueeze_list = [] - loop_condition = g.op("Constant", value_t=torch.tensor(1)) - loop_condition = g.op( - "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL - ) - loop_len = g.op("Min", low_size, hi_size) - - loop, (loop_context,), _ = jit_utils.add_op_with_blocks( - g, "Loop", loop_len, loop_condition, n_blocks=1 - ) - - loop_block = loop_context.block - block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) # noqa: F841 - - starts = loop_context.op("Gather", low_indices, block_input_iter) - ends = loop_context.op("Gather", hi_indices, block_input_iter) - axes = loop_context.op("Constant", value_t=torch.tensor([2])) - starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) - ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) - stack = loop_context.op("Slice", input, starts, ends, axes) - - unsqueeze = symbolic_helper._unsqueeze_helper( - loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] - ) - unsqueeze_list.append(unsqueeze) - concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) - - cond_out = loop_context.op( - "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL - ) - utils._add_output_to_block(loop_block, cond_out) - utils._add_output_to_block(loop_block, concat) - - loop_output = loop.node().output() - perm = [0, 1, 2, 3, 4] - perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] - transpose = g.op("Transpose", loop_output, perm_i=perm) - squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) - - return squeeze - - return symbolic_helper._unimplemented("Unfold", "input size not accessible") - - -@_onnx_symbolic("aten::tensordot") -@symbolic_helper.parse_args("v", "v", "is", "is", "v") -def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): - if out is not None: - symbolic_helper._unimplemented( - "Tensordot", "Out parameter is not supported for tensordot." - ) - - dim_count_a = symbolic_helper._get_tensor_rank(input_a) - if dim_count_a is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", - input_a, - ) - - dim_count_b = symbolic_helper._get_tensor_rank(input_b) - if dim_count_b is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", - input_b, - ) - - dims_a = [ - (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] - for i in range(len(dims_a)) - ] - dims_b = [ - (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] - for i in range(len(dims_b)) - ] - - left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] - left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] - - new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) - new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) - - input_shape = g.op("Shape", new_input_a) - left_sizes_a = symbolic_helper._slice_helper( - g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] - ) - shape_sizes = [ - left_sizes_a, - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), - ] - output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) - - input_shape = g.op("Shape", output_a) - slices = symbolic_helper._slice_helper( - g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] - ) - shape_sizes = [ - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), - slices, - ] - output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) - - input_shape = g.op("Shape", new_input_b) - left_sizes_b = symbolic_helper._slice_helper( - g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] - ) - slices = symbolic_helper._slice_helper( - g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] - ) - shape_sizes = [ - slices, - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), - ] - output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) - - input_shape = g.op("Shape", output_b) - slices = symbolic_helper._slice_helper( - g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] - ) - shape_sizes = [ - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), - slices, - ] - output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) - - output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) - - shape_sizes = [left_sizes_a, left_sizes_b] - return opset9._reshape_from_tensor(g, output, shape_sizes) +from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index aa40c5578042..18aff9295be8 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -1,1113 +1,8 @@ -# mypy: allow-untyped-defs -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md +"""Backward compatibility module for torch.onnx.symbolic_opset13.""" -# This file exports ONNX ops for opset 13 -import functools +from __future__ import annotations -import torch -import torch._C._onnx as _C_onnx -from torch.onnx import ( - _constants, - _type_utils, - errors, - symbolic_helper, - symbolic_opset11 as opset11, - symbolic_opset9 as opset9, - utils, -) -from torch.onnx._internal import jit_utils, registration +__all__: list[str] = [] -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) - - -@_onnx_symbolic("aten::softmax") -@symbolic_helper.parse_args("v", "i", "none") -def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): - softmax = g.op("Softmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - softmax = g.op( - "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() - ) - - return softmax - - -@_onnx_symbolic("aten::log_softmax") -@symbolic_helper.parse_args("v", "i", "none") -def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): - return_op = g.op("LogSoftmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - return_op = g.op( - "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() - ) - return return_op - - -@_onnx_symbolic("aten::frobenius_norm") -@symbolic_helper.parse_args("v", "v", "i") -def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): - dim_val = symbolic_helper._maybe_get_const(dim, "is") - if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: - return g.op("ReduceL2", self, keepdims_i=0) - sqr = g.op("Mul", self, self) - sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) - return g.op("Sqrt", sumsqr) - - -@_onnx_symbolic("aten::split") -@symbolic_helper.parse_args("v", "v", "i", "i") -def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): - if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): - split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) - if _outputs is None: - return split_out - # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. - if ( - symbolic_helper._is_packed_list(split_size_or_sizes) - and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs - ): - split_sizes = [ - symbolic_helper._unsqueeze_helper(g, v, [0]) - for v in symbolic_helper._unpack_list(split_size_or_sizes) - ] - - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - res = [] - for i in range(_outputs): - end = g.op( - "Add", start, split_sizes[i] - ) # split_sizes is a list of same length as _outputs - res.append(g.op("Slice", self, start, end, axis)) - start = end - return res - return [ - g.op( - "SequenceAt", - split_out, - g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), - ) - for i in range(_outputs) - ] - - split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") - if split_val.dim() > 0: - return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) - split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") - - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - if _outputs is not None: - size = split_size * _outputs - else: - raise errors.SymbolicValueError( - "Unknown dimension size not supported", self - ) - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - splits = g.op("Constant", value_t=torch.tensor(splits)) - return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::split_with_sizes") -def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): - return split(g, self, split_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::unsafe_split") -def unsafe_split( - g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None -): - return split(g, self, split_size_or_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::unsafe_split_with_sizes") -def unsafe_split_with_sizes( - g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None -): - return split_with_sizes(g, self, split_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::tensor_split") -@symbolic_helper.parse_args("v", "v", "i", "i") -def tensor_split( - g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None -): - axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - axis = opset11.unsqueeze(g, axis, 0) - const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) - - if symbolic_helper._is_split_static(indices_or_sections, _outputs): - split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") - - if split_val.dim() > 0: - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - res = [] - assert _outputs is not None - for i in range(_outputs - 1): - end = g.op( - "Gather", - indices_or_sections, - g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), - axis_i=0, - ) - res.append(g.op("Slice", self, start, end, axis)) - start = end - - end = symbolic_helper._size_helper(g, self, axis) - res.append(g.op("Slice", self, start, end, axis)) - return res - - split_size = symbolic_helper._get_const( - indices_or_sections, "i", "indices_or_sections" - ) - - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - if _outputs is not None: - size = split_size * _outputs - else: - raise errors.SymbolicValueError( - "Unknown dimension size not supported", self - ) - - min_split_size = size // split_size - num_splits_one_extra = size % split_size - - splits = num_splits_one_extra * [min_split_size + 1] - leftover = (split_size - num_splits_one_extra) * [min_split_size] - - splits = g.op( - "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) - ) - return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) - - if ( - symbolic_helper._is_tensor(indices_or_sections) - and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 - ): - loop_len = symbolic_helper._size_helper( - g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) - ) - loop_len = opset11.unsqueeze(g, loop_len, 0) - loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) - - # To make the first slice in the below loop work, - # we pad a zero to the first position so that it will be the initial start of slice. - padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) - - final_splits = g.op("SequenceEmpty") - # Loop inputs - loop, (loop_context,), _ = jit_utils.add_op_with_blocks( - g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 - ) - - loop_block = loop_context.block - block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) # noqa: F841 - final_splits = utils._add_input_to_block(loop_block) - - start = loop_context.op( - "Gather", indices_or_sections, block_input_iter, axis_i=0 - ) - end = loop_context.op( - "Gather", - indices_or_sections, - loop_context.op("Add", block_input_iter, const_1), - axis_i=0, - ) - - slice = loop_context.op("Slice", self, start, end, axis) - final_splits = loop_context.op("SequenceInsert", final_splits, slice) - - # Loop outputs - cond_out = loop_context.op("Identity", loop_condition) - utils._add_output_to_block(loop_block, cond_out) - utils._add_output_to_block(loop_block, final_splits) - - loop_out = loop.node().output() - start = g.op( - "Gather", - indices_or_sections, - g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), - axis_i=0, - ) - start = opset11.unsqueeze(g, start, 0) - end = symbolic_helper._size_helper(g, self, axis) - - last_slice = g.op("Slice", self, start, end, axis) - - return g.op("SequenceInsert", loop_out, last_slice) - - else: # scalar tensor - dim_size = symbolic_helper._size_helper(g, self, axis) - min_split_size = g.op("Div", dim_size, indices_or_sections) - min_split_size_plus_1 = g.op( - "Add", - min_split_size, - const_1, - ) - num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) - splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) - leftover = g.op( - "Tile", - min_split_size, - g.op( - "Sub", - opset11.unsqueeze(g, indices_or_sections, 0), - num_splits_one_extra, - ), - ) - - splits = g.op("Concat", splits, leftover, axis_i=0) - if _outputs is None: - return g.op("SplitToSequence", self, splits, axis_i=dim) - return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::unbind") -@symbolic_helper.parse_args("v", "i", "i") -def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): - if _outputs is None: - return g.op( - "SplitToSequence", - self, - g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), - axis_i=dim, - keepdims_i=0, - ) - - splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) - outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) - outputs = [outputs] if _outputs == 1 else outputs - squeezed_outputs = [ - g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) - for out in outputs - ] - return squeezed_outputs - - -@_onnx_symbolic("aten::nonzero_numpy") -# Emitted from `torch.nonzero(x, as_tuple=True)` -def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): - return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) - - -@_onnx_symbolic("aten::where") -@symbolic_helper.parse_args("v", "v", "v", "i") -def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): - # Assumes that torch.where's first argument takes only Bool and Byte tensors. - if not symbolic_helper._is_bool(condition): - condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) - if self is None: - condition = opset9.nonzero(g, condition) - return symbolic_helper._unbind_helper( - g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs - ) - return g.op("Where", condition, self, other) - - -@_onnx_symbolic("aten::fake_quantize_per_channel_affine") -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") -def fake_quantize_per_channel_affine( - g: jit_utils.GraphContext, - inputs, - scale, - zero_point, - axis, - quant_min=-128, - quant_max=127, -): - # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). - # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 - if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: - raise errors.SymbolicValueError( - "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " - f"Got ({quant_min}, {quant_max})", - inputs, - ) - # ONNX defines zero_point to be int8 or uint8 - if quant_min == 0: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) - else: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) - quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) - if (quant_min, quant_max) == (0, 127): - quantized = g.op( - "Clip", - quantized, - opset9.unused(g), - g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), - ) - return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) - - -@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") -@symbolic_helper.parse_args("v", "v", "v", "i", "i") -def fake_quantize_per_tensor_affine( - g: jit_utils.GraphContext, - inputs, - scale, - zero_point, - quant_min=-128, - quant_max=127, -): - # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). - # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 - if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: - raise errors.SymbolicValueError( - "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " - f"Got ({quant_min}, {quant_max})", - inputs, - ) - if quant_min == 0: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) - else: - zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) - if ( - _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) - != _type_utils.JitScalarType.FLOAT - ): - scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) - quantized = g.op("QuantizeLinear", inputs, scale, zero_point) - if (quant_min, quant_max) == (0, 127): - quantized = g.op( - "Clip", - quantized, - opset9.unused(g), - g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), - ) - return g.op("DequantizeLinear", quantized, scale, zero_point) - - -def _reduce_op_symbolic(onnx_op_name): - def symbolic(g, self, dim=None, keepdim=None): - self = symbolic_helper._maybe_cast_reduce_op_input(g, self) - if dim is None: - # all-reduce path - return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) - else: - keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") - return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) - - return symbolic - - -@_onnx_symbolic( - "aten::sum", - decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], -) -def _reduce_with_dtype(onnx_op, name): - symbolic = _reduce_op_symbolic(onnx_op) - - @symbolic_helper._overload_by_arg_count - def reduce(g, *args, **kwargs): - @symbolic_helper.parse_args("v", "none") - def reduce_nodim(g, self, dtype): - dtype_onnx = None - if dtype.node().kind() == "onnx::Constant": - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() - self = g.op("Cast", self, to_i=dtype_onnx) - elif dtype.node().kind() != "prim::Constant": - return symbolic_helper._unimplemented(name, "dtype", dtype) - result = symbolic(g, self) - if dtype_onnx is not None: - result_dtype_onnx = _type_utils.JitScalarType.from_value( - result - ).onnx_type() - if result_dtype_onnx != dtype_onnx: - result = g.op("Cast", result, to_i=dtype_onnx) - return result - - @symbolic_helper.parse_args("v", "v", "i", "none") - def reduce_dim(g, self, dim, keepdim, dtype): - dtype_onnx = None - if dtype.node().kind() == "onnx::Constant": - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() - self = g.op("Cast", self, to_i=dtype_onnx) - elif dtype.node().kind() != "prim::Constant": - return symbolic_helper._unimplemented(name, "dtype", dtype) - result = symbolic(g, self, dim, keepdim) - if dtype_onnx is not None: - result_dtype_onnx = _type_utils.JitScalarType.from_value( - result - ).onnx_type() - if result_dtype_onnx != dtype_onnx: - result = g.op("Cast", result, to_i=dtype_onnx) - return result - - return reduce_nodim, reduce_dim - - return reduce - - -# Ported from -# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 -# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... -@_onnx_symbolic("aten::unflatten") -def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): - input_dim = symbolic_helper._get_tensor_rank(input) - if input_dim is None: - return symbolic_helper._unimplemented( - "dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.", - ) - - # dim could be negative - input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) - dim = g.op("Add", input_dim, dim) - dim = g.op("Mod", dim, input_dim) - - input_size = g.op("Shape", input) - - head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) - head_end_idx = g.op( - "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) - ) - head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) - - dim_plus_one = g.op( - "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) - ) - tail_start_idx = g.op( - "Reshape", - dim_plus_one, - g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), - ) - tail_end_idx = g.op( - "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) - ) - tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) - - final_shape = g.op( - "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 - ) - - return symbolic_helper._reshape_helper(g, input, final_shape) - - -@_onnx_symbolic("aten::unsafe_chunk") -@symbolic_helper.parse_args("v", "i", "i", "i") -def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): - if _outputs is None: - return g.op( - "SplitToSequence", - self, - g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), - axis_i=dim, - keepdims_i=0, - ) - - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") - split_size = (size + chunks - 1) // chunks - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - - # TODO: So far we don"t have a module using this method. We"ll keep - # this as a constant unless we see a request of dynamics in any - # user's modules. - splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) - return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::tile") -def tile(g: jit_utils.GraphContext, self, dims): - self_shape = g.op("Shape", self) - self_rank = g.op("Size", self_shape) - dims_rank = g.op("Size", dims) - diff = g.op("Sub", self_rank, dims_rank) - const_zero = g.op("Constant", value_t=torch.tensor([0])) - - # 1. If dims is shorter than self.shape pad dims with 1 - dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) - ( - if_op_greater, - (if_context_greater, else_context_greater), - _, - ) = jit_utils.add_op_with_blocks( - g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 - ) - const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) - diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) - exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) - dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) - utils._add_output_to_block(if_context_greater.block, dims_) - identity_dim = else_context_greater.op("Identity", dims) - utils._add_output_to_block(else_context_greater.block, identity_dim) - dims_final = if_op_greater.node().output() - - # 2. If dims is longer than self.shape pad self.shape with 1 - dims_longer_than_self_shape = g.op("Less", diff, const_zero) - ( - if_op_less, - (if_context_less, else_context_less), - _, - ) = jit_utils.add_op_with_blocks( - g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 - ) - const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) - diff_1d_less = if_context_less.op( - "Reshape", - if_context_less.op("Abs", diff), - const_one, - ) - exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) - self_final_shape = if_context_less.op( - "Concat", exapnd_ones_less, self_shape, axis_i=0 - ) - self_ = if_context_less.op("Reshape", self, self_final_shape) - utils._add_output_to_block(if_context_less.block, self_) - identity_self = else_context_less.op("Identity", self) - utils._add_output_to_block(else_context_less.block, identity_self) - self_final = if_op_less.node().output() - - dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) - return g.op("Tile", self_final, dims_final) - - -@_onnx_symbolic("aten::repeat_interleave") -def repeat_interleave( - g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None -): - repeats_dim = symbolic_helper._get_tensor_rank(repeats) - repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) - input_sizes = symbolic_helper._get_tensor_sizes(self) - if repeats_dim is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", - self, - ) - if repeats_sizes is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", - self, - ) - if input_sizes is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown input size.", - self, - ) - - final_dim = dim - # if dim is None flatten - # By default, use the flattened input array, and return a flat output array - if symbolic_helper._is_none(dim): - self = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([-1])) - ) - dim = torch.tensor(0, dtype=torch.int64) - else: - dim = symbolic_helper._maybe_get_scalar(dim) - - # Handle cases where dim is negative - if dim < 0: - dim += len(input_sizes) - - output_sizes = input_sizes.copy() - for idx, input_size in enumerate(input_sizes): - if input_size is None: - output_sizes[idx], input_sizes[idx] = 0, -1 - - # Check if all indices should be repeated the same number of times. - if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): - return symbolic_helper._repeat_interleave_single_value_repeat_helper( - g, self, repeats, dim - ) - - cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None - # If input size is dynamic or repeats vector is dynamic - if output_sizes[dim] == 0 or cond_dynamic_repeats: - reps = symbolic_helper._size_helper(g, self, dim) - reps = opset11.unsqueeze(g, reps, 0) - - # Check if repeats is dynamic - # As repeats is dynamic, we use a where node as a substitute for the if statement - # If repests_dim = 1, expand repeats otherwise use original tensor - if cond_dynamic_repeats: - repeat_dim = symbolic_helper._size_helper( - g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) - ) - repeat_cond = g.op( - "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) - ) - repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) - # There are cases when the repeats are 1-d tensor with multiple repeats, but dim - # provided along one of the dynamic axes provided. A simple example would be - # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 - # Now, repeat interleaving can be performed in pytorch when the value of * matches - # with the number of elements in repeat, for example if * -> 2, number of repeats - # should be 2 as well. - else: - return opset9.repeat_interleave(g, self, repeats, final_dim) - - reps_like = g.op( - "ConstantOfShape", - g.op("Shape", repeats), - value_t=torch.tensor([1], dtype=torch.long), - ) - r_splits = split(g, repeats, reps_like, 0) - i_splits = split(g, self, reps_like, dim) - - output_sizes[dim], input_sizes[dim] = -1, 1 - - # Create a loop to iterate over each value along the dimension - # and perform individual interleaving using the repeats tensor - # Loop is of the following pattern - # input (trip_count, cond) - # int trip_count = ...; - # bool cond = ...; - # for (int i=0; i < trip_count && cond; ++i) { - # cond = ...; - # } - - # Loop conditions - loop_condition = g.op("Constant", value_t=torch.tensor(1)) - loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) - loop_len = reps - - # Create an empty sequence to store final expansions - final_splits = g.op("SequenceEmpty") - - # Loop inputs - loop, (loop_context,), _ = jit_utils.add_op_with_blocks( - g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 - ) - - loop_block = loop_context.block - block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) # noqa: F841 - final_splits = utils._add_input_to_block(loop_block) - - r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) - i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) - - i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) - r_concat = [ - loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), - r_split, - loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), - ] - r_concat = loop_context.op("Concat", *r_concat, axis_i=0) - i_split = opset9.expand(loop_context, i_split, r_concat, None) - i_split = symbolic_helper._reshape_helper( - loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) - ) - final_splits = loop_context.op("SequenceInsert", final_splits, i_split) - - # Loop outputs - cond_out = loop_context.op( - "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL - ) - utils._add_output_to_block(loop_block, cond_out) - utils._add_output_to_block(loop_block, final_splits) - - loop_out = loop.node().output() - loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) - return loop_out - - -@_onnx_symbolic("aten::diagonal") -@symbolic_helper.parse_args("v", "i", "i", "i") -def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): - rank = symbolic_helper._get_tensor_rank(self) - # Replace negative indexing when rank is known - if rank is not None: - dim1 = dim1 if dim1 >= 0 else dim1 + rank - dim2 = dim2 if dim2 >= 0 else dim2 + rank - - dim1_size = opset9.size( - g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) - ) - dim2_size = opset9.size( - g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) - ) - # Create appropriate mask - mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) - mask = opset9.zeros(g, mask_shape, None, None, None) - mask = g.op("EyeLike", mask, k_i=offset) - # dim1 and dim2 appended as a dimension at the end of the shape - - if rank is not None: - axes = list(range(rank)) - axes.remove(dim1) - axes.remove(dim2) - self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) - else: - return symbolic_helper._unimplemented("diagonal", "unknown input rank") - - # Multiply input and mask to calculate values along diagonal - # The mask consists of one values where diagonal values are to be calculated - # For example: - # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], - # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], - # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] - result = g.op("Mul", self, mask) - result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) - - # Calculate gather indices based on offset and dims - # If offset is greater than zero, set offset to zero as this aids in - # calculation of selection window - offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) - if offset >= 0: - diag_size = g.op( - "Max", - g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), - g.op("Constant", value_t=torch.LongTensor([0])), - ) - offset = 0 - else: - diag_size = g.op( - "Max", - g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), - g.op("Constant", value_t=torch.LongTensor([0])), - ) - diag_size = g.op("Concat", diag_size, axis_i=0) - - # Calculate which diagonal values to select - # For example, in cases with offsets: - # [[0, 1.1, 0] - # [0, 0, 2.2]] - # we need to select the last two columns, so we create a tensor - # with all columns that are to be selected - # So in this example, it is [1, 2] - select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) - select_window = g.op( - "CumSum", - select_window_ones_fill, - g.op("Constant", value_t=torch.LongTensor([0])), - ) - select_window = g.op( - "Add", - select_window, - g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), - ) - - gather_shape = [ - opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) - for axis in list(range(rank))[:-2] - ] - gather_shape.append(diag_size) - gather_shape = g.op("Concat", *gather_shape, axis_i=0) - gather_indices = opset9.zeros(g, gather_shape, 4, None, None) - - # There might be cases where offset value is greater than number of rows/columns - # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. - # For example, if - # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) - # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above - # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 - # In cases without diagonal overrun, we select the appropriate rows/columns along which we - # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has - # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially - # returning an empty tensor - overrun_cond = g.op( - "Not", - g.op( - "Equal", - diag_size, - g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), - ), - ) - - if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( - g, "If", overrun_cond, n_blocks=2 - ) - - gather_indices_if_block = if_context.op("Add", gather_indices, select_window) - gather_indices_if_block = symbolic_helper._unsqueeze_helper( - if_context, gather_indices_if_block, [rank - 1] - ) - final_non_overrun = if_context.op( - "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 - ) - final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) - utils._add_output_to_block(if_context.block, final_non_overrun) - utils._add_output_to_block(else_context.block, final_overrun) - return if_op - - -# Quantized ops - - -@_onnx_symbolic("quantized::linear") -def quantized_linear( - g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.linear(g, input, weight, bias) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::linear_relu") -def quantized_linear_relu( - g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.linear(g, input, weight, bias) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv1d_relu") -def quantized_conv1d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv2d_relu") -def quantized_conv2d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv3d_relu") -def quantized_conv3d_relu( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) - output = opset9.relu(g, output) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv1d") -def quantized_conv1d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv2d") -def quantized_conv2d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv3d") -def quantized_conv3d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose1d") -def quantized_conv_transpose1d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose2d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose2d") -def quantized_conv_transpose2d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose2d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -@_onnx_symbolic("quantized::conv_transpose3d") -def quantized_conv_transpose3d( - g: jit_utils.GraphContext, - q_input, - q_weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - op_scale, - op_zero_point, -): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - - output = opset9.conv_transpose3d( - g, input, weight, bias, stride, padding, output_padding, groups, dilation - ) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 80743c6a4912..367aa9eb0832 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -1,291 +1,8 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type -"""This file exports ONNX ops for opset 14. +"""Backward compatibility module for torch.onnx.symbolic_opset14.""" -Note [ONNX operators that are added/updated in opset 14] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -New operators: - HardSwish, Trilu - -Updated operators: - Reshape - Add, Sub, Mul, Div - GRU, LSTM, RNN - BatchNorm, Cumsum, Relu -""" - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md from __future__ import annotations -import functools -import torch -from torch.onnx import _constants, _type_utils, symbolic_helper -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils, registration +__all__: list[str] = [] - -__all__ = [ - "hardswish", - "tril", - "triu", - "reshape", - "batch_norm", - "quantized_hardswish", - "scaled_dot_product_attention", -] - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) - - -@_onnx_symbolic("aten::hardswish") -@symbolic_helper.parse_args("v") -def hardswish(g: jit_utils.GraphContext, self): - return g.op("HardSwish", self) - - -@_onnx_symbolic("aten::tril") -def tril(g: jit_utils.GraphContext, self, diagonal, out=None): - return g.op("Trilu", self, diagonal, upper_i=0) - - -@_onnx_symbolic("aten::triu") -def triu(g: jit_utils.GraphContext, self, diagonal, out=None): - return g.op("Trilu", self, diagonal, upper_i=1) - - -@_onnx_symbolic("aten::reshape") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v") -def reshape(g: jit_utils.GraphContext, self, shape): - # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 - # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. - return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) - - -@_onnx_symbolic("aten::batch_norm") -@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") -def batch_norm( - g: jit_utils.GraphContext, - input, - weight, - bias, - running_mean, - running_var, - training, - momentum, - eps, - cudnn_enabled, -): - if ( - torch.is_autocast_enabled() - and not symbolic_helper.args_have_same_dtype( - [input, weight, bias, running_mean, running_var] - ) - and GLOBALS.export_onnx_opset_version < 15 - ): - return symbolic_helper._onnx_opset_unsupported_detailed( - "BatchNormalization", - 14, - 15, - "All input tensors must have the same `dtype`." - " Turn off Autocast or export using opset version 15.", - input, - ) - - symbolic_helper.check_training_mode(training, "batch_norm") - weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( - g, input, weight, bias, running_mean, running_var - ) - out = g.op( - "BatchNormalization", - input, - weight, - bias, - running_mean, - running_var, - epsilon_f=eps, - momentum_f=1 - momentum, - training_mode_i=0 if not training else 1, - outputs=1 if not training else 3, - ) - if not training: - return out - else: - res, new_running_mean, new_running_var = out - new_running_mean.setType(running_mean.type()) - new_running_var.setType(running_var.type()) - return res - - -@_onnx_symbolic("quantized::hardswish") -def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = hardswish(g, x) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -# Ported from -# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 -# aten_scaled_dot_product_attention -# NOTE: Need op.Trilu -@_onnx_symbolic("aten::scaled_dot_product_attention") -@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") -def scaled_dot_product_attention( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: torch._C.Value | None = None, - enable_gqa: bool = False, -): - assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( - "is_causal and attn_mask cannot be set at the same time" - ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - ) - - if symbolic_helper._is_none(scale): - scale = _attention_scale(g, query) - - if is_causal: - attn_mask = _causal_attention_mask(g, query, key) - - # Swap the last two axes of key - # NOTE: onnx-script has different logic here, because the attribute perms in - # transpose needs list of ints - key_shape_builtin = symbolic_helper._get_tensor_rank(key) - key_transposed_axes = list(range(key_shape_builtin)) - key_transposed_axes[-1], key_transposed_axes[-2] = ( - key_transposed_axes[-2], - key_transposed_axes[-1], - ) - key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) - - # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 - # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math - query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) - key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) - mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) - - if symbolic_helper._is_none(attn_mask): - mul_qk_add = mul_qk - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) - elif ( - _type_utils.JitScalarType.from_value(attn_mask) - == _type_utils.JitScalarType.BOOL - ): - # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) - const_zero = g.op("Constant", value_t=torch.tensor([0.0])) - const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) - attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) - mul_qk_add = g.op("Add", mul_qk, attn_mask) - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) - # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values - # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. - # This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match - # the behavior of PyTorch with boolean masks. - attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, - ): - mul_qk_add = g.op("Add", mul_qk, attn_mask) - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) - else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) - - if dropout_p != 0: - attn_weight = g.op( - "Dropout", - attn_weight, - g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), - ) - - return g.op("MatMul", attn_weight, value) - - -def _attention_scale( - g: jit_utils.GraphContext, query: torch._C.Value -) -> torch._C.Value: - """Calculate the scale factor for the attention result. - - Args: - query: Tensor of shape [..., L, E] - - Returns: - Scalar scale factor := 1 / math.sqrt(query.size(-1)) - """ - query_shape = g.op("Shape", query) - query_shape_last = g.op( - "Slice", - query_shape, - g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), - g.op( - "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) - ), - ) - embedding_size = g.op( - "Cast", - query_shape_last, - to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), - ) - const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) - scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) - # Add a Cast to convert the scale back to original type - scale = g.op( - "Cast", - scale, - to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), - ) - return scale - - -def _causal_attention_mask( - g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value -) -> torch._C.Value: - """Create a causal mask for the given query and key tensors. - - Equivalent to:: - mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_mask = torch.zeros(L, S, dtype=torch.float) - attn_mask = attn_mask.masked_fill(not mask, -float("inf")) - - Args: - query: Tensor of shape [..., L, E] - key: Tensor of shape [..., S, E] - - Returns: - Tensor of shape [L, S] - """ - - query_shape = g.op("Shape", query) - key_shape = g.op("Shape", key) - - last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) - second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) - target_length = g.op("Slice", query_shape, second_last_idx, last_idx) - source_length = g.op("Slice", key_shape, second_last_idx, last_idx) - # attn_mask = torch.ones(L, S) := { - size = g.op("Concat", target_length, source_length, axis_i=0) - const_one = g.op("Constant", value_t=torch.tensor([1.0])) - attn_mask = g.op("Expand", const_one, size) - # } - attn_mask = g.op("Trilu", attn_mask, upper_i=0) - # The causal mask has 0s in the lower triangle and -inf in the upper triangle. - const_zero = g.op("Constant", value_t=torch.tensor([0.0])) - const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) - attn_mask = g.op( - "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero - ) - return attn_mask +from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index 08f8dcbf5a22..e04e3b045212 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -1,80 +1,8 @@ -# mypy: allow-untyped-defs -"""This file exports ONNX ops for opset 15. +"""Backward compatibility module for torch.onnx.symbolic_opset15.""" -Note [ONNX operators that are added/updated in opset 15] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set -New operators: - Bernoulli - CastLike - Optional - OptionalGetElement - OptionalHasElement - -Updated operators: - BatchNormalization https://github.com/onnx/onnx/pull/3545 - Backwards compatible - TODO: test coverage for mixed types inputs. - Pow https://github.com/onnx/onnx/pull/3412 - Backwards compatible - TODO: bfloat16 support. - Shape https://github.com/onnx/onnx/pull/3580 - Backwards compatible - TODO: optional start/end attribute. -""" - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -import functools - -import torch -from torch import _C -from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) +__all__: list[str] = [] - -@_onnx_symbolic("aten::__is_") -def aten__is_(g: jit_utils.GraphContext, self, other): - if symbolic_helper._is_none(other): - if isinstance(self.type(), _C.OptionalType): - none = g.op("OptionalHasElement", self) - return g.op("Not", none) - else: - return g.op("Constant", value_t=torch.BoolTensor([0])) - return opset9.eq(g, self, other) - - -@_onnx_symbolic("aten::__isnot_") -@opset9.wrap_logical_op_with_negation # type: ignore[has-type] -def aten__isnot_(g: jit_utils.GraphContext, self, other): - return aten__is_(g, self, other) - - -@_onnx_symbolic("aten::bernoulli") -def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): - if out is not None and not symbolic_helper._is_none(out): - symbolic_helper._unimplemented( - "Bernoulli", "out parameter is not supported for bernoulli", input - ) - if generator is not None and not symbolic_helper._is_none(generator): - symbolic_helper._unimplemented( - "Bernoulli", "generator is not supported for bernoulli", input - ) - if p is None or symbolic_helper._is_none(p): - return g.op("Bernoulli", input) - return opset9.bernoulli(g, input, p, generator, out) - - -@_onnx_symbolic("prim::unchecked_cast") -def prim_unchecked_cast(g: jit_utils.GraphContext, self): - # exists to refine the type of the Value - # if x is Optional[Tensor], unchecked_cast will cast - # x to Tensor, so the rest of the graph knows that x is a Tensor. - if isinstance(self.type(), _C.OptionalType): - return g.op("OptionalGetElement", self) - - return self +from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index d4a7baa78c2d..9a248bb0f26c 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -1,185 +1,8 @@ -# mypy: allow-untyped-defs -"""This file exports ONNX ops for opset 16. +"""Backward compatibility module for torch.onnx.symbolic_opset16.""" -Note [ONNX Operators that are added/updated in opset 16] - -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set -New operators: - GridSample https://github.com/onnx/onnx/pull/3557 - -Updated operators: - Identity - If - LeakyRelu - Loop - PRelu - RoiAlign - Scan - ScatterElements - ScatterND - Where - GreaterOrEqual - LessOrEqual -""" - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -import functools - -import torch -from torch.nn.functional import ( - GRID_SAMPLE_INTERPOLATION_MODES, - GRID_SAMPLE_PADDING_MODES, -) -from torch.onnx import _type_utils, errors, symbolic_helper, utils -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) +__all__: list[str] = [] - -# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? -# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. -@_onnx_symbolic("aten::grid_sampler") -@symbolic_helper.parse_args("v", "v", "i", "i", "b") -def grid_sampler( - g: jit_utils.GraphContext, - input, - grid, - mode_enum, - padding_mode_enum, - align_corners, -): - # Check the input and grid tensor rank beforehand. - if symbolic_helper._get_tensor_rank(input) == 5: - return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") - mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] - padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] - padding_mode_enum - ] - return g.op( - "GridSample", - input, - grid, - align_corners_i=int(align_corners), - mode_s=mode_s, - padding_mode_s=padding_mode_s, - ) - - -@_onnx_symbolic("aten::scatter_add") -@symbolic_helper.parse_args("v", "i", "v", "v") -def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): - src_type = _type_utils.JitScalarType.from_value( - src, _type_utils.JitScalarType.UNDEFINED - ) - src_sizes = symbolic_helper._get_tensor_sizes(src) - index_sizes = symbolic_helper._get_tensor_sizes(index) - - if len(src_sizes) != len(index_sizes): - return symbolic_helper._unimplemented( - "scatter_add", - f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", - ) - - # PyTorch only allows index shape <= src shape, so we can only consider - # taking index as subset size to src, like PyTorch does. When sizes for src - # and index are not matched or there are dynamic axes, we take index shape to - # slice src to accommodate. - if src_sizes != index_sizes or None in index_sizes: - adjusted_shape = g.op("Shape", index) - starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) - src = g.op("Slice", src, starts, adjusted_shape) - - src = symbolic_helper._maybe_get_scalar(src) - if symbolic_helper._is_value(src): - return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") - else: - # Check if scalar "src" has same type as self (PyTorch allows different - # type for scalar src (but not when src is tensor)). If not, insert Cast node. - if _type_utils.JitScalarType.from_value(self) != src_type: - src = g.op( - "Cast", - src, - to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), - ) - - return g.op( - "ScatterElements", - self, - index, - src, - axis_i=dim, - reduction_s="add", - ) - - -@_onnx_symbolic("aten::scatter_reduce") -@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") -def scatter_reduce( - g: jit_utils.GraphContext, - self: torch._C.Value, - dim: int, - index: torch._C.Value, - src: torch._C.Value, - reduce: str, - include_self: bool, -): - if reduce == "mean": - raise errors.OnnxExporterError( - "ONNX does not support mean reduction for scatter_reduce" - ) - if not include_self: - raise errors.OnnxExporterError( - "ONNX does not support include_self=False for scatter_reduce" - ) - - reduce_mode = { # convert torch string name to onnx string name - "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition - "sum": "add", - "prod": "mul", - "amin": "min", - "amax": "max", - } - onnx_reduce = reduce_mode[reduce] - - self_rank = g.op("Size", g.op("Shape", self)) - - # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) - self_rank_is_zero = g.op( - "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) - ) - if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( - g, "If", self_rank_is_zero, n_blocks=2, outputs=3 - ) - neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) - - self_reshape = if_context.op("Reshape", self, neg_1) - utils._add_output_to_block(if_context.block, self_reshape) - index_reshape = if_context.op("Reshape", index, neg_1) - utils._add_output_to_block(if_context.block, index_reshape) - src_reshape = if_context.op("Reshape", src, neg_1) - utils._add_output_to_block(if_context.block, src_reshape) - - self_identity = else_context.op("Identity", self) - utils._add_output_to_block(else_context.block, self_identity) - index_identitye = else_context.op("Identity", index) - utils._add_output_to_block(else_context.block, index_identitye) - src_identity = else_context.op("Identity", src) - utils._add_output_to_block(else_context.block, src_identity) - - result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) - - # if self_rank == 0: - if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( - g, "If", self_rank_is_zero, n_blocks=2, outputs=1 - ) - result_squeezed = if_context.op("Squeeze", result) - utils._add_output_to_block(if_context.block, result_squeezed) - result_identity = else_context.op("Identity", result) - utils._add_output_to_block(else_context.block, result_identity) - result_final = if_op.node().output() - - return result_final +from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index bcf80058fe2a..800acd446b5d 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -1,239 +1,8 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type -"""This file exports ONNX ops for opset 17. +"""Backward compatibility module for torch.onnx.symbolic_opset17.""" -Note [ONNX Operators that are added/updated in opset 17] - -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set -New operators: - BlackmanWindow - DFT - HammingWindow - HannWindow - LayerNormalization - MelWeightMatrix - STFT - SequenceMap -""" - -import functools -from collections.abc import Sequence -from typing import Optional - -import torch -from torch import _C -from torch.onnx import _type_utils, errors, symbolic_helper -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md +__all__: list[str] = [] -__all__ = ["layer_norm", "stft", "quantized_layer_norm"] - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) - - -@_onnx_symbolic("aten::layer_norm") -@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") -def layer_norm( - g: jit_utils.GraphContext, - input: _C.Value, - normalized_shape: Sequence[int], - weight: _C.Value, - bias: _C.Value, - eps: float, - cudnn_enable: bool, -): - # normalized_shape: input shape from an expected input of size - # axis: The first normalization dimension. - # layer_norm normalizes on the last D dimensions, - # where D is the size of normalized_shape - axis = -len(normalized_shape) - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - dtype = scalar_type.dtype() - if symbolic_helper._is_none(weight): - weight_value = torch.ones(normalized_shape, dtype=dtype) - weight = g.op("Constant", value_t=weight_value) - if symbolic_helper._is_none(bias): - bias_value = torch.zeros(normalized_shape, dtype=dtype) - bias = g.op("Constant", value_t=bias_value) - return g.op( - "LayerNormalization", - input, - weight, - bias, - epsilon_f=eps, - axis_i=axis, - ) - - -@_onnx_symbolic("quantized::layer_norm") -def quantized_layer_norm( - g: jit_utils.GraphContext, - x, - normalized_shape, - weight, - bias, - eps, - op_scale, - op_zero_point, -): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - - output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) - - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - - -def _compute_edge_sizes(n_fft, window_size): - """Helper function to compute the sizes of the edges (left and right) - of a given window centered within an FFT size.""" - left = (n_fft - window_size) // 2 - right = n_fft - left - window_size - return left, right - - -@_onnx_symbolic("aten::stft") -@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b") -def stft( - g: jit_utils.GraphContext, - input: _C.Value, - n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[_C.Value] = None, - normalized: bool = False, - onesided: Optional[bool] = True, - return_complex: Optional[bool] = False, - align_to_window: Optional[bool] = None, -) -> _C.Value: - """Associates `torch.stft` with the `STFT` ONNX operator. - Note that torch.stft calls _VF.stft, without centering or padding options. - Hence, this function does not contain these two arguments. - See torch.stft source code for more info. - - Args: - g: Graph to write the ONNX representation into - input: Input tensor for the transformation - n_fft: FFT size - hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` - win_length: Size of the analysis window. Defaults to `n_fft` - window: Analysis window. Defaults to a window of all ones - normalized: Whether to return a normalized STFT - onesided: Whether to return only half (+1) of the results, given the - symmetry of the STFT - return_complex: Whether to return the complex value (Note: Must be - `False` or `None`) - - Returns: - op: Operator for torch.stft associated with STFT (ONNX) - """ - # Checks - if return_complex: - raise errors.SymbolicValueError( - msg="STFT does not currently support complex types", value=input - ) - - if align_to_window is not None: - raise errors.SymbolicValueError( - msg="STFT does not currently support the align_to_window option", - value=input, - ) # TODO(#145944): add compatibility with align_to_window option. - - # Get STFT sizes - frame_step_value = hop_length if hop_length is not None else n_fft // 4 - frame_step_const = g.op( - "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) - ) - frame_length_const = g.op( - "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) - ) - - # Pre-process input if needed - signal = input - signal_rank = symbolic_helper._get_tensor_rank(signal) - if signal_rank == 1: - # Add batch dimension - signal = g.op( - "Unsqueeze", - signal, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), - ) - elif signal_rank is None or signal_rank > 2: - raise errors.SymbolicValueError( - msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " - f"Current rank of signal is {signal_rank}, please reduce it.", - value=input, - ) - - # Get window and make sure it's the same size as `win_length` or `n_fft` - n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) - if n_win is not None: - win_length_default = win_length if win_length else n_fft - assert n_win == win_length_default, ( - "Analysis window size must equal `win_length` or `n_fft`. " - f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", - ) - - # Center window around zeros if needed (required by ONNX's STFT) - if n_win < n_fft: - left, right = _compute_edge_sizes(n_fft, n_win) - left_win = g.op("Constant", value_t=torch.zeros(left)) - right_win = g.op("Constant", value_t=torch.zeros(right)) - window = g.op("Concat", left_win, window, right_win, axis_i=0) - - # Create window, if needed - if symbolic_helper._is_none(window): - if win_length: - if win_length > n_fft: - raise errors.SymbolicValueError( - msg="The analysis window can't be longer than the size of the FFT. " - f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", - value=input, - ) - - # Center window, if needed - left, right = _compute_edge_sizes(n_fft, win_length) - torch_window = torch.hstack( - (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) - ) - else: - # Rectangle window - torch_window = torch.ones(n_fft) - assert torch_window.shape[0] == n_fft - window = g.op("Constant", value_t=torch_window) - window = g.op( - "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() - ) - - # Run STFT - result = g.op( - "STFT", - signal, - frame_step_const, - window, - frame_length_const, - onesided_i=1 if onesided is None or onesided else 0, - ) - - # Transpose to mimic torch.stft's behavior - result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) - - # Remove batch dimension, if needed - if signal_rank == 1: - result = g.op( - "Squeeze", - result, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), - ) - - # Normalize, if needed - if normalized: - sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) - result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) - - return result +from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index 76f5d4df6ec2..cc07a60f018d 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -1,265 +1,8 @@ -# mypy: allow-untyped-defs -"""This file exports ONNX ops for opset 18. +"""Backward compatibility module for torch.onnx.symbolic_opset18.""" -Note [ONNX Operators that are added/updated in opset 18] - -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set -New operators: - BitwiseAnd - CenterCropPad - Col2Im - Mish - OptionalGetElement - OptionalHasElement - Pad - Resize - ScatterElements - ScatterND - Split -""" - -import functools -from collections.abc import Sequence -from typing import Optional - -import torch -from torch import _C -from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in symbolic_helper.py +__all__: list[str] = [] -__all__ = [ - "col2im", -] - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) - - -@_onnx_symbolic("aten::__and_") -@_onnx_symbolic("aten::bitwise_and") -def __and_(g: jit_utils.GraphContext, self, other): - # do type promotion (scalars don't seem to apply) - args = [self, other] - # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) - prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] - if len(prom_args) == 0: - prom_args = args - promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) - self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) - other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) - if promotion_jit_type == _type_utils.JitScalarType.BOOL: - return g.op("And", self, other) - return g.op("BitwiseAnd", self, other) - - -@_onnx_symbolic("aten::col2im") -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") -def col2im( - g, - input: _C.Value, - output_size: _C.Value, - kernel_size: _C.Value, - dilation: Sequence[int], - padding: Sequence[int], - stride: Sequence[int], -): - # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] - adjusted_padding: list[int] = [] - for pad in padding: - adjusted_padding.extend(pad for _ in range(2)) - - num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] - if not adjusted_padding: - adjusted_padding = [0, 0] * num_dimensional_axis - - if not dilation: - dilation = [1] * num_dimensional_axis - - if not stride: - stride = [1] * num_dimensional_axis - - return g.op( - "Col2Im", - input, - output_size, - kernel_size, - dilations_i=dilation, - pads_i=adjusted_padding, - strides_i=stride, - ) - - -@_onnx_symbolic( - "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] -) -@_onnx_symbolic( - "aten::prod", - decorate=[ - symbolic_helper._apply_params( - "ReduceProd", "prod", allow_multi_dim_support=False - ) - ], -) -def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): - return symbolic_helper._reduce_with_dtype_helper( - onnx_op, name, allow_multi_dim_support - ) - - -@_onnx_symbolic("aten::native_layer_norm") -@symbolic_helper.quantized_args(True, False, False, False) -@symbolic_helper.parse_args("v", "is", "v", "v", "f") -def _native_layer_norm( - g: jit_utils.GraphContext, - input: _C.Value, - normalized_shape: Sequence[int], - weight: _C.Value, - bias: _C.Value, - eps: float, -) -> tuple[_C.Value, _C.Value, _C.Value]: - return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) - - -@_onnx_symbolic("aten::glu") -@symbolic_helper.parse_args("v", "i") -def _glu(g: jit_utils.GraphContext, input, dim): - dim_size = symbolic_helper._get_tensor_dim_size(input, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) - return g.op("Mul", first, g.op("Sigmoid", second)) - - -@_onnx_symbolic("aten::max") -# torch.max (same for torch.min) actually has two interfaces smashed together: -# torch.max(x, dim, keepdim) and torch.max(x, y) -# TODO(justinchuby): Support multiple quantized args in output -def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) - - -@_onnx_symbolic("aten::maximum") -@symbolic_helper.quantized_args(True, True) -def maximum(g: jit_utils.GraphContext, input, other): - return max(g, input, dim_or_y=other) - - -@_onnx_symbolic("aten::min") -# TODO(justinchuby): Support multiple quantized args in output -def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) - - -@_onnx_symbolic("aten::minimum") -@symbolic_helper.quantized_args(True, True) -def minimum(g: jit_utils.GraphContext, input, other): - return min(g, input, dim_or_y=other) - - -@_onnx_symbolic("aten::amax") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "is", "i") -def amax(g: jit_utils.GraphContext, self, dim, keepdim): - axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - return g.op("ReduceMax", self, axes, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::amin") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "is", "i") -def amin(g: jit_utils.GraphContext, self, dim, keepdim): - axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - return g.op("ReduceMin", self, axes, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::aminmax") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v", "i") -def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): - if not symbolic_helper._is_none(dim): - dim = symbolic_helper._get_const(dim, "i", "dim") - axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( - "ReduceMax", self, axes, keepdims_i=keepdim - ) - else: - return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( - "ReduceMax", self, keepdims_i=keepdim - ) - - -@_onnx_symbolic("aten::var_mean") -def _var_mean(g: jit_utils.GraphContext, input, *args): - if len(args) == 1: - return symbolic_helper._var_mean_helper(g, input, None, args[0], None) - else: - return symbolic_helper._var_mean_helper(g, input, *args) - - -@_onnx_symbolic("aten::logsumexp") -@symbolic_helper.parse_args("v", "is", "i") -def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): - if dim is None: - return g.op("ReduceLogSumExp", input, keepdims_i=0) - else: - axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) - return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::linalg_matrix_norm") -@symbolic_helper.parse_args("v", "v", "is", "b", "v") -def _linalg_matrix_norm( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: torch._C.Value, - dim: list[int], - keepdim: bool, - dtype: torch._C.Value, -): - return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) - - -@_onnx_symbolic("aten::embedding_bag") -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") -def embedding_bag( - g: jit_utils.GraphContext, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, -): - return symbolic_helper._embedding_bag_helper( - g, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, - ) - - -@_onnx_symbolic("aten::linalg_vector_norm") -@symbolic_helper.parse_args("v", "f", "is", "b", "v") -def linalg_vector_norm( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: float, - dim: Optional[Sequence[int]], - keepdim: bool, - dtype: torch._C.Value, -): - return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) +from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset19.py b/torch/onnx/symbolic_opset19.py index 781bc2d200c7..4f7a54fc1dd3 100644 --- a/torch/onnx/symbolic_opset19.py +++ b/torch/onnx/symbolic_opset19.py @@ -1,31 +1,8 @@ -"""This file exports ONNX ops for opset 19. +"""Backward compatibility module for torch.onnx.symbolic_opset19.""" -Note [ONNX Operators that are added/updated in opset 19] +from __future__ import annotations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set -New operators: -AveragePool -Cast -CastLike -Constant -DeformConv -DequantizeLinear -Equal -Identity -If -Loop -Pad -QuantizeLinear -Reshape -Resize -Scan -Shape -Size -""" - - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in symbolic_helper.py __all__: list[str] = [] + +from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index d96f770ca11e..56635a781161 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -1,92 +1,8 @@ -# mypy: allow-untyped-defs -"""This file exports ONNX ops for opset 20. +"""Backward compatibility module for torch.onnx.symbolic_opset20.""" -Note [ONNX Operators that are added/updated in opset 20] - -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set -New operators: - AffineGrid - ConstantOfShape - DFT - Gelu - GridSample - ImageDecoder - IsInf - IsNaN - ReduceMax - ReduceMin - RegexFullMatch - StringConcat - StringSplit -""" - -import functools - -import torch.nn.functional as F -from torch import _C -from torch.onnx import symbolic_helper -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in symbolic_helper.py +__all__: list[str] = [] -__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] - - -def convert_grid_sample_mode(mode_s): - return ( - "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s - ) - - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) - - -@_onnx_symbolic("aten::grid_sampler") -@symbolic_helper.parse_args("v", "v", "i", "i", "b") -def _grid_sampler( - g: jit_utils.GraphContext, - input: _C.Value, - grid: _C.Value, - mode_enum: int, - padding_mode_enum: int, - align_corners: bool, -): - mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] - # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html - mode_s = convert_grid_sample_mode(mode_s) - padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] - padding_mode_enum # type: ignore[index] - ] - return g.op( - "GridSample", - input, - grid, - align_corners_i=int(align_corners), - mode_s=mode_s, - padding_mode_s=padding_mode_s, - ) - - -@_onnx_symbolic("aten::affine_grid_generator") -@symbolic_helper.parse_args("v", "v", "b") -def _affine_grid_generator( - g: jit_utils.GraphContext, - theta: _C.Value, - size: _C.Value, - align_corners: bool, -): - return g.op( - "AffineGrid", - theta, - size, - align_corners_i=int(align_corners), - ) - - -@_onnx_symbolic("aten::gelu") -@symbolic_helper.parse_args("v", "s") -def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): - return g.op("Gelu", self, approximate_s=approximate) +from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index c647ead4e297..c11e769677ec 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -1,67 +1,8 @@ -# mypy: allow-untyped-defs -""" -Note [ONNX operators that are added/updated from opset 7 to opset 8] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -New operators: - Expand +"""Backward compatibility module for torch.onnx.symbolic_opset7.""" -Updated operators: - Min, Max, Sum, Mean: supports multidirectional broadcasting. - MaxPool: added optional indices output. - Scan -""" - -import functools -import warnings - -from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) +__all__: list[str] = [] -block_listed_operators = ( - "scan", - "expand", - "expand_as", - "meshgrid", - "adaptive_max_pool1d", - "adaptive_max_pool2d", - "adaptive_max_pool3d", - "max_pool1d_with_indices", - "max_pool2d_with_indices", - "max_pool3d_with_indices", -) - - -# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. -# torch.max (same for torch.min) actually has two interfaces smashed together: -# torch.max(x, dim, keepdim) and torch.max(x, y) -@_onnx_symbolic("aten::max") -def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - # torch.max(input, other) - if keepdim is None and dim_or_y is not None: - warnings.warn( - "Multidirectional broadcasting is not supported in opset 7. " - "This might cause the onnx model to be incorrect, if inputs to max operators " - "have different shapes" - ) - return opset9.max(g, self, dim_or_y, keepdim) - - -@_onnx_symbolic("aten::min") -def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - # torch.min(input, other) - if keepdim is None and dim_or_y is not None: - warnings.warn( - "Multidirectional broadcasting is not supported in opset 7. " - "This might cause the onnx model to be incorrect, if inputs to min operators " - "have different shapes" - ) - return opset9.min(g, self, dim_or_y, keepdim) - - -for block_listed_op in block_listed_operators: - _onnx_symbolic(f"aten::{block_listed_op}")( - symbolic_helper._block_list_in_opset(block_listed_op) - ) +from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 41abf46be2a0..0e4411649f3e 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -1,463 +1,8 @@ -# mypy: allow-untyped-defs -""" -Note [ONNX operators that are added/updated from opset 8 to opset 9] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -New operators: - Compress - ConstantOfShape - EyeLike - MaxUnpool - OneHot - Sinh - Cosh - Asinh - Acosh - Atanh - Shrink - IsNaN - Sign - Erf - Scatter - Where - NonZero - TfIdfVectorizer - MeanVarianceNormalization +"""Backward compatibility module for torch.onnx.symbolic_opset8.""" -Updated operators: - BatchNormalization: removed spatial attribute. - Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. - Cast: more data types{string} supported. - Upsample: moved scales from attribute to input. - Scan -""" - -import functools -import warnings - -import torch -from torch._C import _onnx as _C_onnx -from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import jit_utils, registration +from __future__ import annotations -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) +__all__: list[str] = [] -block_listed_operators = ( - "nonzero", - "where", - "scatter", - "scatter_add", - "erf", - "sign", - "isnan", - "gather", - "arange", - "masked_fill", - "index_fill", - "index_copy", - "repeat_interleave", - "any", - "all", -) - -for block_listed_op in block_listed_operators: - _onnx_symbolic(f"aten::{block_listed_op}")( - symbolic_helper._block_list_in_opset(block_listed_op) - ) - - -@_onnx_symbolic( - "aten::upsample_nearest1d", - decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest2d", - decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_nearest3d", - decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], -) -@_onnx_symbolic( - "aten::upsample_linear1d", - decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], -) -@_onnx_symbolic( - "aten::upsample_bilinear2d", - decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], -) -@_onnx_symbolic( - "aten::upsample_trilinear3d", - decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], -) -def _interpolate(name, dim, interpolate_mode): - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = symbolic_helper._get_interpolate_attributes( - g, interpolate_mode, args - ) - symbolic_helper._interpolate_warning(interpolate_mode) - align_corners = symbolic_helper._maybe_get_scalar(align_corners) - if align_corners: - return symbolic_helper._unimplemented(name, "align_corners == True", input) - output_size = symbolic_helper._maybe_get_const(output_size, "is") - if symbolic_helper._is_value(output_size): - return symbolic_helper._unimplemented( - name, "torch._C.Value (output_size) indexing" - ) - if scales is None: - scales = [ - 1.0 - if i < 2 - else float(output_size[-(dim - i)]) - / float(input.type().sizes()[-(dim - i)]) - for i in range(0, dim) - ] - return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) - - return symbolic_fn - - -@_onnx_symbolic("aten::__interpolate") -def __interpolate( - g: jit_utils.GraphContext, - input, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, - antialias, -): - align_corners = symbolic_helper._maybe_get_const(align_corners, "b") - if not symbolic_helper._is_none(align_corners) and align_corners: - return symbolic_helper._unimplemented("interpolate", "align_corners == True") - - if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( - scale_factor - ): - return symbolic_helper._unimplemented( - "interpolate", "dynamic scales in opset 8" - ) - - if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): - return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") - - scales, mode = symbolic_helper._interpolate_get_scales_and_mode( - g, input, size, scale_factor, mode, align_corners - ) - return g.op("Upsample", input, mode_s=mode, scales_f=scales) - - -# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation -# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which -# is lost after casting. -def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): - floating_scalar_types = { - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.DOUBLE, - } - old_type = None - # Cast the input tensor to Float if its scalarType is known and is not floating number. - # If casting is performed, return the old scalarType, otherwise return None. - arg0_type = _type_utils.JitScalarType.from_value( - args[0], _type_utils.JitScalarType.UNDEFINED - ) - if arg0_type != _type_utils.JitScalarType.UNDEFINED: - old_type = arg0_type - if old_type not in floating_scalar_types: - old_type = old_type.scalar_name() # type: ignore[assignment] - args = tuple( - g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) - for arg in args - ) - else: - return (None,) + args - else: - warnings.warn( - "Only floating datatype is supported for these operators: " - "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " - "the onnx model to be incorrect, if inputs have integer datatypes." - ) - return (old_type,) + args - - -def _cast_to_type(g: jit_utils.GraphContext, input, to_type): - if to_type is None: - return input - return getattr(opset9, f"_cast_{to_type}")(g, input, False) - - -def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): - other = symbolic_helper._maybe_get_scalar(other) - other = symbolic_helper._if_scalar_type_as(other, input) - _, input, other = _try_cast_integer_to_float(g, input, other) - return g.op(op_name, input, other) - - -# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, -# integer input type not supported in opset8. Cast to float if possible. -@_onnx_symbolic("aten::gt") -def gt(g: jit_utils.GraphContext, input, other): - return _comparison_operator(g, input, other, "Greater") - - -@_onnx_symbolic("aten::lt") -def lt(g: jit_utils.GraphContext, input, other): - return _comparison_operator(g, input, other, "Less") - - -@_onnx_symbolic("aten::bmm") -def bmm(g: jit_utils.GraphContext, self, other): - if symbolic_helper._try_get_scalar_type(self): - old_type, self, other = _try_cast_integer_to_float(g, self, other) - return _cast_to_type(g, g.op("MatMul", self, other), old_type) - else: - return g.op("MatMul", self, other) - - -@_onnx_symbolic("aten::matmul") -def matmul(g: jit_utils.GraphContext, self, other): - return bmm(g, self, other) - - -@_onnx_symbolic("aten::prelu") -def prelu(g: jit_utils.GraphContext, self, weight): - self_rank = symbolic_helper._get_tensor_rank(self) - weight_sizes = symbolic_helper._get_tensor_sizes(weight) - if self_rank is not None and self_rank > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) - elif self_rank == 0 and weight_sizes == [1]: - # self and weight are both scalar but weight has rank == 1, squeeze weight. - weight = symbolic_helper._squeeze_helper(g, weight, [0]) - if symbolic_helper._try_get_scalar_type(self): - old_type, self, weight = _try_cast_integer_to_float(g, self, weight) - return _cast_to_type(g, g.op("PRelu", self, weight), old_type) - else: - return g.op("PRelu", self, weight) - - -@_onnx_symbolic("aten::mm") -def mm(g: jit_utils.GraphContext, self, other): - # Create a dummy C tensor. Only needed for API purposes, the value is - # since beta = 0 - scalar_type = symbolic_helper._try_get_scalar_type(self, other) - if scalar_type is None: - raise errors.SymbolicValueError( - "mm can only operate on tensors with known types", self - ) - zero_constant = g.op( - "Constant", - value_t=torch.tensor([0], dtype=scalar_type.dtype()), - ) - - if symbolic_helper._try_get_scalar_type(self): - old_type, self, other, zero_constant = _try_cast_integer_to_float( - g, self, other, zero_constant - ) - return _cast_to_type( - g, - g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), - old_type, - ) - return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) - - -@_onnx_symbolic("aten::addmm") -@symbolic_helper.parse_args("v", "v", "v", "t", "t") -def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): - if symbolic_helper._try_get_scalar_type(self): - old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) - return _cast_to_type( - g, - g.op( - "Gemm", - mat1, - mat2, - self, - beta_f=symbolic_helper._scalar(beta), - alpha_f=symbolic_helper._scalar(alpha), - ), - old_type, - ) - else: - return g.op( - "Gemm", - mat1, - mat2, - self, - beta_f=symbolic_helper._scalar(beta), - alpha_f=symbolic_helper._scalar(alpha), - ) - - -@_onnx_symbolic("aten::flatten") -def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): - start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") - end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") - - dim = input.type().dim() - if end_dim_i < 0: - end_dim_i = dim + end_dim_i - # use ONNX's Flatten operator for cases where the output shape is 2D - if start_dim_i == 1 and end_dim_i == dim - 1: - if symbolic_helper._try_get_scalar_type(input): - old_type, input = _try_cast_integer_to_float(g, input) - return _cast_to_type( - g, g.op("Flatten", input, axis_i=start_dim_i), old_type - ) - else: - return g.op("Flatten", input, axis_i=start_dim_i) - if start_dim_i == 0 and end_dim_i == dim - 2: - if symbolic_helper._try_get_scalar_type(input): - old_type, input = _try_cast_integer_to_float(g, input) - return _cast_to_type( - g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type - ) - else: - return g.op("Flatten", input, axis_i=end_dim_i + 1) - - return opset9.flatten(g, input, start_dim, end_dim) - - -def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - if not scalar_type.dtype().is_floating_point: - result = g.op( - "ConstantFill", - sizes, - dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), - input_as_shape_i=1, - value_f=const_value, - ) - return g.op("Cast", result, to_i=scalar_type.onnx_type()) - else: - return g.op( - "ConstantFill", - sizes, - dtype_i=scalar_type.onnx_type(), - input_as_shape_i=1, - value_f=const_value, - ) - - -@_onnx_symbolic("aten::empty") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def empty( - g: jit_utils.GraphContext, - sizes, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - return zeros(g, sizes, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::empty_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def empty_like( - g: jit_utils.GraphContext, - input, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - return zeros_like(g, input, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::zeros") -@symbolic_helper.parse_args("v", "i", "v", "v", "v") -def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): - # NOTE: no way to set device and layout in ONNX, so we ignore it - return _constant_fill(g, sizes, dtype, 0) - - -@_onnx_symbolic("aten::zeros_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def zeros_like( - g: jit_utils.GraphContext, - input, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - shape = g.op("Shape", input) - return _constant_fill(g, shape, dtype, 0) - - -@_onnx_symbolic("aten::ones") -@symbolic_helper.parse_args("v", "i", "v", "v", "v") -def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): - return _constant_fill(g, sizes, dtype, 1) - - -@_onnx_symbolic("aten::ones_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def ones_like( - g: jit_utils.GraphContext, - input, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - shape = g.op("Shape", input) - return _constant_fill(g, shape, dtype, 1) - - -@_onnx_symbolic("aten::full") -def full( - g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False -): - const_value = symbolic_helper._maybe_get_const(value, "t") - if symbolic_helper._is_value(const_value): - tmp = zeros(g, sizes, dtype, layout, device) - return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) - else: - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - return _constant_fill(g, sizes, dtype, const_value) - - -@_onnx_symbolic("aten::full_like") -@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") -def full_like( - g: jit_utils.GraphContext, - input, - fill_value, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - shape = g.op("Shape", input) - return _constant_fill(g, shape, dtype, fill_value) - - -@_onnx_symbolic("aten::repeat") -def repeat(g: jit_utils.GraphContext, self, repeats): - if not symbolic_helper._is_value(repeats): - repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - if symbolic_helper._is_packed_list(repeats): - repeat_size_len = len(symbolic_helper._unpack_list(repeats)) - else: - const_repeats = symbolic_helper._maybe_get_const(repeats, "is") - repeat_size_len = len(const_repeats) - if self.isCompleteTensor(): - sizes = self.type().sizes() - diff_dims = repeat_size_len - len(sizes) - if diff_dims > 0: - self = opset9.view( - g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) - ) - return g.op("Tile", self, repeats) +from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F401,F403 diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index af56a8751459..bd0f4795340a 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1,6653 +1,14 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type -"""This file exports ONNX ops for opset 9. - -Opset 9 is supported by ONNX release 1.4.1 -release on 01/23/19 -""" +"""Backward compatibility module for torch.onnx.symbolic_opset9.""" from __future__ import annotations -import builtins -import functools -import math -import sys -import warnings -from typing import Callable, TYPE_CHECKING -from typing_extensions import deprecated -import torch -import torch._C._onnx as _C_onnx -import torch.nn.modules.utils -import torch.onnx -from torch import _C +__all__: list[str] = [] -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import _constants, _type_utils, errors, symbolic_helper -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils, registration - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from torch.types import Number - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in README.md - -__all__ = [ - "abs", - "acos", - "add", - "addcmul", - "addmm", - "alias", - "amax", - "amin", - "aminmax", - "arange", - "argmax", - "argmin", - "as_strided", - "as_tensor", - "asin", - "atan", - "atan2", - "baddbmm", - "batch_norm", - "bernoulli", - "bitwise_not", - "bitwise_or", - "bmm", - "broadcast_tensors", - "broadcast_to", - "bucketize", - "cat", - "cdist", - "ceil", - "clamp_max", - "clamp_min", - "clamp", - "clone", - "constant_pad_nd", - "contiguous", - "conv_tbc", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - "conv1d", - "conv2d", - "conv3d", - "convert_element_type", - "convolution", - "cos", - "cosine_similarity", - "cross", - "cumsum", - "detach", - "dim", - "div", - "dot", - "dropout", - "elu", - "embedding_bag", - "embedding", - "empty_like", - "empty", - "eq", - "erf", - "exp", - "expand_as", - "expand", - "eye", - "fill", - "flatten", - "floor_divide", - "floor", - "floordiv", - "frobenius_norm", - "full_like", - "full", - "gather", - "ge", - "gelu", - "get_pool_ceil_padding", - "glu", - "group_norm", - "gt", - "hann_window", - "hardshrink", - "hardsigmoid", - "hardswish", - "hardtanh", - "index_add", - "index_copy", - "index_fill", - "index_put", - "index_select", - "index", - "instance_norm", - "is_floating_point", - "is_pinned", - "isnan", - "item", - "kl_div", - "layer_norm", - "le", - "leaky_relu", - "lerp", - "lift", - "linalg_cross", - "linalg_matrix_norm", - "linalg_norm", - "linalg_vector_norm", - "linear", - "linspace", - "log_sigmoid", - "log_softmax", - "log", - "log10", - "log1p", - "log2", - "logical_and", - "logical_not", - "logical_or", - "logical_xor", - "logit", - "logsumexp", - "lstm_cell", - "lstm", - "lt", - "masked_fill", - "masked_fill_", - "matmul", - "max_pool1d_with_indices", - "max_pool2d_with_indices", - "max_pool3d_with_indices", - "max", - "maximum", - "meshgrid", - "min", - "minimum", - "mish", - "mm", - "movedim", - "mse_loss", - "mul", - "multinomial", - "mv", - "narrow", - "native_layer_norm", - "ne", - "neg", - "new_empty", - "new_full", - "new_ones", - "new_zeros", - "nonzero_numpy", - "nonzero", - "norm", - "numel", - "numpy_T", - "one_hot", - "ones_like", - "ones", - "onnx_placeholder", - "pad", - "pairwise_distance", - "permute", - "pixel_shuffle", - "pixel_unshuffle", - "pow", - "prelu", - "prim_constant_chunk", - "prim_constant_split", - "prim_constant", - "prim_data", - "prim_device", - "prim_dtype", - "prim_if", - "prim_layout", - "prim_list_construct", - "prim_list_unpack", - "prim_loop", - "prim_max", - "prim_min", - "prim_shape", - "prim_tolist", - "prim_tuple_construct", - "prim_type", - "prim_unchecked_cast", - "prim_uninitialized", - "rand_like", - "rand", - "randint_like", - "randint", - "randn_like", - "randn", - "reciprocal", - "reflection_pad", - "relu", - "relu6", - "remainder", - "repeat_interleave", - "repeat", - "replication_pad", - "reshape_as", - "reshape", - "roll", - "rrelu", - "rsqrt", - "rsub", - "scalar_tensor", - "scatter_add", - "scatter", - "select", - "selu", - "sigmoid", - "sign", - "silu", - "sin", - "size", - "slice", - "softmax", - "softplus", - "softshrink", - "sort", - "split_with_sizes", - "split", - "sqrt", - "square", - "squeeze", - "stack", - "std_mean", - "std", - "sub", - "t", - "take", - "tan", - "tanh", - "tanhshrink", - "tensor", - "threshold", - "to", - "topk", - "transpose", - "true_divide", - "type_as", - "unbind", - "unfold", - "unsafe_chunk", - "unsafe_split_with_sizes", - "unsafe_split", - "unsqueeze", - "unsupported_complex_operators", - "noop_complex_operators", - "unused", - "var_mean", - "var", - "view_as", - "view", - "where", - "wrap_logical_op_with_cast_to", - "wrap_logical_op_with_negation", - "zeros_like", - "zeros", - "zero", -] - - -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) - - -def _export(name: str): - """Exports the function in the current global namespace.""" - - def wrapper(func): - globals()[name] = func - __all__.append(name) - return func - - return wrapper - - -def unused(g): - """Represents "missing" optional inputs.""" - n = g.op("prim::Constant") - n.setType(_C.OptionalType.ofTensor()) - return n - - -@_onnx_symbolic("aten::_shape_as_tensor") -def _shape_as_tensor(g: jit_utils.GraphContext, input): - return g.op("Shape", input) - - -@_onnx_symbolic("aten::_reshape_from_tensor") -def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): - if isinstance(shape, list): - shape = g.op("Concat", *shape, axis_i=0) - return reshape(g, input, shape) - - -@_onnx_symbolic("aten::reshape") -@symbolic_helper.quantized_args(True) -def reshape(g: jit_utils.GraphContext, self, shape): - return symbolic_helper._reshape_helper(g, self, shape) - - -@_onnx_symbolic("aten::reshape_as") -@symbolic_helper.quantized_args(True) -def reshape_as(g: jit_utils.GraphContext, self, other): - shape = g.op("Shape", other) - return reshape(g, self, shape) - - -@_onnx_symbolic("aten::add") -def add(g: jit_utils.GraphContext, self, other, alpha=None): - """ - This function takes the add function and returns the corresponding ONNX operator. - - This function is not meant to be called directly by the user. - - Args: - g (GraphContext): The graph context. - self (Tensor): The first operand. - other (Tensor): The second operand. - alpha (float, optional): The scaling factor for the second operand. Defaults to None. - - Returns: - ONNX operator. - """ - if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): - return symbolic_helper._onnx_opset_unsupported_detailed( - "Add", 9, 11, "Add between list of tensors not supported", self - ) - if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: - other = g.op("Mul", other, alpha) - return g.op("Add", self, other) - - -@_onnx_symbolic("aten::sub") -def sub(g: jit_utils.GraphContext, self, other, alpha=None): - """ - Consumes sub function and returns the corresponding ONNX operator. - - This function is not meant to be called directly by the user. - - Args: - g (GraphContext): The graph context. - self (Tensor): The first operand. - other (Tensor): The second operand. - alpha (Optional[Tensor]): A scaling factor to apply to the second operand. - If `alpha` is not provided, it defaults to 1. - - Returns: - ONNX operator - """ - if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: - other = g.op("Mul", other, alpha) - return g.op("Sub", self, other) - - -@_onnx_symbolic("aten::rsub") -def rsub(g: jit_utils.GraphContext, self, other, alpha=None): - return sub(g, other, self, alpha=alpha) - - -@_onnx_symbolic("aten::mul") -def mul(g: jit_utils.GraphContext, self, other): - if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): - # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. - return g.op("And", self, other) - else: - return g.op("Mul", self, other) - - -@_onnx_symbolic("aten::div") -def div(g: jit_utils.GraphContext, self, other, *args): - if len(args) == 0: - return true_divide(g, self, other) - else: - return _div_rounding_mode(g, self, other, *args) - - -@_onnx_symbolic("aten::addcmul") -@symbolic_helper.parse_args("v", "v", "v", "f") -def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): - value_tens = g.op("Constant", value_t=torch.tensor([value])) - return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) - - -@symbolic_helper.parse_args("v", "v", "s") -def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): - if rounding_mode is None: - return true_divide(g, self, other) - elif rounding_mode == "floor": - return _floor_divide(g, self, other) - elif rounding_mode == "trunc": - return _trunc_divide(g, self, other) - else: - raise errors.SymbolicValueError( - f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', - self, - ) - - -def _trunc_divide(g: jit_utils.GraphContext, self, other): - out = g.op("Div", self, other) - # the correct operation is truncate, which is not supported in ONNX, - # we cannot call floor since it will behave differently for negative numbers - # (eg. -0.1 should become -0 ) - # - if scalar_type information are not available, assume that - # we need to call floor (treat as float) - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) - - # Matching PyTorch's behavior: - # - if self is fp the output's type is self's type - # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT - # - self is not fp and other is not fp, the output's type is self's output type - # - the output type defaults to Float - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.UNDEFINED - ) - if scalar_type != _type_utils.JitScalarType.UNDEFINED: - if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) - else: - out = g.op( - "Cast", - out, - to_i=scalar_type.onnx_type(), - ) - else: - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) - return out - - -def _floor_divide(g: jit_utils.GraphContext, self, other): - if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): - out = true_divide(g, self, other) - return g.op("Floor", out) - else: - # Integer division does truncation rounding - div = g.op("Div", self, other) - # Division is negative if: self < 0 != other < 0 - zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) - negative = g.op( - "Xor", - symbolic_helper._lt_helper(g, self, zero), - symbolic_helper._lt_helper(g, other, zero), - ) - - # For negative numbers with self % other != 0, subtract 1 to round down instead of up - mod = g.op("Sub", self, g.op("Mul", div, other)) - fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) - - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - fixup = g.op("Mul", fixup_mask, one) - return g.op("Sub", div, fixup) - - -@_onnx_symbolic("aten::floor_divide") -def floor_divide(g: jit_utils.GraphContext, self, other): - # Deprecated behavior, floor_divide actually truncates - return _trunc_divide(g, self, other) - - -@_onnx_symbolic("aten::floordiv") -def floordiv(g: jit_utils.GraphContext, self, other): - return floor_divide(g, self, other) - - -@_onnx_symbolic("aten::true_divide") -def true_divide(g: jit_utils.GraphContext, self, other): - """Division where both inputs are cast to floating types - - If both inputs are floating, performs div as usual - If only one input is a floating type, the other input is cast to its type - If neither input is a floating type, both inputs are cast to the default scalar type - """ - - # Case 1: either values are floating - # Performs div as usual. - # Implicit casting will be handled in scalar type analysis pass. - if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): - return g.op("Div", self, other) - - # Case 2: neither is floating - # Casts both inputs to the default scalar type - scalar_type = torch.get_default_dtype() - onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT - assert scalar_type is torch.float or scalar_type is torch.double - if torch.get_default_dtype() is torch.double: - onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE - - self = g.op("Cast", self, to_i=onnx_scalar_type) - other = g.op("Cast", other, to_i=onnx_scalar_type) - return g.op("Div", self, other) - - -@_onnx_symbolic("aten::reciprocal") -def reciprocal(g: jit_utils.GraphContext, self): - # torch.reciprocal implicitly casts to float, so we do the same. - if not symbolic_helper._is_fp(self): - self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) - return g.op("Reciprocal", self) - - -@_onnx_symbolic("aten::cat") -@symbolic_helper.parse_args("v", "i") -def cat(g: jit_utils.GraphContext, tensor_list, dim): - """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. - - Parameters: - g (jit_utils.GraphContext): Graph context. - tensor_list (List[torch.Tensor]): List of tensors to concatenate. - dim (int): Dimension along which to concatenate the tensors. - - Returns: - ONNX graph node representing the concatenated tensor. - """ - tensors = symbolic_helper._unpack_list(tensor_list) - # torch.cat ignores empty tensors such as `torch.Tensor([])` - # These needs to be removed as input from ONNX's concat too, otherwise shape inference - # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) - nonempty_tensors = [] - for t in tensors: - if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( - t, 0 - ): - continue - nonempty_tensors.append(t) - assert len(nonempty_tensors) > 0 - assert all( - symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None - or symbolic_helper._get_tensor_rank(t) is None - or symbolic_helper._get_tensor_rank(t) - == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) - for t in nonempty_tensors - ) - tensor_list.node().removeAllInputs() - for t in nonempty_tensors: - tensor_list.node().addInput(t) - - tensors = symbolic_helper._unpack_list(tensor_list) - return g.op("Concat", *tensors, axis_i=dim) - - -@_onnx_symbolic("aten::stack") -@symbolic_helper.parse_args("v", "i") -def stack(g: jit_utils.GraphContext, tensor_list, dim): - unsqueezed = [ - symbolic_helper._unsqueeze_helper(g, t, [dim]) - for t in symbolic_helper._unpack_list(tensor_list) - ] - return g.op("Concat", *unsqueezed, axis_i=dim) - - -@_onnx_symbolic("aten::list") -def _list(g: jit_utils.GraphContext, self): - return self - - -@_onnx_symbolic("aten::mm") -def mm(g: jit_utils.GraphContext, self, other): - # Create a dummy C tensor. Only needed for API purposes, the value is - # since beta = 0 - C = g.op("Constant", value_t=torch.tensor([1])) - return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) - - -@_onnx_symbolic("aten::bmm") -def bmm(g: jit_utils.GraphContext, self, other): - return g.op("MatMul", self, other) - - -@_onnx_symbolic("aten::matmul") -def matmul(g: jit_utils.GraphContext, self, other): - return g.op("MatMul", self, other) - - -@_onnx_symbolic("aten::addmm") -@symbolic_helper.parse_args("v", "v", "v", "t", "t") -def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): - scalar_type = None - self_scalar_type = symbolic_helper._try_get_scalar_type(self) - mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) - mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) - if self_scalar_type is not None: - scalar_type = self_scalar_type - elif mat1_scalar_type is not None: - scalar_type = mat1_scalar_type - elif mat2_scalar_type is not None: - scalar_type = mat2_scalar_type - - mat1_rank = symbolic_helper._get_tensor_rank(mat1) - mat2_rank = symbolic_helper._get_tensor_rank(mat2) - - def is_not_none_nor(v, u): - return v is not None and v != u - - if scalar_type is not None and ( - is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) - ): - res1 = g.op("MatMul", mat1, mat2) - res2 = self - - alpha = symbolic_helper._scalar(alpha) - beta = symbolic_helper._scalar(beta) - - if alpha != 1: - alpha = g.op( - "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) - ) - res1 = g.op("Mul", res1, alpha) - if beta != 1: - beta = g.op( - "Constant", - value_t=torch.tensor( - symbolic_helper._scalar(beta), dtype=scalar_type.dtype() - ), - ) - res2 = g.op("Mul", res2, beta) - - return g.op("Add", res1, res2) - - return g.op( - "Gemm", - mat1, - mat2, - self, - beta_f=symbolic_helper._scalar(beta), - alpha_f=symbolic_helper._scalar(alpha), - ) - - -@_onnx_symbolic("aten::neg") -def neg(g: jit_utils.GraphContext, self): - return g.op("Neg", self) - - -@_onnx_symbolic("aten::sqrt") -def sqrt(g: jit_utils.GraphContext, self): - if _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.UNDEFINED - ) in { - _type_utils.JitScalarType.UINT8, - _type_utils.JitScalarType.INT8, - _type_utils.JitScalarType.INT16, - _type_utils.JitScalarType.INT, - _type_utils.JitScalarType.INT64, - }: - # torch converts all int inputs to sqrt to float - self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - return g.op("Sqrt", self) - - -@_onnx_symbolic("aten::rsqrt") -def rsqrt(g: jit_utils.GraphContext, self): - return g.op( - "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) - ) - - -@_onnx_symbolic("aten::tanh") -# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp -@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) -def tanh(g: jit_utils.GraphContext, self): - return g.op("Tanh", self) - - -@_onnx_symbolic("aten::sin") -def sin(g: jit_utils.GraphContext, self): - return g.op("Sin", self) - - -@_onnx_symbolic("aten::cos") -def cos(g: jit_utils.GraphContext, self): - return g.op("Cos", self) - - -@_onnx_symbolic("aten::tan") -def tan(g: jit_utils.GraphContext, self): - return g.op("Tan", self) - - -@_onnx_symbolic("aten::asin") -def asin(g: jit_utils.GraphContext, self): - return g.op("Asin", self) - - -@_onnx_symbolic("aten::acos") -def acos(g: jit_utils.GraphContext, self): - return g.op("Acos", self) - - -@_onnx_symbolic("aten::atan") -def atan(g: jit_utils.GraphContext, self): - return g.op("Atan", self) - - -@_onnx_symbolic("aten::atan2") -def atan2(g: jit_utils.GraphContext, self, other): - # self is y, and other is x on coordinate - slope = g.op("Div", self, other) - atan = g.op("Atan", slope) - const_zero = g.op("Constant", value_t=torch.tensor(0)) - const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) - - condition_second_or_third_quadrant = g.op("Greater", self, const_zero) - second_third_quadrant = g.op( - "Where", - condition_second_or_third_quadrant, - g.op("Add", atan, const_pi), - g.op("Sub", atan, const_pi), - ) - - condition_14_or_23_quadrant = g.op("Less", other, const_zero) - result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) - - return result - - -@_onnx_symbolic("aten::sigmoid") -# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp -@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) -def sigmoid(g: jit_utils.GraphContext, self): - """Converts the corresponding PyTorch function into ONNX operators. - - It is not meant to be called directly by a user. - - Args: - g (jit_utils.GraphContext): Graph context. - self (Tensor): the input tensor. - Returns: - ONNX operator - """ - return g.op("Sigmoid", self) - - -@_onnx_symbolic("aten::sign") -def sign(g: jit_utils.GraphContext, self): - return g.op("Sign", self) - - -@symbolic_helper.quantized_args(True) -def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): - assert len(starts) == len(ends) - if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: - return input - return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) - - -@_onnx_symbolic( - "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] +from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( # noqa: F401 + _prepare_onnx_paddings, + _reshape_from_tensor, + _slice, + _var_mean, ) -@_onnx_symbolic( - "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] -) -# torch.prod does not support multidimensional "dim" -@_onnx_symbolic( - "aten::prod", - decorate=[ - symbolic_helper._apply_params( - "ReduceProd", "prod", allow_multi_dim_support=False - ) - ], -) -def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): - return symbolic_helper._reduce_with_dtype_helper( - onnx_op, name, allow_multi_dim_support - ) - - -@_onnx_symbolic("aten::cumsum") -@symbolic_helper.parse_args("v", "i", "none") -def cumsum(g: jit_utils.GraphContext, input, dim, dtype): - symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) - - -@_onnx_symbolic("aten::_sample_dirichlet") -def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): - return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) - - -@_onnx_symbolic("aten::_standard_gamma") -def _standard_gamma(g: jit_utils.GraphContext, self, generator): - return symbolic_helper._onnx_unsupported("_standard_gamma", self) - - -@_onnx_symbolic("aten::t") -def t(g: jit_utils.GraphContext, self): - rank = symbolic_helper._get_tensor_rank(self) - if rank is None or rank < 2: - # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior - # clearly and onnxruntime fails on these cases. So we add an Identity node to - # mirror the behavior of eager mode. - return g.op("Identity", self) - return g.op("Transpose", self, perm_i=(1, 0)) - - -@_onnx_symbolic("aten::numpy_T") -@symbolic_helper.quantized_args(True) -def numpy_T(g: jit_utils.GraphContext, input): - ndim = symbolic_helper._get_tensor_rank(input) - assert ndim is not None - perm = list(reversed(range(0, ndim))) - return g.op("Transpose", input, perm_i=perm) - - -@_onnx_symbolic("aten::expand") -@symbolic_helper.quantized_args(True) -def expand(g: jit_utils.GraphContext, self, size, implicit): - """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" - size = symbolic_helper._maybe_get_const(size, "is") - if not symbolic_helper._is_value(size): - size = g.op("Constant", value_t=torch.LongTensor(size)) - elif symbolic_helper._is_packed_list(size): - # Expand with -1 dim value means dim is unchanged. - # Since onnx::expand supports two-way broadcasting, - # -1 dim value can be exported to onnx as 1 - size = symbolic_helper._reshape_helper( - g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) - ) - dtype = _type_utils.JitScalarType.INT64 - ones = ones_like(g, size, dtype) - neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) - size = where(g, g.op("Equal", size, neg_ones), ones, size) - return g.op("Expand", self, size) - - -@_onnx_symbolic("aten::broadcast_to") -@symbolic_helper.quantized_args(True) -def broadcast_to(g: jit_utils.GraphContext, self, size): - size = symbolic_helper._maybe_get_const(size, "is") - if not symbolic_helper._is_value(size): - size = g.op("Constant", value_t=torch.LongTensor(size)) - elif symbolic_helper._is_packed_list(size): - # Expand with -1 dim value means dim is unchanged. - # Since onnx::expand supports two-way broadcasting, - # -1 dim value can be exported to onnx as 1 - size = symbolic_helper._reshape_helper( - g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) - ) - dtype = _type_utils.JitScalarType.INT64 - ones = ones_like(g, size, dtype) - neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) - size = where(g, g.op("Equal", size, neg_ones), ones, size) - return g.op("Expand", self, size) - - -@_onnx_symbolic("aten::expand_as") -@symbolic_helper.quantized_args(True, True) -def expand_as(g: jit_utils.GraphContext, self, other): - self_t = symbolic_helper._maybe_get_const(self, "t") - if isinstance(self_t, torch.Tensor): - orig_type = self_t.dtype - self_t = self_t.to(torch.double) - dims = [] - for d in range(self_t.dim()): - if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): - dims.append(d) - self = g.op( - "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) - ) - - shape = g.op("Shape", other) - return g.op("Expand", self, shape) - - -@_onnx_symbolic("aten::embedding") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v", "i", "b", "v") -def embedding( - g: jit_utils.GraphContext, - weight, - indices, - padding_idx, - scale_grad_by_freq, - sparse, -): - if scale_grad_by_freq and GLOBALS.export_training: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " - "for training mode. ONNX does not support scaling the gradients.", - weight, - ) - if padding_idx >= 0 and GLOBALS.export_training: - warnings.warn( - "Warning: ONNX export of embedding with padding_idx >= 0 " - "for training mode. " - "ONNX does not support not updating the embedding vector at padding_idx during training." - ) - - return g.op("Gather", weight, indices) - - -@_onnx_symbolic("aten::embedding_bag") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") -def embedding_bag( - g: jit_utils.GraphContext, - embedding_matrix, - indices, - offsets, - scale_grad_by_freq, - mode, - sparse, - per_sample_weights, - include_last_offset, - padding_idx, -): - if not symbolic_helper._is_none(per_sample_weights): - return symbolic_helper._onnx_unsupported( - "embedding_bag with per_sample_weights" - ) - - return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) - - -@_onnx_symbolic("aten::size") -@symbolic_helper.quantized_args(True, quantize_output=False) -def size(g: jit_utils.GraphContext, self, dim=None): - if dim is None: - return g.op("Shape", self) - if symbolic_helper._maybe_get_const(dim, "i") < 0: - rank = symbolic_helper._get_tensor_rank(self) - if rank is not None: - dim = symbolic_helper._maybe_get_const(dim, "i") + rank - dim = g.op("Constant", value_t=torch.tensor(dim)) - return symbolic_helper._size_helper(g, self, dim) - - -@_onnx_symbolic("aten::transpose") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "i", "i") -def transpose(g: jit_utils.GraphContext, self, dim0, dim1): - if dim0 == dim1: # micro-optimization - return self - - # NB: Transpose in ONNX is actually a Permute - rank = symbolic_helper._get_tensor_rank(self) - if rank is not None: - axes = list(range(rank)) - axes[dim0], axes[dim1] = axes[dim1], axes[dim0] - return g.op("Transpose", self, perm_i=axes) - else: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of transpose for tensor of unknown rank.", - self, - ) - - -@_onnx_symbolic("aten::permute") -@symbolic_helper.parse_args("v", "is") -def permute(g: jit_utils.GraphContext, self, dims): - if dims == list(range(0, len(dims))): - return self - return g.op("Transpose", self, perm_i=dims) - - -@_onnx_symbolic("aten::view") -@symbolic_helper.quantized_args(True) -def view(g: jit_utils.GraphContext, self, size): - return reshape(g, self, size) - - -@_onnx_symbolic("aten::view_as") -def view_as(g: jit_utils.GraphContext, self, other): - shape = g.op("Shape", other) - return reshape(g, self, shape) - - -@_onnx_symbolic("aten::unsafe_chunk") -@symbolic_helper.parse_args("v", "i", "i", "i") -def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): - if _outputs is None: - return symbolic_helper._onnx_opset_unsupported_detailed( - "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self - ) - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - return symbolic_helper._unimplemented( - "unsafe_chunk", "unknown dimension size", self - ) - split_size = (size + chunks - 1) // chunks - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::split") -@symbolic_helper.parse_args("v", "v", "i", "i") -def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): - if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): - return symbolic_helper._onnx_opset_unsupported_detailed( - "split", 9, 11, "Dynamic number of outputs not supported", self - ) - split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") - if split_val.dim() > 0: - return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) - split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") - - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - if _outputs is not None: - size = split_size * _outputs - else: - return symbolic_helper._onnx_opset_unsupported_detailed( - "split", 9, 11, "Unknown dimension size not supported", self - ) - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::unsafe_split") -def unsafe_split( - g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None -): - return split(g, self, split_size_or_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::split_with_sizes") -@symbolic_helper.parse_args("v", "is", "i", "i") -def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): - if not symbolic_helper._is_split_static(split_sizes, _outputs): - return symbolic_helper._onnx_opset_unsupported_detailed( - "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self - ) - return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) - - -@_onnx_symbolic("aten::unsafe_split_with_sizes") -def unsafe_split_with_sizes( - g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None -): - return split_with_sizes(g, self, split_sizes, dim, _outputs) - - -@_onnx_symbolic("aten::unbind") -@symbolic_helper.parse_args("v", "i", "i") -def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): - if _outputs is None: - return symbolic_helper._onnx_opset_unsupported_detailed( - "unbind", 9, 11, "Dynamic number of outputs not supported", self - ) - - outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) - outputs = [outputs] if _outputs == 1 else outputs - squeezed_outputs = [ - symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs - ] - return squeezed_outputs - - -@_onnx_symbolic("aten::select") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "i", "v") -def select(g: jit_utils.GraphContext, self, dim, index): - """Implement the select functionality for a pytorch tensor in ONNX. - - Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. - """ - index = symbolic_helper._maybe_get_scalar(index) - if (not symbolic_helper._is_value(index)) and (index < 0): - if index == -1: - end_index = _constants.INT64_MAX - else: - end_index = index + 1 - slice_node = symbolic_helper._slice_helper( - g, self, axes=[dim], starts=[index], ends=[end_index] - ) - return symbolic_helper._squeeze_helper(g, slice_node, [dim]) - else: - # FIXME(justinchuby): can index be an int and not a value? - return g.op("Gather", self, index, axis_i=dim) - - -@_onnx_symbolic("aten::square") -def square(g: jit_utils.GraphContext, self): - return g.op("Mul", self, self) - - -@_onnx_symbolic("aten::squeeze") -def squeeze(g: jit_utils.GraphContext, self, dim=None): - if dim is None: - return g.op("Squeeze", self) - - squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") - # Handle negative dims - if squeeze_dim < 0: - rank = symbolic_helper._get_tensor_rank(self) - if rank is not None: - warnings.warn( - "ONNX export squeeze with negative axis " - + str(squeeze_dim) - + " might cause the onnx model to be incorrect. " - + "Negative axis is not supported in ONNX. " - + "Axis is converted to " - + str(squeeze_dim + rank) - + " based on input shape at export time. " - + "Passing an tensor of different rank in execution will be incorrect." - ) - squeeze_dim += rank - else: - return symbolic_helper._unimplemented( - "squeeze", "negative axis with unknown input rank", self - ) - - dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) - if dim_size is None: - warnings.warn( - "This model contains a squeeze operation on dimension " - + str(squeeze_dim) - + " on an input " - + "with unknown shape. Note that if the size of dimension " - + str(squeeze_dim) - + " of the input " - + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " - + "non-singleton dimensions, it is recommended to export this model using opset " - + "version 11 or higher." - ) - return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) - if dim_size > 1: - warnings.warn( - "This model contains a squeeze operation on dimension " - + str(squeeze_dim) - + ". The size of " - + "this dimension in the given input is " - + str(dim_size) - + ". The model will " - + "be exported without the squeeze node. If the model is intended to be used with dynamic " - + "input shapes, please use opset version 11 to " - + "export the model." - ) - return self - - warnings.warn( - "This model contains a squeeze operation on dimension " - + str(squeeze_dim) - + ". If the model is " - + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." - ) - return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) - - -@_onnx_symbolic("aten::prelu") -def prelu(g: jit_utils.GraphContext, self, weight): - self_rank = symbolic_helper._get_tensor_rank(self) - weight_sizes = symbolic_helper._get_tensor_sizes(weight) - weight_rank = len(weight_sizes) - if self_rank is not None: - if self_rank > 2: - # make weight unidirectional broadcastable - weight = symbolic_helper._unsqueeze_helper( - g, weight, list(range(1, self_rank - 1)) - ) - elif self_rank == 0 and weight_sizes == [1]: - # self and weight are both scalar but weight has rank == 1, squeeze weight. - weight = symbolic_helper._squeeze_helper(g, weight, [0]) - weight_rank = 0 - - if self_rank is not None and weight_rank is not None: - assert self_rank >= weight_rank, ( - f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" - ) - return g.op("PRelu", self, weight) - - -@_onnx_symbolic("aten::silu") -def silu(g: jit_utils.GraphContext, input): - return g.op("Mul", input, g.op("Sigmoid", input)) - - -@_onnx_symbolic("aten::mish") -def mish(g: jit_utils.GraphContext, input): - return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) - - -@_onnx_symbolic("aten::relu") -@symbolic_helper.quantized_args(True) -def relu(g: jit_utils.GraphContext, input): - return symbolic_helper._op_with_optional_float_cast( - g, "Relu", input, opset_before=14 - ) - - -@_onnx_symbolic("aten::relu6") -@symbolic_helper.quantized_args(True) -def relu6(g: jit_utils.GraphContext, input): - return clamp(g, input, 0, 6) - - -@_onnx_symbolic("aten::ceil") -def ceil(g: jit_utils.GraphContext, input): - return g.op("Ceil", input) - - -@_onnx_symbolic("aten::floor") -def floor(g: jit_utils.GraphContext, input): - return g.op("Floor", input) - - -@_onnx_symbolic("aten::len") -def _len(g: jit_utils.GraphContext, self): - sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) - return symbolic_helper._squeeze_helper(g, sz_0, [0]) - - -@_onnx_symbolic("aten::threshold") -@symbolic_helper.parse_args("v", "t", "t") -def threshold(g: jit_utils.GraphContext, self, threshold, value): - # See Note [Export inplace] - if symbolic_helper._scalar(threshold) != 0: - return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) - if symbolic_helper._scalar(value) != 0: - return symbolic_helper._unimplemented("threshold", "non-zero value", self) - return g.op("Relu", self) - - -@_onnx_symbolic("aten::leaky_relu") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "f", "b") -def leaky_relu( - g: jit_utils.GraphContext, - input: _C.Value, - negative_slope: float, - inplace: bool = False, -): - # See Note [Export inplace] - return g.op("LeakyRelu", input, alpha_f=negative_slope) - - -@_onnx_symbolic("aten::glu") -@symbolic_helper.parse_args("v", "i") -def glu(g: jit_utils.GraphContext, input, dim): - dim_size = symbolic_helper._get_tensor_dim_size(input, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - first, second = g.op("Split", input, axis_i=dim, outputs=2) - return g.op("Mul", first, g.op("Sigmoid", second)) - - -@_onnx_symbolic("aten::softmax") -@symbolic_helper.parse_args("v", "i", "none") -def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): - # Softmax does normalization at vector level. - # PyTorch and ONNX use different strategies to split the input tensor into vectors. - # Thus dim and axis have different meanings. - # PyTorch slices the input tensor into vectors along the `dim`-th dimension. - # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. - # If input is a 2 x 3 tensor: - # input = [[1.0, 1.0, 1.0], - # [1.0, 1,0, 1,0]] - # with dim = 0, the result is: - # result = [[0.5, 0.5, 0.5], - # [0.5, 0.5, 0.5]] - # with axis = 0, the result is: - # result = [[0.167, 0.167, 0.167], - # [0.167, 0.167, 0.167]] - # So only when dim and axis both equal to ndim - 1 (the last dimension), - # their semantics are equivalent. - # So use softmax when dim and axis both equal to ndim - 1, - # otherwise transpose the input to put the vectors to be normalized to the last dimension. - # When input rank is not known at export time we compute softmax using a subgraph - # with other operators - input_dim = symbolic_helper._get_tensor_rank(input) - if input_dim is not None: - # TODO: remove this as onnx opset 11 spec allows negative axes - if dim < 0: - dim = input_dim + dim - - is_transpose_required = input_dim != dim + 1 - - if is_transpose_required: - axes = list(range(input_dim)) - axes[dim], axes[-1] = axes[-1], axes[dim] - input = g.op("Transpose", input, perm_i=axes) - dim = input_dim - 1 - - softmax = g.op("Softmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - softmax = g.op( - "Cast", - softmax, - to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), - ) - - if is_transpose_required: - softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] - return softmax - - # Apply max normalization. - input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) - - exp = g.op("Exp", input) - sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) - softmax = g.op("Div", exp, sum) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - softmax = g.op( - "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() - ) - return softmax - - -@_onnx_symbolic("aten::softplus") -def softplus(g: jit_utils.GraphContext, self, beta, threshold): - beta_const = symbolic_helper._maybe_get_const(beta, "f") - if beta_const != 1: - return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) - return g.op("Softplus", self) - - -@_onnx_symbolic("aten::get_pool_ceil_padding") -def get_pool_ceil_padding(input, kernel_size, stride, padding): - # TODO(justinchuby): Looks like this op is deprecated in torch - sizes = symbolic_helper._get_tensor_sizes(input) - dim = sizes[-len(padding) :] if sizes is not None else None - if dim is None or any(i is None for i in dim): - return symbolic_helper._unimplemented( - "get_pool_ceil_padding", "input size not accessible", input - ) - ceiled_output_dim = [ - int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) - + 1 - for i in range(0, len(padding)) - ] - # ensure last pooling starts inside - ceiled_output_dim = [ - ( - ceiled_output_dim[i] - 1 - if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) - else ceiled_output_dim[i] - ) - for i in range(0, len(ceiled_output_dim)) - ] - padding_ceil = [ - ( - 0 - if (stride[i] == 1) - else ( - kernel_size[i] - - ( - dim[i] - + 2 * padding[i] - - ((ceiled_output_dim[i] - 1) * stride[i] + 1) - ) - ) - ) - for i in range(0, len(padding)) - ] - # ensure padding is not > kernel_size - padding_ceil = [ - ( - ( - int(padding_ceil[i]) - if padding_ceil[i] < kernel_size[i] - 1 - else int(kernel_size[i] - 1) - ) - if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) - else int(padding_ceil[i]) - ) - for i in range(0, len(padding_ceil)) - ] - return padding_ceil - - -@_onnx_symbolic( - "aten::max_pool1d", - decorate=[ - symbolic_helper._apply_params( - "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False - ), - _export("max_pool1d"), - ], -) -@_onnx_symbolic( - "aten::max_pool2d", - decorate=[ - symbolic_helper._apply_params( - "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False - ), - _export("max_pool2d"), - ], -) -@_onnx_symbolic( - "aten::max_pool3d", - decorate=[ - symbolic_helper._apply_params( - "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False - ), - _export("max_pool3d"), - ], -) -def _max_pool(name, tuple_fn, ndims, return_indices): - @symbolic_helper.quantized_args(True, False, False, False, False, False) - @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") - def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): - if set(tuple_fn(dilation)) != {1}: - return symbolic_helper._unimplemented(name, "dilation", input) - if not stride: - stride = kernel_size - padding = tuple(tuple_fn(padding)) - if ceil_mode: - padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) - padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) - else: - padding = padding * 2 - kwargs = { - "kernel_shape_i": tuple_fn(kernel_size), - "pads_i": padding, - "strides_i": tuple_fn(stride), - } - # easy but hacky way to get flattened indices values - # to be used to convert the indices values to non-flattened. - # In ONNX the indices are computed as a flatten 1-D tensor, - # so the values in indices are in [0, N x C x D1 x ... x Dn). - # To convert the indices to the same format used by Pytorch, - # we first execute a maxpool with a kernel and stride of 1 on the same input. - # This will result in a tensor of indices in which each index will have it's own value. - # Using this tensor as a reference, we extract the first index of each axis and subtract - # it from each index of this axis in the indices to convert. - # This step will result in a tensor were each dimension has values of indices within - # the dimension it is in. - # For more information : - # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 - if return_indices: - r, indices = g.op("MaxPool", input, outputs=2, **kwargs) - _, flattened_indices = g.op( - "MaxPool", - input, - outputs=2, - kernel_shape_i=[1 for _ in range(ndims)], - strides_i=[1 for _ in range(ndims)], - ) - # convert indices to have non-flattened indices values - s = symbolic_helper._slice_helper( - g, - flattened_indices, - axes=[2 + i for i in range(ndims)], - starts=list(tuple_fn(0)), - ends=list(tuple_fn(1)), - ) - indices = sub(g, indices, s) - return r, indices - else: - r = g.op("MaxPool", input, outputs=1, **kwargs) - return r - - return symbolic_fn - - -max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( - _max_pool( - "max_pool1d_with_indices", - torch.nn.modules.utils._single, - 1, - return_indices=True, - ) -) -max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( - _max_pool( - "max_pool2d_with_indices", - torch.nn.modules.utils._pair, - 2, - return_indices=True, - ) -) -max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( - _max_pool( - "max_pool3d_with_indices", - torch.nn.modules.utils._triple, - 3, - return_indices=True, - ) -) - - -@_onnx_symbolic( - "aten::avg_pool1d", - decorate=[ - symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), - _export("avg_pool1d"), - ], -) -@_onnx_symbolic( - "aten::avg_pool2d", - decorate=[ - symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), - _export("avg_pool2d"), - ], -) -@_onnx_symbolic( - "aten::avg_pool3d", - decorate=[ - symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), - _export("avg_pool3d"), - ], -) -def _avg_pool(name, tuple_fn): - @symbolic_helper.quantized_args(True) - @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") - def symbolic_fn( - g, - input: _C.Value, - kernel_size: Sequence[int], - stride: Sequence[int], - padding: int | Sequence[int], - ceil_mode: int, - count_include_pad: int, - divisor_override=None, - ): - if not stride: - stride = kernel_size - padding = symbolic_helper._avgpool_helper( - tuple_fn, padding, kernel_size, stride, divisor_override, name - ) - assert isinstance(padding, tuple) - adjusted_padding = padding - # Although onnx::AvgPool provides count_include_pad, - # The corner case of Average Pooling with ceil_mode on - # PyTorch allows sliding window go off bound, which leads to - # this accommodation. - # More detail on https://github.com/pytorch/pytorch/issues/57178 - if count_include_pad: - input = symbolic_helper._op_with_optional_float_cast( - g, - "Pad", - input, - pads_i=((0,) * 2 + padding) * 2, - mode_s="constant", - value_f=0.0, - opset_before=11, - ) - adjusted_padding = (0,) * len(padding) - if ceil_mode: - padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) - adjusted_padding = adjusted_padding + tuple( - a + b for (a, b) in zip(padding_ceil, adjusted_padding) - ) - else: - adjusted_padding = adjusted_padding * 2 - output = g.op( - "AveragePool", - input, - kernel_shape_i=tuple_fn(kernel_size), - strides_i=tuple_fn(stride), - pads_i=adjusted_padding, - ) - return output - - return symbolic_fn - - -@_onnx_symbolic( - "aten::adaptive_avg_pool1d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single - ), - _export("adaptive_avg_pool1d"), - ], -) -@_onnx_symbolic( - "aten::adaptive_avg_pool2d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair - ), - _export("adaptive_avg_pool2d"), - ], -) -@_onnx_symbolic( - "aten::adaptive_avg_pool3d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple - ), - _export("adaptive_avg_pool3d"), - ], -) -@_onnx_symbolic( - "aten::adaptive_max_pool1d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_max_pool1d", - "MaxPool", - torch.nn.modules.utils._single, - max_pool1d_with_indices, - ), - _export("adaptive_max_pool1d"), - ], -) -@_onnx_symbolic( - "aten::adaptive_max_pool2d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_max_pool2d", - "MaxPool", - torch.nn.modules.utils._pair, - max_pool2d_with_indices, - ), - _export("adaptive_max_pool2d"), - ], -) -@_onnx_symbolic( - "aten::adaptive_max_pool3d", - decorate=[ - symbolic_helper._apply_params( - "adaptive_max_pool3d", - "MaxPool", - torch.nn.modules.utils._triple, - max_pool3d_with_indices, - ), - _export("adaptive_max_pool3d"), - ], -) -def _adaptive_pool(name, type, tuple_fn, fn=None): - @symbolic_helper.quantized_args(True, False) - def symbolic_fn(g, input, output_size): - # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, - # by executing a GlobalPool. - # It is also supported for cases where the output size is a factor of the input size. - # For these cases the stride and kernel size are uniform along all the indices of - # the same dimension, which makes it possible to export it to ONNX. - # for MaxPool, GlobalMaxPool does not return indices, - # so we try using max_poolxd_with_indices, and if it is not possible - # (input is not a complete tensor or output size not factor of input size) - # then we call GlobalAveragePool and return None for the indices - output_size_value = output_size - try: - output_size = symbolic_helper._parse_arg(output_size, "is") - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - return symbolic_helper._onnx_unsupported( - "adaptive pooling, since output_size is not constant.", input - ) - if output_size == [1] * len(output_size) and type == "AveragePool": - return g.op("GlobalAveragePool", input) - sizes = symbolic_helper._get_tensor_sizes(input) - try: - dim = sizes[2:] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - dim = None - if dim is None or any(i is None for i in dim): - if output_size == [1] * len(output_size): - return g.op("GlobalMaxPool", input), None - return symbolic_helper._unimplemented( - name, "input size not accessible", input - ) - # verify if output size % input size = 0 for all dim - mod = [dim[i] % output_size[i] for i in range(0, len(dim))] - if mod != [0] * len(mod): - if output_size == [1] * len(output_size): - return g.op("GlobalMaxPool", input), None - return symbolic_helper._unimplemented( - name, "output size that are not factor of input size", output_size_value - ) - k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] - # call max_poolxd_with_indices to get indices in the output - if type == "MaxPool": - return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) - output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) - return output - - return symbolic_fn - - -def _prepare_onnx_paddings(dim: int, pad): - """Generate paddings in ONNX order based on pad in pytorch. - Args: - dim: the dimension of the tensor. - pad: the paddings in pytorch. - The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... - """ - # The desired order of paddings is - # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. - # n is the dimension of input. - # assume zero-dimensions in the beginning - paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) - # reverse order and collate first beginnings and then ends - paddings = paddings[-2::-2] + paddings[-1::-2] - return paddings - - -def _convert_padding_node(input): - padding = symbolic_helper._maybe_get_const(input, "is") - if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): - input_list = symbolic_helper._unpack_list(padding) - try: - padding = [ - symbolic_helper._get_const(v, "i", "padding") for v in input_list - ] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - return symbolic_helper._onnx_opset_unsupported_detailed( - "Pad", 9, 11, "The sizes of the padding must be constant", input - ) - return padding - - -@_onnx_symbolic("aten::constant_pad_nd") -def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): - mode = "constant" - try: - value = symbolic_helper._get_const(value, "f", "value") - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - return symbolic_helper._onnx_opset_unsupported_detailed( - "Pad", 9, 11, "The value for the padding must be constant", value - ) - - padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return symbolic_helper._op_with_optional_float_cast( - g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 - ) - - -def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): - padding = _convert_padding_node(pad) - assert len(padding) % 2 == 0 - ndim = len(padding) // 2 - - cur = input - for idx in range(ndim): - pad_r = padding[-(2 * idx + 1)] - pad_l = padding[-(2 * idx + 2)] - tensors = [] - if pad_l > 0: - left = symbolic_helper._slice_helper( - g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] - ) - tensors.append(left) - - if pad_l < 0 or pad_r < 0: - start = builtins.max(0, -pad_l) - end = -(builtins.max(0, -pad_r)) - middle = symbolic_helper._slice_helper( - g, - cur, - axes=[2 + idx], - starts=[start], - ends=[end], - ) - tensors.append(middle) - else: - tensors.append(cur) - - if pad_r > 0: - right = symbolic_helper._slice_helper( - g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] - ) - tensors.append(right) - - cur = g.op("Concat", *tensors, axis_i=(2 + idx)) - - return cur - - -@_onnx_symbolic("aten::reflection_pad1d") -@_onnx_symbolic("aten::reflection_pad2d") -@_onnx_symbolic("aten::reflection_pad3d") -def reflection_pad(g: jit_utils.GraphContext, input, padding): - mode = "reflect" - padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return symbolic_helper._op_with_optional_float_cast( - g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 - ) - - -@_onnx_symbolic("aten::replication_pad1d") -@_onnx_symbolic("aten::replication_pad2d") -@_onnx_symbolic("aten::replication_pad3d") -def replication_pad(g: jit_utils.GraphContext, input, padding): - mode = "edge" - padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return symbolic_helper._op_with_optional_float_cast( - g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 - ) - - -@_onnx_symbolic("aten::pad") -def pad( - g: jit_utils.GraphContext, - input: _C.Value, - pad: _C.Value, - mode: _C.Value, - value: _C.Value, -): - mode = symbolic_helper._parse_arg(mode, "s") - if mode == "replicate": - return replication_pad(g, input, pad) - elif mode == "reflect": - return reflection_pad(g, input, pad) - elif mode == "constant": - return constant_pad_nd(g, input, pad, value) - elif mode == "circular": - return _pad_circular(g, input, pad) - else: - raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) - - -@_onnx_symbolic( - "aten::upsample_nearest1d", - decorate=[ - symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), - _export("upsample_nearest1d"), - ], -) -@_onnx_symbolic( - "aten::upsample_nearest2d", - decorate=[ - symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), - _export("upsample_nearest2d"), - ], -) -@_onnx_symbolic( - "aten::upsample_nearest3d", - decorate=[ - symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), - _export("upsample_nearest3d"), - ], -) -@_onnx_symbolic( - "aten::upsample_linear1d", - decorate=[ - symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), - _export("upsample_linear1d"), - ], -) -@_onnx_symbolic( - "aten::upsample_bilinear2d", - decorate=[ - symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), - _export("upsample_bilinear2d"), - ], -) -@_onnx_symbolic( - "aten::upsample_trilinear3d", - decorate=[ - symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), - _export("upsample_trilinear3d"), - ], -) -def _interpolate(name: str, dim: int, interpolate_mode: str): - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = symbolic_helper._get_interpolate_attributes( - g, interpolate_mode, args - ) - symbolic_helper._interpolate_warning(interpolate_mode) - align_corners = symbolic_helper._maybe_get_scalar(align_corners) - if align_corners: - return symbolic_helper._unimplemented(name, "align_corners == True", input) - if scales is None: - scales = symbolic_helper._interpolate_size_to_scales( - g, input, output_size, dim - ) - return g.op("Upsample", input, scales, mode_s=interpolate_mode) - - return symbolic_fn - - -@_onnx_symbolic("aten::__interpolate") -def __interpolate( - g: jit_utils.GraphContext, - input, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, - antialias, -): - scales, mode = symbolic_helper._interpolate_get_scales_and_mode( - g, input, size, scale_factor, mode, align_corners - ) - return g.op("Upsample", input, scales, mode_s=mode) - - -@_onnx_symbolic("aten::bitwise_not") -def bitwise_not(g: jit_utils.GraphContext, input): - if not symbolic_helper._is_bool(input): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise Not " - "for non-boolean input values", - input, - ) - return g.op("Not", input) - - -@_onnx_symbolic("aten::bitwise_or") -def bitwise_or(g, self, other): - if not symbolic_helper._is_bool(self): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise OR " - "for non-boolean input values. self: ", - self, - ) - if not symbolic_helper._is_bool(other): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise OR " - "for non-boolean input values. other: ", - other, - ) - return g.op("Or", self, other) - - -def wrap_logical_op_with_cast_to(to_type): - def decorator(fn): - @functools.wraps(fn) - def wrap_with_cast(g, input, other): - to_cast_func = globals()[f"_cast_{to_type}"] - return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) - - return wrap_with_cast - - return decorator - - -def wrap_logical_op_with_negation(func: Callable) -> Callable: - @functools.wraps(func) - def wrap_with_not(g, input, other): - return g.op("Not", func(g, input, other)) - - return wrap_with_not - - -@_onnx_symbolic("aten::__not_") -def __not_(g: jit_utils.GraphContext, self): - if not symbolic_helper._is_bool(self): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise Not " - "for non-boolean input values", - self, - ) - return g.op("Not", self) - - -@_onnx_symbolic("aten::eq") -@symbolic_helper.quantized_args(True, True) -def eq(g: jit_utils.GraphContext, self, other): - if isinstance(self.type(), _C.DeviceObjType) and isinstance( - other.type(), _C.DeviceObjType - ): - # ONNX doesn't have devices, so consider them all to be equal. - # The no-op check for equality will get constant-folded. - return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) - self_node = self.node() - other_node = other.node() - if self_node.kind() == other_node.kind() == "onnx::Constant": - if self_node.kindOf("value") == other_node.kindOf("value") == "s": - # Exporting strings to ONNX is not supported. - # If both strings are constant, we can compare them directly. - # The no-op check for equality will get constant-folded. - return g.op( - "Constant", - value_t=torch.tensor( - self_node.s("value") == other_node.s("value"), - dtype=torch.bool, - ), - ) - - return g.op("Equal", self, other) - - -@_onnx_symbolic("aten::ne") -@symbolic_helper.quantized_args(True, True) -@wrap_logical_op_with_negation -def ne(g: jit_utils.GraphContext, self, other): - return eq(g, self, other) - - -@_onnx_symbolic("aten::gt") -@symbolic_helper.quantized_args(True, True) -def gt(g: jit_utils.GraphContext, input, other): - return _gt_impl(g, input, other) - - -def _gt_impl(g: jit_utils.GraphContext, input, other): - if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): - input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) - return g.op("Greater", input, other) - - -@_onnx_symbolic("aten::lt") -@symbolic_helper.quantized_args(True, True) -def lt(g: jit_utils.GraphContext, input, other): - return _lt_impl(g, input, other) - - -def _lt_impl(g: jit_utils.GraphContext, input, other): - if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): - input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) - return g.op("Less", input, other) - - -@_onnx_symbolic("aten::ge") -@symbolic_helper.quantized_args(True, True) -@wrap_logical_op_with_negation -def ge(g: jit_utils.GraphContext, input, other): - return _lt_impl(g, input, other) - - -@_onnx_symbolic("aten::le") -@symbolic_helper.quantized_args(True, True) -@wrap_logical_op_with_negation -def le(g: jit_utils.GraphContext, input, other): - return _gt_impl(g, input, other) - - -@_onnx_symbolic("aten::__and_") -def __and_(g: jit_utils.GraphContext, input, other): - if not symbolic_helper._is_bool(input): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise AND " - "for non-boolean input values", - input, - ) - if not symbolic_helper._is_bool(other): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise AND " - "for non-boolean input values", - other, - ) - return g.op("And", input, other) - - -@_onnx_symbolic("aten::__or_") -def __or_(g: jit_utils.GraphContext, input, other): - if not symbolic_helper._is_bool(input): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise OR " - "for non-boolean input values", - input, - ) - if not symbolic_helper._is_bool(other): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise OR " - "for non-boolean input values", - other, - ) - return g.op("Or", input, other) - - -@_onnx_symbolic("aten::__xor_") -def __xor_(g: jit_utils.GraphContext, input, other): - if not symbolic_helper._is_bool(input): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise XOR " - "for non-boolean input values", - input, - ) - if not symbolic_helper._is_bool(other): - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting bitwise XOR " - "for non-boolean input values", - other, - ) - return g.op("Xor", input, other) - - -@_onnx_symbolic("aten::logical_and") -@wrap_logical_op_with_cast_to("Bool") -def logical_and(g: jit_utils.GraphContext, input, other): - return g.op("And", input, other) - - -@_onnx_symbolic("aten::logical_or") -@wrap_logical_op_with_cast_to("Bool") -def logical_or(g: jit_utils.GraphContext, input, other): - return g.op("Or", input, other) - - -@_onnx_symbolic("aten::logical_xor") -@wrap_logical_op_with_cast_to("Bool") -def logical_xor(g: jit_utils.GraphContext, input, other): - return g.op("Xor", input, other) - - -@_onnx_symbolic("aten::logical_not") -def logical_not(g: jit_utils.GraphContext, input): - return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) - - -@_onnx_symbolic("aten::__rshift_") -def __rshift_(g: jit_utils.GraphContext, self, other): - # make sure to cast other to self's type - # (when self is long, make sure that other is not float) - self_scalar_type = _type_utils.JitScalarType.from_value(self) - if ( - _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) - != self_scalar_type - ): - other = g.op( - "Cast", - other, - to_i=self_scalar_type.onnx_type(), - ) - - two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) - # exponent (same type as self) has to be float or double in onnx::Pow - if not symbolic_helper._is_fp(self): - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) - two_pow = g.op("Pow", two, other) - two_pow = g.op( - "Cast", - two_pow, - to_i=self_scalar_type.onnx_type(), - ) - rshift = g.op("Div", self, two_pow) - return rshift - - -@_onnx_symbolic("aten::__lshift_") -def __lshift_(g: jit_utils.GraphContext, self, other): - # make sure to cast other to self's type - # (when self is long, make sure that other is not float) - self_scalar_type = _type_utils.JitScalarType.from_value(self) - if ( - _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) - != self_scalar_type - ): - other = g.op( - "Cast", - other, - to_i=self_scalar_type.onnx_type(), - ) - - two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) - # exponent (same type as self) has to be float or double in onnx::Pow - if not symbolic_helper._is_fp(self): - other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) - two_pow = g.op("Pow", two, other) - two_pow = g.op( - "Cast", - two_pow, - to_i=self_scalar_type.onnx_type(), - ) - lshift = g.op("Mul", self, two_pow) - return lshift - - -@_onnx_symbolic("aten::where") -@symbolic_helper.parse_args("v", "v", "v", "i") -def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): - # Assumes that torch.where's first argument takes only Bool and Byte tensors. - if not symbolic_helper._is_bool(condition): - condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) - if self is None: - condition = nonzero(g, condition) - return symbolic_helper._unbind_helper( - g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs - ) - return g.op("Where", condition, self, other) - - -@_onnx_symbolic("aten::log_softmax") -@symbolic_helper.parse_args("v", "i", "none") -def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): - # PyTorch dim and ONNX axis have different meanings. - # See Softmax comment for details. - # TODO: remove this as onnx opset 11 spec allows negative axes - input_dim = symbolic_helper._get_tensor_rank(input) - if input_dim is None: - return symbolic_helper._unimplemented( - "dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.", - ) - if dim < 0: - dim = input_dim + dim - is_transpose_required = input_dim != dim + 1 - # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. - if is_transpose_required: - axes = list(range(input_dim)) - axes[dim], axes[-1] = axes[-1], axes[dim] - input = g.op("Transpose", input, perm_i=axes) - dim = input_dim - 1 - return_op = g.op("LogSoftmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") - return_op = g.op( - "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() - ) - if is_transpose_required: - return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] - return return_op - - -@_onnx_symbolic("aten::_log_softmax") -@symbolic_helper.parse_args("v", "i", "i") -def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): - if ( - half_to_float - and _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.UNDEFINED - ) - == _type_utils.JitScalarType.HALF - ): - input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) - return log_softmax(g, input, dim) - - -@_onnx_symbolic("aten::_convolution") -@symbolic_helper.parse_args( - "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" -) -def _convolution( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - benchmark, - deterministic, - cudnn_enabled, - allow_tf32=None, -): - weight_size = symbolic_helper._get_tensor_sizes(weight) - try: - kernel_shape = weight_size[2:] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - kernel_shape = None - - if kernel_shape is None or any(i is None for i in kernel_shape): - raise errors.SymbolicValueError( - "Unsupported: ONNX export of convolution for kernel of unknown shape.", - input, - ) - - args = [input, weight] - # ONNX only supports 1D bias - if ( - not symbolic_helper._is_none(bias) - and symbolic_helper._get_tensor_rank(bias) == 1 - ): - args.append(bias) - - kwargs = { - "kernel_shape_i": weight_size[2:], - "strides_i": stride, - # NB: ONNX supports asymmetric padding, whereas PyTorch supports only - # symmetric padding - "pads_i": padding + padding, - "dilations_i": dilation, - "group_i": groups, - } - - if any(o != 0 for o in output_padding): - # ONNX supports both output_shape and output_padding. they are equivalent expressive. - # output_padding is more straightforward, so we use it here. - # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 - assert transposed - assert len(stride) == len(output_padding) - kwargs["output_padding_i"] = output_padding - - n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) - - if ( - not symbolic_helper._is_none(bias) - and symbolic_helper._get_tensor_rank(bias) != 1 - ): - return g.op("Add", n, bias) - else: - return n - - -@_onnx_symbolic("aten::_convolution_mode") -@symbolic_helper.parse_args( - "v", - "v", - "v", - "is", - "s", - "is", - "i", -) -def _convolution_mode( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - dilation, - groups, -): - weight_size = symbolic_helper._get_tensor_sizes(weight) - try: - kernel_shape = weight_size[2:] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - kernel_shape = None - - if kernel_shape is None or any(i is None for i in kernel_shape): - raise errors.SymbolicValueError( - "Unsupported: ONNX export of convolution for kernel of unknown shape.", - input, - ) - - args = [input, weight] - # ONNX only supports 1D bias - if ( - not symbolic_helper._is_none(bias) - and symbolic_helper._get_tensor_rank(bias) == 1 - ): - args.append(bias) - - if padding == "valid": - padding = "VALID" - elif padding == "same": - padding = "SAME_UPPER" - kwargs = { - "kernel_shape_i": weight_size[2:], - "strides_i": stride, - "auto_pad_s": padding, - "dilations_i": dilation, - "group_i": groups, - } - - n = g.op("Conv", *args, **kwargs) - - if ( - not symbolic_helper._is_none(bias) - and symbolic_helper._get_tensor_rank(bias) != 1 - ): - return g.op("Add", n, bias) - else: - return n - - -@_onnx_symbolic("aten::convolution") -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") -def convolution( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, -): - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv1d") -@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") -def conv1d( - g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups -): - str_padding = symbolic_helper._parse_arg(padding, "s") - if str_padding in ["valid", "same"]: - return _convolution_mode( - g, - input, - weight, - bias, - stride, - str_padding, - dilation, - groups, - ) - else: - padding = symbolic_helper._parse_arg(padding, "is") - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - False, - (), - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv2d") -@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") -def conv2d( - g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups -): - str_padding = symbolic_helper._parse_arg(padding, "s") - if str_padding in ["valid", "same"]: - return _convolution_mode( - g, - input, - weight, - bias, - stride, - str_padding, - dilation, - groups, - ) - else: - padding = symbolic_helper._parse_arg(padding, "is") - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - False, - (), - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv3d") -@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") -def conv3d( - g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups -): - str_padding = symbolic_helper._parse_arg(padding, "s") - if str_padding in ["valid", "same"]: - return _convolution_mode( - g, - input, - weight, - bias, - stride, - str_padding, - dilation, - groups, - ) - else: - padding = symbolic_helper._parse_arg(padding, "is") - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - False, - (), - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv_transpose1d") -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") -def conv_transpose1d( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - output_padding, - groups, - dilation, -): - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - True, - output_padding, - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv_transpose2d") -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") -def conv_transpose2d( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - output_padding, - groups, - dilation, -): - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - True, - output_padding, - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::conv_transpose3d") -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") -def conv_transpose3d( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - output_padding, - groups, - dilation, -): - return _convolution( - g, - input, - weight, - bias, - stride, - padding, - dilation, - True, - output_padding, - groups, - None, - None, - None, - None, - ) - - -@_onnx_symbolic("aten::batch_norm") -@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") -def batch_norm( - g: jit_utils.GraphContext, - input, - weight, - bias, - running_mean, - running_var, - training, - momentum, - eps, - cudnn_enabled, -): - symbolic_helper.check_training_mode(training, "batch_norm") - - if ( - torch.is_autocast_enabled() - and not symbolic_helper.args_have_same_dtype( - [input, weight, bias, running_mean, running_var] - ) - and GLOBALS.export_onnx_opset_version < 15 - ): - return symbolic_helper._onnx_opset_unsupported_detailed( - "BatchNormalization", - 9, - 15, - "All input tensors must have the same `dtype`." - " Turn off Autocast or export using opset version 15.", - input, - ) - - weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( - g, input, weight, bias, running_mean, running_var - ) - out = g.op( - "BatchNormalization", - input, - weight, - bias, - running_mean, - running_var, - epsilon_f=eps, - momentum_f=1 - momentum, - outputs=1 if not training else 5, - ) - if not training: - return out - else: - res, new_running_mean, new_running_var, saved_mean, saved_var = out - new_running_mean.setType(running_mean.type()) - new_running_var.setType(running_var.type()) - saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) - saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) - return res - - -@_onnx_symbolic("aten::native_layer_norm") -@symbolic_helper.quantized_args(True, False, False, False) -@symbolic_helper.parse_args("v", "is", "v", "v", "f") -def native_layer_norm( - g: jit_utils.GraphContext, - input: _C.Value, - normalized_shape: Sequence[int], - weight: _C.Value, - bias: _C.Value, - eps: float, -) -> tuple[_C.Value, _C.Value, _C.Value]: - axes = [-i for i in range(len(normalized_shape), 0, -1)] - - two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) - eps_cst = symbolic_helper._generate_wrapped_number(g, eps) - - if g.opset < 18: - mean = g.op("ReduceMean", input, axes_i=axes) - else: - mean = g.op( - "ReduceMean", - input, - g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), - ) - - numerator = sub(g, input, mean) - - # Cast it to eps dtype to avoid precision loss - is_type_half = ( - _type_utils.JitScalarType.from_value(numerator) - == _type_utils.JitScalarType.HALF - ) - if is_type_half: - eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) - numerator = g.op( - "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() - ) - - # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula - if g.opset < 18: - variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) - else: - variance = g.op( - "ReduceMean", - pow(g, numerator, two_cst), - g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), - ) - - denominator = sqrt(g, g.op("Add", variance, eps_cst)) - normalized = g.op("Div", numerator, denominator) - - # Cast back to input type as eps related ops are all done - if is_type_half: - input_dtype = _type_utils.JitScalarType.from_value(input) - normalized = g.op( - "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() - ) - - if not (weight is None or symbolic_helper._is_none(weight)): - normalized = mul(g, normalized, weight) - if not (bias is None or symbolic_helper._is_none(bias)): - normalized = add(g, normalized, bias) - - # rdenominator := 1 / sqrt(variance + eps) - # According to aten::native_layer_norm, rdenominator should have the same dtype as input, - # mean and normalized, so we need to Cast it back - if is_type_half: - denominator = g.op( - "Cast", - denominator, - to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] - ) - rdenominator = g.op("Reciprocal", denominator) - else: - rdenominator = reciprocal(g, denominator) - - return normalized, mean, rdenominator - - -@_onnx_symbolic("aten::layer_norm") -@symbolic_helper.quantized_args(True, False, False, False) -@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") -def layer_norm( - g: jit_utils.GraphContext, - input: _C.Value, - normalized_shape: Sequence[int], - weight: _C.Value, - bias: _C.Value, - eps: float, - cudnn_enable: bool, -) -> _C.Value: - normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) - return normalized - - -@_onnx_symbolic("aten::instance_norm") -@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") -def instance_norm( - g: jit_utils.GraphContext, - input, - weight, - bias, - running_mean, - running_var, - use_input_stats: bool, - momentum: Number, - eps: Number, - cudnn_enabled: bool, -): - symbolic_helper.check_training_mode(use_input_stats, "instance_norm") - channel_size = symbolic_helper._get_tensor_dim_size(input, 1) - if weight is None or symbolic_helper._is_none(weight): - if channel_size is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of instance_norm for unknown channel size.", - input, - ) - weight_value = torch.tensor( - [1.0] * channel_size, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ) - weight = g.op("Constant", value_t=weight_value) - if bias is None or symbolic_helper._is_none(bias): - if channel_size is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of instance_norm for unknown channel size.", - input, - ) - bias_value = torch.tensor( - [0.0] * channel_size, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ) - bias = g.op("Constant", value_t=bias_value) - if ( - running_mean is None - or symbolic_helper._is_none(running_mean) - or running_var is None - or symbolic_helper._is_none(running_var) - ): - return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) - else: - input_size = symbolic_helper._get_tensor_sizes(input) - # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. - # For more information instance_norm(): - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 - input_size_reshape = input_size.copy() - n = input_size[0] - if n is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of instance_norm training for unknown " - "batch size.", - input, - ) - c = input_size[1] - input_size_reshape[0] = 1 - input_size_reshape[1] = n * c - weight_ = repeat( - g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) - ) - bias_ = repeat( - g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) - ) - running_mean_ = repeat( - g, - running_mean, - g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), - ) - running_var_ = repeat( - g, - running_var, - g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), - ) - input_reshaped = g.op( - "Reshape", - input, - g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), - ) - out = batch_norm( - g, - input_reshaped, - weight_, - bias_, - running_mean_, - running_var_, - use_input_stats, - momentum, - eps, - cudnn_enabled, - ) - return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) - - -@_onnx_symbolic("aten::unfold") -@symbolic_helper.parse_args("v", "i", "i", "i") -def unfold(g: jit_utils.GraphContext, input, dimension, size, step): - sizes = symbolic_helper._get_tensor_sizes(input) - # FIXME(justinchuby): Get rid of the try catch here to improve readability - try: - sizedim = sizes[dimension] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - sizedim = None - if sizedim is not None: - low_indices = range(0, sizedim, step) - hi_indices = range(size, sizedim + 1, step) - stack = [ - symbolic_helper._slice_helper( - g, input, axes=[dimension], starts=[low], ends=[hi] - ) - for low, hi in zip(low_indices, hi_indices) - ] - ndim = len(sizes) - perm = list(range(0, ndim)) - perm.append(perm.pop(dimension)) - unsqueeze = [ - symbolic_helper._unsqueeze_helper( - g, g.op("Transpose", t, perm_i=perm), [dimension] - ) - for t in stack - ] - return g.op("Concat", *unsqueeze, axis_i=dimension) - else: - return symbolic_helper._unimplemented( - "Unfold", "input size not accessible", input - ) - - -@_onnx_symbolic("aten::elu") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "t", "t", "t") -def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): - if scale and scale != 1.0: - return symbolic_helper._unimplemented( - "scale", "does not support scale in Elu", scale - ) - if input_scale and input_scale != 1.0: - return symbolic_helper._unimplemented( - "input_scale", "does not support input_scale in Elu", input_scale - ) - # See Note [Export inplace] - return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) - - -@_onnx_symbolic("aten::selu") -@symbolic_helper.quantized_args(True) -def selu(g: jit_utils.GraphContext, input): - return g.op("Selu", input) - - -@_onnx_symbolic("aten::index_select") -@symbolic_helper.parse_args("v", "i", "v") -def index_select(g: jit_utils.GraphContext, self, dim, index): - # In case of a scalar index, index_select returns a tensor with the same rank as the input. - # To match this behavior in ONNX, we make index a 1D tensor so that the following gather - # also produces a tensor with the same rank as the input. - return symbolic_helper._select_helper(g, self, dim, index) - - -@_onnx_symbolic("aten::index_put") -def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): - if symbolic_helper._is_packed_list(indices_list_value): - indices_list = symbolic_helper._unpack_list(indices_list_value) - else: - indices_list = [indices_list_value] - - accumulate = symbolic_helper._parse_arg(accumulate, "b") - - if len(indices_list) == 0: - if accumulate: - return add(g, self, values) - return values - symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) - - -@_onnx_symbolic("aten::index_fill") -def index_fill(g: jit_utils.GraphContext, self, dim, index, value): - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( - g, self, dim, index - ) - value = symbolic_helper._maybe_get_scalar(value) - value = symbolic_helper._if_scalar_type_as(value, self) - expanded_value = expand(g, value, expanded_index_shape, None) - - return scatter(g, self, dim, expanded_index, expanded_value) - - -@_onnx_symbolic("aten::index_copy") -def index_copy(g: jit_utils.GraphContext, self, dim, index, source): - _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( - g, self, dim, index - ) - return scatter(g, self, dim, expanded_index, source) - - -@_onnx_symbolic("aten::bucketize") -@symbolic_helper.parse_args("v", "v", "b", "b") -def bucketize( - g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False -): - out_type = _C_onnx.TensorProtoDataType.INT64 - if out_int32: - out_type = _C_onnx.TensorProtoDataType.INT32 - # A tensor expanded_boundaries is created such that it - # contains a copy of boundaries for each element of self. - new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) - # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops - # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - tensor_rank = symbolic_helper._get_tensor_rank(self) - assert tensor_rank is not None - unsqueeze_axes = list(range(1, tensor_rank + 1)) - expanded_boundaries = expand( - g, - symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), - new_shape, - None, - ) - # Compare each element of self to boundaries to get a tensor - # with leading 1s and trailing 0s. - # e.g., 4 > [1, 3, 4] = [1, 1, 0] - # The index of the last 1 is the bucket where the element should go. - if right: - cond = ge(g, self, expanded_boundaries) - else: - cond = gt(g, self, expanded_boundaries) - cond_out = g.op("Cast", cond, to_i=out_type) - # Sum to get the number of 1s corresponding to each element, - # which is the same as the bucket index. - # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 - return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) - - -@_onnx_symbolic("aten::type_as") -def type_as(g: jit_utils.GraphContext, self, other): - self_dtype = symbolic_helper._try_get_scalar_type(self) - other_dtype = symbolic_helper._try_get_scalar_type(other) - if self_dtype == other_dtype and self_dtype is not None: - return self - if other_dtype is not None: - return g.op( - "Cast", - self, - to_i=other_dtype.onnx_type(), - ) - - raise errors.SymbolicValueError( - "Unsupported: ONNX export of type_as for tensor " - "of unknown dtype. Please check if the dtype of the " - "parameter passed to the type_as function is correct.", - other, - ) - - -@_onnx_symbolic("aten::cosine_similarity") -@symbolic_helper.parse_args("v", "v", "i", "f") -def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): - cross = symbolic_helper._reducesum_helper( - g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 - ) - x1_l2 = symbolic_helper._reducesum_helper( - g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 - ) - x2_l2 = symbolic_helper._reducesum_helper( - g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 - ) - div_tens = max( - g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) - ) - return div(g, cross, div_tens) - - -@_onnx_symbolic("aten::pairwise_distance") -def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): - if not symbolic_helper._is_value(eps): - eps = g.op("Constant", value_t=torch.tensor([eps])) - inv_p = div( - g, - g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), - add(g, p, eps), - ) - summation = symbolic_helper._reducesum_helper( - g, - pow(g, sub(g, input1, input2), p), - axes_i=[-1], - keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), - ) - return pow(g, summation, inv_p) - - -@_onnx_symbolic("aten::clone") -# ignore clone operators that are inserted by PyTorch autograd -def clone(g: jit_utils.GraphContext, input, unused_memory_format): - return input - - -@_onnx_symbolic("aten::abs") -def abs(g: jit_utils.GraphContext, self): - return g.op("Abs", self) - - -@_onnx_symbolic("aten::log") -def log(g: jit_utils.GraphContext, self): - return g.op("Log", self) - - -@_onnx_symbolic("aten::log1p") -def log1p(g: jit_utils.GraphContext, self): - return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) - - -@_onnx_symbolic("aten::log10") -def log10(g: jit_utils.GraphContext, self): - _ln10 = 2.30258509299404568401 - return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) - - -@_onnx_symbolic("aten::pow") -def pow(g: jit_utils.GraphContext, self, exponent): - f_dtype = _type_utils.JitScalarType.from_value(self) - if not symbolic_helper._is_fp(self): - f_dtype = _type_utils.JitScalarType.FLOAT - self = g.op("Cast", self, to_i=f_dtype.onnx_type()) - if not symbolic_helper._is_fp(exponent): - exponent = g.op( - "Cast", - exponent, - to_i=f_dtype.onnx_type(), - ) - pow = g.op("Pow", self, exponent) - return pow - - -@_onnx_symbolic("aten::clamp") -def clamp(g: jit_utils.GraphContext, self, min, max): - # min or max may be None that we need to dispatch to - # Clip separately, as ONNX does not have None syntax - if symbolic_helper._is_none(min): - return clamp_max(g, self, max) - elif symbolic_helper._is_none(max): - return clamp_min(g, self, min) - else: - if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): - return symbolic_helper._op_with_optional_float_cast( - g, - "Clip", - self, - min_f=symbolic_helper._parse_arg(min, "f"), - max_f=symbolic_helper._parse_arg(max, "f"), - opset_before=12, - ) - else: - return clamp_max(g, clamp_min(g, self, min), max) - - -@_onnx_symbolic("aten::clamp_min") -@symbolic_helper.parse_args("v", "v") -def clamp_min(g: jit_utils.GraphContext, self, min): - if symbolic_helper._is_constant(min): - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 - ) - else: - dtype = _type_utils.JitScalarType.from_value(self) - min = g.op("Cast", min, to_i=dtype.onnx_type()) - return symbolic_helper._op_with_optional_float_cast( - g, "Max", self, min, opset_before=12 - ) - - -@_onnx_symbolic("aten::clamp_max") -@symbolic_helper.parse_args("v", "v") -def clamp_max(g: jit_utils.GraphContext, self, max): - if symbolic_helper._is_constant(max): - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 - ) - else: - dtype = _type_utils.JitScalarType.from_value(self) - max = g.op("Cast", max, to_i=dtype.onnx_type()) - return symbolic_helper._op_with_optional_float_cast( - g, "Min", self, max, opset_before=12 - ) - - -@_onnx_symbolic("aten::max") -# torch.max (same for torch.min) actually has two interfaces smashed together: -# torch.max(x, dim, keepdim) and torch.max(x, y) -# TODO(justinchuby): Support multiple quantized args in output -def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) - - -@_onnx_symbolic("aten::maximum") -@symbolic_helper.quantized_args(True, True) -def maximum(g: jit_utils.GraphContext, input, other): - return max(g, input, dim_or_y=other) - - -@_onnx_symbolic("aten::min") -# TODO(justinchuby): Support multiple quantized args in output -def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): - return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) - - -@_onnx_symbolic("aten::minimum") -@symbolic_helper.quantized_args(True, True) -def minimum(g: jit_utils.GraphContext, input, other): - return min(g, input, dim_or_y=other) - - -@_onnx_symbolic("aten::amax") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "is", "i") -def amax(g: jit_utils.GraphContext, self, dim, keepdim): - return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::amin") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "is", "i") -def amin(g: jit_utils.GraphContext, self, dim, keepdim): - return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::aminmax") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v", "i") -def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): - reduce_kwargs = {"keepdims_i": keepdim} - if not symbolic_helper._is_none(dim): - dim = symbolic_helper._get_const(dim, "i", "dim") - reduce_kwargs["axes_i"] = [dim] - - return g.op("ReduceMin", self, **reduce_kwargs), g.op( - "ReduceMax", self, **reduce_kwargs - ) - - -@_onnx_symbolic("aten::exp") -def exp(g: jit_utils.GraphContext, self): - return g.op("Exp", self) - - -@_onnx_symbolic("aten::dropout_") -@_onnx_symbolic("aten::dropout") -@symbolic_helper.parse_args("v", "f", "i") -def dropout(g: jit_utils.GraphContext, input, p, train): - symbolic_helper.check_training_mode(train, "dropout") - # if train is False, dropout is no-op - if not train: - return input - r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) - return r - - -@_onnx_symbolic( - "aten::alpha_dropout_", - decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], -) # See Note [Export inplace] -@_onnx_symbolic( - "aten::feature_alpha_dropout_", - decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], -) -@_onnx_symbolic( - "aten::feature_dropout_", - decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], -) -@_onnx_symbolic( - "aten::feature_alpha_dropout", - decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], -) -@_onnx_symbolic( - "aten::alpha_dropout", - decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], -) -@_onnx_symbolic( - "aten::feature_dropout", - decorate=[symbolic_helper._apply_params("aten::feature_dropout")], -) -def _unsupported_dropout(name: str): - @symbolic_helper.parse_args("v", "none", "b") - def feature_dropout(g, input, p, train): - # NB: In inference mode, FeatureDropout is exported as an identity op. - if train: - return symbolic_helper._unimplemented(name, "training mode", input) - return input - - return feature_dropout - - -@_onnx_symbolic("aten::norm") -@symbolic_helper.parse_args("v", "t", "is", "i", "v") -def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): - if p == 1: - f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") - elif p == 2: - f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") - else: - raise errors.SymbolicValueError( - "ONNX export only p-norms with p of 1 or 2", self - ) - result = f(g, self, dim=dim, keepdim=keepdim) - if dtype is not None: - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - return result - - -@_onnx_symbolic("aten::conv_tbc") -@symbolic_helper.parse_args("v", "v", "v", "i") -def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): - # input must have 3 dimensions, see: - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 - # input = (time, batch, in_channels) - # weight = (kernel_width, in_channels, out_channels) - # bias = (out_channels,) - input = g.op("Transpose", input, perm_i=[1, 2, 0]) - weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) - conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) - return g.op("Transpose", conv, perm_i=[2, 0, 1]) - - -@_onnx_symbolic("aten::_unique") -@symbolic_helper.parse_args("v", "i", "i") -def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): - return symbolic_helper._onnx_unsupported("_unique", input) - - -@_onnx_symbolic("aten::_unique2") -@symbolic_helper.parse_args("v", "i", "i", "i") -def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): - symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) - - -@_onnx_symbolic("aten::_cast_Byte") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) - - -@_onnx_symbolic("aten::_cast_Char") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) - - -@_onnx_symbolic("aten::_cast_Short") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) - - -@_onnx_symbolic("aten::_cast_Int") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) - - -@_onnx_symbolic("aten::_cast_Long") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) - - -@_onnx_symbolic("aten::_cast_Half") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - - -@_onnx_symbolic("aten::_cast_Float") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - -@_onnx_symbolic("aten::_cast_Double") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) - - -@_onnx_symbolic("aten::_cast_Bool") -@deprecated("Avoid using this function and create a Cast node instead") -def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): - return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) - - -@_onnx_symbolic("aten::empty") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def empty( - g: jit_utils.GraphContext, - sizes, - dtype, - layout, - device, - pin_memory=False, - memory_format=None, -): - return zeros(g, sizes, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::empty_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def empty_like( - g: jit_utils.GraphContext, - input, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - return zeros_like(g, input, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::new_empty") -def new_empty( - g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False -): - self_dtype = symbolic_helper._try_get_scalar_type(self) - if symbolic_helper._is_none(dtype) and self_dtype is not None: - dtype = self_dtype - return empty(g, sizes, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::scalar_tensor") -def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - dtype = _type_utils.JitScalarType.FLOAT - scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - return scalar - - -@_onnx_symbolic("aten::tensor") -def tensor( - g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False -): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if symbolic_helper._is_packed_list(data): - if dtype is None: - dtype = _type_utils.JitScalarType.from_value( - symbolic_helper._unpack_list(data)[0] - ) - input_list = [] - for t in symbolic_helper._unpack_list(data): - shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) - t = symbolic_helper._reshape_helper(g, t, shape_reference) - t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - input_list.append(t) - return g.op("Concat", *input_list, axis_i=0) - else: - if dtype is None: - dtype = _type_utils.JitScalarType.from_value(data) - if symbolic_helper._is_list(data) and ( - symbolic_helper._is_tensor_list(data) - or symbolic_helper._is_scalar_list(data) - ): - data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) - return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - - -@_onnx_symbolic("aten::as_tensor") -def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): - return tensor(g, data, dtype, device) - - -@_onnx_symbolic("aten::zeros") -@symbolic_helper.parse_args("v", "i", "v", "v", "v") -def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): - # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - sizes_ = symbolic_helper._maybe_get_const(sizes, "is") - if isinstance(sizes_, list) and len(sizes_) == 0: - sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) - return g.op( - "ConstantOfShape", - sizes, - value_t=torch.tensor([0], dtype=scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::zeros_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def zeros_like( - g: jit_utils.GraphContext, - input, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - shape = g.op("Shape", input) - if symbolic_helper._is_none(dtype): - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - else: - scalar_type = _type_utils.JitScalarType(dtype) - return g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([0], dtype=scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::new_zeros") -def new_zeros( - g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False -): - self_dtype = symbolic_helper._try_get_scalar_type(self) - - if symbolic_helper._is_none(dtype) and self_dtype is not None: - dtype = self_dtype - return zeros(g, sizes, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::zero") -def zero(g: jit_utils.GraphContext, self): - self_dtype = symbolic_helper._try_get_scalar_type(self) - return zeros_like(g, self, self_dtype) - - -@_onnx_symbolic("aten::ones") -@symbolic_helper.parse_args("v", "i", "v", "v", "v") -def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - sizes_ = symbolic_helper._maybe_get_const(sizes, "is") - if isinstance(sizes_, list) and len(sizes_) == 0: - sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) - return g.op( - "ConstantOfShape", - sizes, - value_t=torch.tensor([1], dtype=scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::ones_like") -@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") -def ones_like( - g: jit_utils.GraphContext, - input, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - shape = g.op("Shape", input) - if symbolic_helper._is_none(dtype): - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - else: - scalar_type = _type_utils.JitScalarType(dtype) - return g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([1], dtype=scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::new_ones") -def new_ones( - g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False -): - self_dtype = symbolic_helper._try_get_scalar_type(self) - if symbolic_helper._is_none(dtype) and self_dtype is not None: - dtype = self_dtype - return ones(g, sizes, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::full") -def full( - g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False -): - const_value = symbolic_helper._maybe_get_const(value, "t") - if symbolic_helper._is_value(const_value): - dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype - tmp = zeros(g, sizes, dtype, layout, device) - return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) - else: - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - sizes_ = symbolic_helper._maybe_get_const(sizes, "is") - if isinstance(sizes_, list) and len(sizes_) == 0: - sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) - return g.op( - "ConstantOfShape", - sizes, - value_t=const_value.view(1).to(scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::full_like") -def full_like( - g: jit_utils.GraphContext, - input, - fill_value, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - fill_value = symbolic_helper._maybe_get_const(fill_value, "f") - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - else: - scalar_type = _type_utils.JitScalarType(dtype) - if symbolic_helper._is_value(fill_value): - tmp = zeros_like(g, input, dtype, layout, device) - fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) - return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) - else: - shape = g.op("Shape", input) - return g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), - ) - - -@_onnx_symbolic("aten::new_full") -def new_full( - g: jit_utils.GraphContext, - self, - size, - fill_value, - dtype, - layout, - device, - pin_memory=False, -): - self_dtype = symbolic_helper._try_get_scalar_type(self) - if symbolic_helper._is_none(dtype) and self_dtype is not None: - dtype = self_dtype - return full(g, size, fill_value, dtype, layout, device, pin_memory) - - -@_onnx_symbolic("aten::eye") -def eye(g: jit_utils.GraphContext, *args): - if len(args) == 5: - # aten::eye(n, dtype, layout, device, pin_memory) - n, dtype, layout, device, _pin_memory = args - dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) - shape = g.op("Concat", dim_size, dim_size, axis_i=0) - tensor = zeros(g, shape, dtype, layout, device) - return g.op("EyeLike", tensor) - if len(args) == 6: - # aten::eye(n, m, dtype, layout, device, pin_memory) - n, m, dtype, layout, device, _pin_memory = args - shape = g.op( - "Concat", - symbolic_helper._unsqueeze_helper(g, n, [0]), - symbolic_helper._unsqueeze_helper(g, m, [0]), - axis_i=0, - ) - tensor = zeros(g, shape, dtype, layout, device) - return g.op("EyeLike", tensor) - - return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") - - -@_onnx_symbolic("aten::slice") -def slice(g: jit_utils.GraphContext, self, *args): - if len(args) == 4: - # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor - dim, start, end, step = args - step = symbolic_helper._parse_arg(step, "i") - if step != 1: - raise errors.SymbolicValueError("step!=1 is currently not supported", self) - is_start_none = start.node().kind() == "prim::Constant" and isinstance( - start.type(), _C.NoneType - ) - is_end_none = end.node().kind() == "prim::Constant" and isinstance( - end.type(), _C.NoneType - ) - is_start_onnx_const = start.node().kind() == "onnx::Constant" - is_end_onnx_const = end.node().kind() == "onnx::Constant" - if ( - ((not is_start_none) and (not is_start_onnx_const)) - or ((not is_end_none) and (not is_end_onnx_const)) - or dim.node().kind() != "onnx::Constant" - ): - if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " - "is a deprecated experimental op. Please use statically allocated " - "variables or export to a higher opset version.", - self, - ) - else: - start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) - end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) - dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) - return g.op( - "DynamicSlice", - self, - start_unsqueezed, - end_unsqueezed, - dim_unsqueezed, - ) - else: - start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") - end = ( - _constants.INT64_MAX - if is_end_none - else symbolic_helper._parse_arg(end, "i") - ) - dim = symbolic_helper._parse_arg(dim, "i") - return symbolic_helper._slice_helper( - g, self, axes=[dim], starts=[start], ends=[end] - ) - elif len(args) == 3: - # aten::slice(t[] l, int start, int end, int step) -> t[] - start, end, step = args - dim = 0 - is_start_none = start.node().kind() == "prim::Constant" and isinstance( - start.type(), _C.NoneType - ) - is_end_none = end.node().kind() == "prim::Constant" and isinstance( - end.type(), _C.NoneType - ) - start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") - end = ( - _constants.INT64_MAX - if is_end_none - else symbolic_helper._parse_arg(end, "i") - ) - return symbolic_helper._slice_helper( - g, self, axes=[dim], starts=[start], ends=[end] - ) - - return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") - - -@_onnx_symbolic("aten::hardtanh") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "f", "f") -def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): - return symbolic_helper._op_with_optional_float_cast( - g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 - ) - - -@_onnx_symbolic("aten::hardswish") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v") -def hardswish(g: jit_utils.GraphContext, self): - hs = hardsigmoid(g, self) - return g.op("Mul", self, hs) - - -@_onnx_symbolic("aten::hardsigmoid") -# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp -@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) -@symbolic_helper.parse_args("v") -def hardsigmoid(g: jit_utils.GraphContext, self): - # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. - # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html - return g.op("HardSigmoid", self, alpha_f=1 / 6) - - -@_onnx_symbolic("aten::tanhshrink") -@symbolic_helper.parse_args("v") -def tanhshrink(g: jit_utils.GraphContext, self): - return g.op("Sub", self, tanh(g, self)) - - -@_onnx_symbolic("aten::hardshrink") -@symbolic_helper.parse_args("v", "f") -def hardshrink(g: jit_utils.GraphContext, self, lambd): - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - lambd_op = g.op( - "Constant", - value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), - ) - cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) - return g.op( - "Where", - cond, - self, - g.op( - "Constant", - value_t=torch.tensor(0, dtype=scalar_type.dtype()), - ), - ) - - -@_onnx_symbolic("aten::softshrink") -@symbolic_helper.parse_args("v", "f") -def softshrink(g: jit_utils.GraphContext, self, lambd): - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - lambd_op = g.op( - "Constant", - value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), - ) - gt_cond = gt(g, self, lambd_op) - gt_out = g.op( - "Where", - gt_cond, - sub(g, self, lambd_op), - g.op( - "Constant", - value_t=torch.tensor(0, dtype=scalar_type.dtype()), - ), - ) - lt_cond = lt(g, self, neg(g, lambd_op)) - lt_out = g.op( - "Where", - lt_cond, - add(g, self, lambd_op), - g.op( - "Constant", - value_t=torch.tensor(0, dtype=scalar_type.dtype()), - ), - ) - return add(g, gt_out, lt_out) - - -@_onnx_symbolic("aten::alias") -def alias(g: jit_utils.GraphContext, self): - return self - - -@_onnx_symbolic("aten::unsqueeze") -@symbolic_helper.parse_args("v", "i") -def unsqueeze(g: jit_utils.GraphContext, self, dim): - """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" - # Handle negative dim - if dim < 0: - rank = symbolic_helper._get_tensor_rank(self) - if rank is not None: - warnings.warn( - "ONNX export unsqueeze with negative axis " - + str(dim) - + " might cause the onnx model to be incorrect. " - + "Negative axis is not supported in ONNX. " - + "Axis is converted to " - + str(dim + rank + 1) - + " based on input shape at export time. " - + "Passing an tensor of different rank in execution will be incorrect." - ) - dim = dim + rank + 1 - else: - return symbolic_helper._unimplemented( - "unsqueeze", "negative axis with unknown input rank", self - ) - - return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) - - -@_onnx_symbolic("aten::sort") -# TODO(justinchuby): Support multiple quantized args in output -@symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): - if out is not None: - symbolic_helper._unimplemented( - "Sort", "Out parameter is not supported for sort", self - ) - self_sizes = symbolic_helper._get_tensor_sizes(self) - try: - dim_size = self_sizes[dim] - except Exception: - # FIXME(justinchuby): Avoid catching Exception. - # Catch a more specific exception instead. - dim_size = None - - if dim_size is None: - return symbolic_helper._unimplemented("Sort", "input size not accessible", self) - - return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) - - -@_onnx_symbolic("aten::numel") -def numel(g: jit_utils.GraphContext, self): - return symbolic_helper._numel_helper(g, self) - - -@_onnx_symbolic("aten::topk") -# TODO(justinchuby): Support multiple quantized args in output -@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") -def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): - if out is not None: - symbolic_helper._unimplemented( - "TopK", "Out parameter is not supported for topk", self - ) - if not largest: - symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) - - return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) - - -@_onnx_symbolic("prim::convert_element_type") -def convert_element_type(g: jit_utils.GraphContext, self, *args): - dtype = symbolic_helper._get_const(args[0], "i", "dtype") - return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - - -@_onnx_symbolic("aten::to") -def to(g: jit_utils.GraphContext, self, *args): - def is_aten_to_device_only(args): - if len(args) == 4: - # aten::to(Tensor, Device, bool, bool, memory_format) - return ( - args[0].node().kind() == "prim::device" - or args[0].type().isSubtypeOf(_C.ListType.ofInts()) - or isinstance(args[0].type(), _C.DeviceObjType) - ) - elif len(args) == 5: - # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) - # When dtype is None, this is a aten::to(device) call - dtype = symbolic_helper._get_const(args[1], "i", "dtype") - return dtype is None - elif len(args) in (6, 7): - # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor - # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor - # When dtype is None, this is a aten::to(device) call - dtype = symbolic_helper._get_const(args[0], "i", "dtype") - return dtype is None - return False - - # ONNX doesn't have a concept of a device, so we ignore device-only casts - if is_aten_to_device_only(args): - return self - - if len(args) == 4: - # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() - # In this case, the constant value is a tensor not int, - # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. - dtype = args[0] - if ( - symbolic_helper._is_value(args[0]) - and args[0].node().kind() == "onnx::Constant" - ): - tval = symbolic_helper._node_get(args[0].node(), "value") - if isinstance(tval, torch.Tensor): - if len(tval.shape) == 0: - tval = tval.item() - dtype = int(tval) - else: - dtype = tval - - if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): - # aten::to(Tensor, Tensor, bool, bool, memory_format) - dtype = _type_utils.JitScalarType.from_value(args[0]) - return g.op( - "Cast", - self, - to_i=dtype.onnx_type(), - ) - else: - # aten::to(Tensor, ScalarType, bool, bool, memory_format) - # memory_format is ignored - return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - elif len(args) == 5: - # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) - dtype = symbolic_helper._get_const(args[1], "i", "dtype") - # memory_format is ignored - return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - elif len(args) == 6: - # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor - dtype = symbolic_helper._get_const(args[0], "i", "dtype") - # Layout, device and memory_format are ignored - return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - elif len(args) == 7: - # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor - dtype = symbolic_helper._get_const(args[0], "i", "dtype") - # Layout, device and memory_format are ignored - return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) - - return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) - - -@_onnx_symbolic("aten::repeat") -def repeat(g: jit_utils.GraphContext, self, repeats): - dtype = _type_utils.JitScalarType.INT64 - shape_ = ones_like(g, repeats, dtype) - self = g.op("Expand", self, shape_) - return g.op("Tile", self, repeats) - - -@_onnx_symbolic("aten::repeat_interleave") -def repeat_interleave( - g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None -): - repeats_dim = symbolic_helper._get_tensor_rank(repeats) - repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) - input_sizes = symbolic_helper._get_tensor_sizes(self) - if repeats_dim is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", - self, - ) - if repeats_sizes is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", - self, - ) - if input_sizes is None: - raise errors.SymbolicValueError( - "Unsupported: ONNX export of repeat_interleave for unknown input size.", - self, - ) - - # if dim is None flatten - # By default, use the flattened input array, and return a flat output array - if symbolic_helper._is_none(dim): - self = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([-1])) - ) - dim = torch.tensor(0, dtype=torch.int64) - else: - dim = symbolic_helper._maybe_get_scalar(dim) - - # Handle cases where dim is negative - if dim < 0: - dim += len(input_sizes) - - input_sizes_temp = input_sizes.copy() - for idx, input_size in enumerate(input_sizes): - if input_size is None: - input_sizes[idx], input_sizes_temp[idx] = 0, -1 - - # Cases where repeats is an int or single value tensor - if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): - if input_sizes[dim] == 0: - return symbolic_helper._onnx_opset_unsupported_detailed( - "repeat_interleave", - 9, - 13, - "Unsupported along dimension with unknown input size", - self, - ) - return symbolic_helper._repeat_interleave_single_value_repeat_helper( - g, self, repeats, dim - ) - - # Cases where repeats is a 1 dim Tensor - elif repeats_dim == 1: - if input_sizes[dim] == 0: - return symbolic_helper._onnx_opset_unsupported_detailed( - "repeat_interleave", - 9, - 13, - "Unsupported along dimension with unknown input size", - self, - ) - if repeats_sizes[0] is None: - return symbolic_helper._onnx_opset_unsupported_detailed( - "repeat_interleave", - 9, - 13, - "Unsupported for cases with dynamic repeats", - self, - ) - assert repeats_sizes[0] == input_sizes[dim], ( - "repeats must have the same size as input along dim" - ) - reps = repeats_sizes[0] - else: - raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) - - final_splits = [] - r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) - i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) - input_sizes[dim], input_sizes_temp[dim] = -1, 1 - for idx, r_split in enumerate(r_splits): - i_split = unsqueeze(g, i_splits[idx], dim + 1) - r_concat = [ - g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), - r_split, - g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), - ] - r_concat = g.op("Concat", *r_concat, axis_i=0) - i_split = expand(g, i_split, r_concat, None) - i_split = symbolic_helper._reshape_helper( - g, - i_split, - g.op("Constant", value_t=torch.LongTensor(input_sizes)), - allowzero=0, - ) - final_splits.append(i_split) - return g.op("Concat", *final_splits, axis_i=dim) - - -@_onnx_symbolic("aten::pixel_shuffle") -@symbolic_helper.parse_args("v", "i") -def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): - dims = symbolic_helper._get_tensor_sizes(self) - if len(dims) != 4: - return symbolic_helper._unimplemented( - "pixel_shuffle", "only support 4d input", self - ) - if any(i is None for i in dims[1:]): - after_view = symbolic_helper._reshape_helper( - g, - symbolic_helper._unsqueeze_helper(g, self, [2, 3]), - g.op( - "Constant", - value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), - ), - allowzero=0, - ) - after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) - # For dynamic input shapes, two reshapes are performed - reshape_h = symbolic_helper._reshape_helper( - g, - after_transpose, - g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), - allowzero=0, - ) - reshape_w = symbolic_helper._reshape_helper( - g, - reshape_h, - g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), - allowzero=0, - ) - return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) - else: - output_channel = dims[1] // upscale_factor // upscale_factor - after_view = symbolic_helper._reshape_helper( - g, - self, - g.op( - "Constant", - value_t=torch.tensor( - [ - -1, - output_channel, - upscale_factor, - upscale_factor, - dims[2], - dims[3], - ] - ), - ), - allowzero=0, - ) - after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) - return symbolic_helper._reshape_helper( - g, - after_transpose, - g.op( - "Constant", - value_t=torch.tensor( - [ - -1, - output_channel, - dims[2] * upscale_factor, - dims[3] * upscale_factor, - ] - ), - ), - allowzero=0, - ) - - -@_onnx_symbolic("aten::pixel_unshuffle") -@symbolic_helper.parse_args("v", "i") -def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): - dims = symbolic_helper._get_tensor_sizes(self) - if len(dims) != 4: - return symbolic_helper._unimplemented( - "pixel_shuffle", "only support 4d input", self - ) - if any(i is None for i in dims[1:]): - # For dynamic input shapes, two reshapes are performed - reshape_h = symbolic_helper._reshape_helper( - g, - symbolic_helper._unsqueeze_helper(g, self, [3]), - g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), - allowzero=0, - ) - reshape_w = symbolic_helper._reshape_helper( - g, - reshape_h, - g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), - allowzero=0, - ) - after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) - final_reshape = symbolic_helper._reshape_helper( - g, - after_transpose, - g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), - allowzero=0, - ) - return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) - else: - output_channel = dims[1] * downscale_factor * downscale_factor - after_view = symbolic_helper._reshape_helper( - g, - self, - g.op( - "Constant", - value_t=torch.tensor( - [ - -1, - dims[1], - dims[2] // downscale_factor, - downscale_factor, - dims[3] // downscale_factor, - downscale_factor, - ] - ), - ), - allowzero=0, - ) - after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) - return symbolic_helper._reshape_helper( - g, - after_transpose, - g.op( - "Constant", - value_t=torch.tensor( - [ - -1, - output_channel, - dims[2] // downscale_factor, - dims[3] // downscale_factor, - ] - ), - ), - allowzero=0, - ) - - -def _generic_rnn( - g: jit_utils.GraphContext, - variant, - input, - initial_states, - all_weights, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first=None, - batch_sizes=None, -): - warnings.warn( - "Exporting a model to ONNX with a batch_size other than 1, " - + "with a variable length with " - + variant - + " can cause an error " - + "when running the ONNX model with a different batch size. " - + "Make sure to save the model with a batch size of 1, " - + "or define the initial states (h0/c0) as inputs of the model. " - ) - - onnxActivations = [ - "Relu", - "Tanh", - "Sigmoid", - "Affine", - "LeakyRelu", - "ThresholdedRelu", - "ScaledTanh", - "HardSigmoid", - "Elu", - "Softsign", - "Softplus", - ] - variantToOnnxActivationMap = dict( - zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) - ) - weights_per_layer = 4 if has_biases else 2 - # this means that projections are used inside LSTM, so need to tell user that it's not supported - if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( - 1 + bidirectional - ): - return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) - assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) - layer_weights = [ - all_weights[i : i + weights_per_layer] - for i in range(0, len(all_weights), weights_per_layer) - ] - if batch_first: - # batch, seq, feat -> seq, batch, feat - input = g.op("Transpose", input, perm_i=[1, 0, 2]) - if dropout and train: - return symbolic_helper._unimplemented( - "RNN/GRU/LSTM", "dropout in training mode", input - ) - - if variant.startswith("RNN"): - nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] - variant = "RNN" - - w_hh = all_weights[1] - hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) - if hidden_size is None: - return symbolic_helper._unimplemented( - "RNN/GRU/LSTM", "unknown hidden size", input - ) - - unidirectional = not bidirectional - - prev_output = input - - h_outs = [] - if variant == "RNN" or variant == "GRU": - h0 = initial_states - elif variant == "LSTM": - h0, c0 = initial_states - c_outs = [] - - sequence_lens = unused(g) if batch_sizes is None else batch_sizes - - if variant == "GRU": - # pytorch is reset, input, hidden - # onnx is input, reset, hidden - reform_permutation = [(1, 2), (0, 1), (2, 3)] - elif variant == "LSTM": - # pytorch is input, forget, cell, output. - # onnx is input, output, forget, cell. - reform_permutation = [(0, 1), (3, 4), (1, 3)] - - def reform_weights(g, w, n, intervals): - slices = [ - symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) - for x, y in intervals - ] - return g.op("Concat", *slices, axis_i=0) - - def transform_weights_no_bias(layer_index): - weights = layer_weights[layer_index] - if variant == "RNN": - weight_ih, weight_hh = weights - elif variant == "GRU" or variant == "LSTM": - weight_ih, weight_hh = ( - reform_weights(g, w, hidden_size, reform_permutation) for w in weights - ) - return tuple( - symbolic_helper._unsqueeze_helper(g, x, [0]) - for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] - ) - - def transform_weights(layer_index): - weights = layer_weights[layer_index] - if variant == "RNN": - weight_ih, weight_hh, bias_ih, bias_hh = weights - elif variant == "GRU" or variant == "LSTM": - weight_ih, weight_hh, bias_ih, bias_hh = ( - reform_weights(g, w, hidden_size, reform_permutation) for w in weights - ) - bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] - return tuple( - symbolic_helper._unsqueeze_helper(g, x, [0]) - for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] - ) - - def retrieve_state(x, start, end): - return ( - x - if num_layers == 1 - else symbolic_helper._slice_helper( - g, x, axes=[0], starts=[start], ends=[end] - ) - ) - - for i in range(num_layers): - if unidirectional: - if weights_per_layer == 4: - weight_ih, weight_hh, bias_concat = transform_weights(i) - else: - weight_ih, weight_hh = transform_weights_no_bias(i) - bias_concat = unused(g) - - state_indices = i, i + 1 - else: - if weights_per_layer == 4: - weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) - weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) - bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) - else: - weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) - weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) - bias_concat = unused(g) - - weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) - weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) - - state_indices = 2 * i, 2 * i + 2 - - inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] - - inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] - if variant == "LSTM": - inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] - - extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} - if variant == "RNN": - if bidirectional: - activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] - else: - activation = [nonlinearity] # type: ignore[possibly-undefined] - - prev_output, h_out = g.op( - "RNN", - *inputs, - outputs=2, - hidden_size_i=hidden_size, - activations_s=activation, - **extra_kwargs, - ) - elif variant == "GRU": - prev_output, h_out = g.op( - "GRU", - *inputs, - outputs=2, - hidden_size_i=hidden_size, - linear_before_reset_i=1, - **extra_kwargs, - ) - elif variant == "LSTM": - prev_output, h_out, c_out = g.op( - "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs - ) - - if bidirectional: - # The ONNX RNN/GRU/LSTM produce an output of dimensions - # seq_len, num_directions, batch, hidden_size - # We have to convert to match pytorch's expected - # seq_len, batch, num_directions * hidden_size - # by first moving num_directions before hidden_size with - # Transpose, and then combining it with hidden_size - # with Reshape. - prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) - prev_output = symbolic_helper._reshape_helper( - g, - prev_output, - g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), - allowzero=0, - ) - else: - prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) - - h_outs.append(h_out) # type: ignore[possibly-undefined] - if variant == "LSTM": - c_outs.append(c_out) # type: ignore[possibly-undefined] - if batch_first: - # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size - prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) - h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] - if variant == "RNN" or variant == "GRU": - return prev_output, h_outs - elif variant == "LSTM": - c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] - return prev_output, h_outs, c_outs - - -@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") -def _lstm_full( - g: jit_utils.GraphContext, - input, - hidden_v, - weight_v, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first, -): - hidden, weight = ( - symbolic_helper._unpack_list(hidden_v), - symbolic_helper._unpack_list(weight_v), - ) - return _generic_rnn( - g, - "LSTM", - input, - hidden, - weight, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first, - ) - - -@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") -def _lstm_packed( - g: jit_utils.GraphContext, - input, - batch_sizes, - hidden_v, - weight_v, - has_biases, - num_layers, - dropout, - train, - bidirectional, -): - hidden, weight = ( - symbolic_helper._unpack_list(hidden_v), - symbolic_helper._unpack_list(weight_v), - ) - return _generic_rnn( - g, - "LSTM", - input, - hidden, - weight, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_sizes=batch_sizes, - ) - - -@_onnx_symbolic("aten::lstm") -def lstm(g: jit_utils.GraphContext, *args): - if symbolic_helper._is_tensor_list(args[3]): - return _lstm_packed(g, *args) - else: - return _lstm_full(g, *args) - - -@_onnx_symbolic("aten::lstm_cell") -def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): - input = symbolic_helper._unsqueeze_helper(g, self, [0]) - hidden = symbolic_helper._unpack_list(hidden) - hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] - weight = ( - (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) - ) - has_biases = True if symbolic_helper._is_tensor(b_ih) else False - _, h_outs, c_outs = _generic_rnn( - g, - "LSTM", - input, - hidden, - weight, - has_biases, - num_layers=1, - dropout=0, - train=0, - bidirectional=False, - batch_first=False, - ) - return symbolic_helper._squeeze_helper( - g, h_outs, [0] - ), symbolic_helper._squeeze_helper(g, c_outs, [0]) - - -@_onnx_symbolic( - "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] -) -@_onnx_symbolic( - "aten::rnn_tanh", - decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], -) -@_onnx_symbolic( - "aten::rnn_relu", - decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], -) -def _one_hidden_rnn(kind: str): - @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") - def _rnn_full( - g, - input, - hidden, - weight_v, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first, - ): - weight = symbolic_helper._unpack_list(weight_v) - return _generic_rnn( - g, - kind, - input, - hidden, - weight, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first, - ) - - @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") - def _rnn_packed( - g, - input, - batch_sizes, - hidden, - weight_v, - has_biases, - num_layers, - dropout, - train, - bidirectional, - ): - weight = symbolic_helper._unpack_list(weight_v) - return _generic_rnn( - g, - kind, - input, - hidden, - weight, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_sizes=batch_sizes, - ) - - def symbolic(g, *args): - if symbolic_helper._is_tensor_list(args[3]): - return _rnn_packed(g, *args) - else: - return _rnn_full(g, *args) - - return symbolic - - -@_onnx_symbolic("aten::_dim_arange") -@symbolic_helper.parse_args("v", "i") -def _dim_arange(g: jit_utils.GraphContext, like, dim): - like_shape = g.op("Shape", like) - stop = g.op( - "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 - ) - # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - return arange(g, stop, 4, None, None, None) - - -@_onnx_symbolic("aten::detach") -def detach(g: jit_utils.GraphContext, input): - # Erase aten::detach nodes because ONNX is inference only - return input - - -@_onnx_symbolic("aten::contiguous") -@symbolic_helper.parse_args("v", "i") -def contiguous(g: jit_utils.GraphContext, input, memory_format): - if memory_format > 2: # allower values are any, preserve and contiguous_format - raise errors.SymbolicValueError( - "onnx memory_format support is not implemented", input - ) - return input - - -@_onnx_symbolic("aten::_pack_padded_sequence") -@symbolic_helper.parse_args("v", "v", "i") -def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): - # Currently there is no PackPadded operator in ONNX. We rely on an - # optimization pass to remove this later. It is an error if all - # PackPadded operators cannot be optimized out. - if batch_first: - input = g.op("Transpose", input, perm_i=[1, 0, 2]) - if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): - raise errors.SymbolicValueError( - "'lengths' must be a Tensor for ONNX export", input - ) - # We know it's a TensorType so this check is now safe. - # It's really only necessary because those operators expand to something that - # only works with int32 types in Caffe2... - if ( - _type_utils.JitScalarType.from_value( - lengths, _type_utils.JitScalarType.UNDEFINED - ) - != _type_utils.JitScalarType.INT - ): - lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) - return g.op("prim::PackPadded", input, lengths, outputs=2) - - -@_onnx_symbolic("aten::_pad_packed_sequence") -@symbolic_helper.parse_args("v", "v", "i", "t", "v") -def _pad_packed_sequence( - g: jit_utils.GraphContext, - data, - batch_sizes, - batch_first, - padding_value, - total_length, -): - # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence - # It is only useful/used when training using data_parallel model, so - # It shouldn't be relevant for ONNX anyway - data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) - if batch_first: - data = g.op("Transpose", data, perm_i=[1, 0, 2]) - return data, lengths - - -@_onnx_symbolic("aten::randint") -def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - low_i = symbolic_helper._get_const(low, "i", "low") - high_i = symbolic_helper._get_const(high, "i", "high") - if dtype is None: - scalar_type = _type_utils.JitScalarType.INT64 - else: - scalar_type = _type_utils.JitScalarType(dtype) - if low_i is None: - raise symbolic_helper._onnx_unsupported("randint", low) - if high_i is None: - raise symbolic_helper._onnx_unsupported("randint", high) - - shape = symbolic_helper._maybe_get_const(shapes, "is") - if symbolic_helper._is_value(shape): - shape_const = g.op( - "ConstantOfShape", - shapes, - value_t=torch.tensor([0], dtype=torch.float), - ) - randn = g.op( - "RandomUniformLike", - shape_const, - low_f=low_i, - high_f=high_i, - ) - else: - randn = g.op( - "RandomUniform", - shape_i=shape, - low_f=low_i, - high_f=high_i, - ) - - # cast to integer type - int_dtype = _type_utils.JitScalarType.INT64 - randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) - if int_dtype != scalar_type: - randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) - return randint - - -@_onnx_symbolic("aten::randint_like") -def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - low_i = symbolic_helper._get_const(low, "i", "low") - high_i = symbolic_helper._get_const(high, "i", "high") - if dtype is None: - scalar_type = _type_utils.JitScalarType.INT64 - else: - scalar_type = _type_utils.JitScalarType(dtype) - if low_i is None: - raise symbolic_helper._onnx_unsupported("randint", low) - if high_i is None: - raise symbolic_helper._onnx_unsupported("randint", high) - - randn = g.op( - "RandomUniformLike", - self, - low_f=low_i, - high_f=high_i, - ) - - # cast to integer type - int_dtype = _type_utils.JitScalarType.INT64 - randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) - if int_dtype != scalar_type: - randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) - return randint - - -@_onnx_symbolic("aten::randn") -def randn(g: jit_utils.GraphContext, shapes, dtype, *options): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - shape = symbolic_helper._maybe_get_const(shapes, "is") - if symbolic_helper._is_value(shape): - shape_const = g.op( - "ConstantOfShape", - shapes, - value_t=torch.tensor([0], dtype=torch.float), - ) - return g.op( - "RandomNormalLike", - shape_const, - dtype_i=scalar_type.onnx_type(), - ) - return g.op( - "RandomNormal", - shape_i=shape, - dtype_i=scalar_type.onnx_type(), - ) - - -@_onnx_symbolic("aten::rand") -def rand(g: jit_utils.GraphContext, shapes, dtype, *options): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - scalar_type = _type_utils.JitScalarType.FLOAT - else: - scalar_type = _type_utils.JitScalarType(dtype) - shape = symbolic_helper._maybe_get_const(shapes, "is") - if symbolic_helper._is_value(shape): - shape_const = g.op( - "ConstantOfShape", - shapes, - value_t=torch.tensor([0], dtype=torch.float), - ) - return g.op( - "RandomUniformLike", - shape_const, - dtype_i=scalar_type.onnx_type(), - ) - return g.op( - "RandomUniform", - shape_i=shape, - dtype_i=scalar_type.onnx_type(), - ) - - -@_onnx_symbolic("aten::randn_like") -def randn_like( - g: jit_utils.GraphContext, - self, - dtype, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - else: - scalar_type = _type_utils.JitScalarType(dtype) - return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) - - -@_onnx_symbolic("aten::rand_like") -def rand_like( - g: jit_utils.GraphContext, - self, - dtype, - layout=None, - device=None, - pin_memory=False, - memory_format=None, -): - dtype = symbolic_helper._get_const(dtype, "i", "dtype") - if dtype is None: - dtype = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - return g.op( - "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() - ) - - -@_onnx_symbolic("aten::rrelu") -@symbolic_helper.parse_args("v", "f", "f", "i", "none") -def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): - if not training: - slope = (upper + lower) / 2.0 - return g.op("LeakyRelu", input, alpha_f=slope) - p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) - return g.op("PRelu", input, p) - - -@_onnx_symbolic("aten::bernoulli") -def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): - if out is not None and not symbolic_helper._is_none(out): - symbolic_helper._unimplemented( - "Bernoulli", "out parameter is not supported for bernoulli", input - ) - if generator is not None and not symbolic_helper._is_none(generator): - symbolic_helper._unimplemented( - "Bernoulli", "generator is not supported for bernoulli", input - ) - - dtype = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.UNDEFINED - ) - if dtype == _type_utils.JitScalarType.UNDEFINED: - return symbolic_helper._unimplemented( - "Bernoulli", "input dtype not accessible", input - ) - - rands = g.op( - "RandomUniformLike", - input, - high_f=1.0, - low_f=0.0, - dtype_i=dtype.onnx_type(), - ) - prob = p if p is not None and not symbolic_helper._is_none(p) else input - output = g.op("Less", rands, prob) - return g.op("Cast", output, to_i=dtype.onnx_type()) - - -@_onnx_symbolic("aten::log_sigmoid") -@symbolic_helper.parse_args("v") -def log_sigmoid(g: jit_utils.GraphContext, input): - p = g.op("Sigmoid", input) - return g.op("Log", p) - - -@_onnx_symbolic("aten::erf") -@symbolic_helper.parse_args("v") -def erf(g: jit_utils.GraphContext, input): - return g.op("Erf", input) - - -@_onnx_symbolic("aten::flatten") -@symbolic_helper.quantized_args(True, False, False) -@symbolic_helper.parse_args("v", "i", "i") -def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): - dim = symbolic_helper._get_tensor_rank(input) - if dim is None: - return symbolic_helper._unimplemented( - "dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.", - input, - ) - - if dim == 0: - return symbolic_helper._reshape_helper(g, input, [1]) - if dim == 1: - return g.op("Identity", input) - # TODO: remove this as onnx opset 11 spec allows negative axes - if end_dim < 0: - end_dim = dim + end_dim - # use ONNX's Flatten operator for cases where the output shape is 2D - if start_dim == 1 and end_dim == dim - 1: - return g.op("Flatten", input, axis_i=start_dim) - if start_dim == 0 and end_dim == dim - 2: - return g.op("Flatten", input, axis_i=end_dim + 1) - - return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) - - -@_onnx_symbolic("aten::nonzero") -@symbolic_helper.parse_args("v") -def nonzero(g: jit_utils.GraphContext, input): - """Emitted from `torch.nonzero(x, as_tuple=False)`""" - return t(g, g.op("NonZero", input)) - - -@_onnx_symbolic("aten::nonzero_numpy") -# Emitted from `torch.nonzero(x, as_tuple=True)` -def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): - return unbind(g, nonzero(g, input), 1, _outputs=_outputs) - - -@_onnx_symbolic("aten::isnan") -@symbolic_helper.parse_args("v") -def isnan(g: jit_utils.GraphContext, input): - output = g.op("IsNaN", input) - return output - - -@_onnx_symbolic("aten::any") -def _any(g: jit_utils.GraphContext, *args): - # aten::any(Tensor self) - if len(args) == 1: - input = args[0] - dim, keepdim = None, 0 - # aten::any(Tensor self, int[]? dim, bool keepdim) - else: - input, dim, keepdim = args - # Can be int list or single int - dim = symbolic_helper._parse_arg(dim, "t") - dim = [int(d) for d in dim.view(-1)] - keepdim = symbolic_helper._parse_arg(keepdim, "i") - input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) - input_sum = symbolic_helper._reducesum_helper( - g, input, axes_i=dim, keepdims_i=keepdim - ) - return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) - - -@_onnx_symbolic("aten::all") -def _all(g: jit_utils.GraphContext, *args): - input = g.op("Not", args[0]) - # aten::all(Tensor self) - if len(args) == 1: - return g.op("Not", _any(g, input)) - # aten::all(Tensor self, int[]? dim, bool keepdim) - else: - return g.op("Not", _any(g, input, args[1], args[2])) - - -@_onnx_symbolic("aten::narrow") -@symbolic_helper.parse_args("v", "i", "i", "i") -def narrow(g: jit_utils.GraphContext, input, dim, start, length): - return symbolic_helper._slice_helper( - g, input, axes=[dim], starts=[start], ends=[start + length] - ) - - -@_onnx_symbolic("aten::argmax") -@symbolic_helper.parse_args("v", "v", "b") -def argmax( - g: jit_utils.GraphContext, - input: torch._C.Value, - dim: torch._C.Value, - keepdim: bool, -): - return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") - - -@_onnx_symbolic("aten::argmin") -@symbolic_helper.parse_args("v", "v", "b") -def argmin( - g: jit_utils.GraphContext, - input: torch._C.Value, - dim: torch._C.Value, - keepdim: bool, -): - return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") - - -@_onnx_symbolic("aten::scatter") -@symbolic_helper.parse_args("v", "i", "v", "v") -def scatter(g: jit_utils.GraphContext, self, dim, index, src): - src_type = _type_utils.JitScalarType.from_value( - src, _type_utils.JitScalarType.UNDEFINED - ) - src = symbolic_helper._maybe_get_scalar(src) - if symbolic_helper._is_value(src): - return g.op("Scatter", self, index, src, axis_i=dim) - else: - # Check if scalar "src" has same type as self (PyTorch allows different - # type for scalar src (but not when src is tensor)). If not, insert Cast node. - self_scalar_type = _type_utils.JitScalarType.from_value(self) - if self_scalar_type != src_type: - src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) - return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) - - -@_onnx_symbolic("aten::scatter_add") -@symbolic_helper.parse_args("v", "i", "v", "v") -def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): - scalar_type = symbolic_helper._try_get_scalar_type(self) - if scalar_type is None: - return symbolic_helper._unimplemented( - "scatter_add", "input dtype not accessible", self - ) - sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) - if sizes: - to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) - else: - to_add = zeros_like(g, self, scalar_type) - to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) - return add(g, self, to_add) - - -@_onnx_symbolic("aten::log2") -def log2(g: jit_utils.GraphContext, self): - _ln2 = 0.693147180559945309 - return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) - - -@_onnx_symbolic("aten::is_floating_point") -def is_floating_point(g: jit_utils.GraphContext, self): - if symbolic_helper._is_fp(self): - return g.op("Constant", value_t=torch.BoolTensor([1])) - return g.op("Constant", value_t=torch.BoolTensor([0])) - - -@_onnx_symbolic("aten::__is_") -def __is_(g: jit_utils.GraphContext, self, other): - if symbolic_helper._is_none(other): - if symbolic_helper._is_none(self): - return g.op("Constant", value_t=torch.BoolTensor([1])) - return g.op("Constant", value_t=torch.BoolTensor([0])) - return eq(g, self, other) - - -@_onnx_symbolic("aten::__isnot_") -@wrap_logical_op_with_negation -def __isnot_(g: jit_utils.GraphContext, self, other): - return __is_(g, self, other) - - -@_onnx_symbolic("aten::one_hot") -def one_hot(g: jit_utils.GraphContext, self, num_classes): - values = g.op("Constant", value_t=torch.LongTensor([0, 1])) - # onnxruntime supports limited type combinations for OneHot. - if _type_utils.JitScalarType.from_value( - num_classes, _type_utils.JitScalarType.UNDEFINED - ) in { - _type_utils.JitScalarType.UINT8, - _type_utils.JitScalarType.INT8, - _type_utils.JitScalarType.INT, - _type_utils.JitScalarType.INT16, - }: - num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) - return g.op("OneHot", self, num_classes, values, axis_i=-1) - - -@_onnx_symbolic("aten::gather") -@symbolic_helper.parse_args("v", "i", "v", "v") -def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): - if symbolic_helper._maybe_get_const(sparse_grad, "i"): - return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) - # NOTE: This workaround is needed since GatherElement is only supported - # since opset 11, and Gather in ONNX is not the same as torch.gather. - scalar_type = _type_utils.JitScalarType.from_value(self) - values = g.op("Constant", value_t=torch.LongTensor([0, 1])) - depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) - index = g.op( - "Cast", - g.op("OneHot", index, depth, values, axis_i=dim), - to_i=scalar_type.onnx_type(), - ) - mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) - return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) - - -@symbolic_helper.parse_args("v", "is", "i", "i") -def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): - return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) - - -@_onnx_symbolic("aten::std") -def std(g: jit_utils.GraphContext, input, *args): - var, _ = var_mean(g, input, *args) - return g.op("Sqrt", var) - - -@_onnx_symbolic("aten::var") -def var(g: jit_utils.GraphContext, input, *args): - var, _ = var_mean(g, input, *args) - return var - - -@_onnx_symbolic("aten::var_mean") -def var_mean(g: jit_utils.GraphContext, input, *args): - if len(args) == 1: - return _var_mean(g, input, None, args[0], None) - else: - return _var_mean(g, input, *args) - - -@_onnx_symbolic("aten::std_mean") -def std_mean(g: jit_utils.GraphContext, input, *args): - var, mean = var_mean(g, input, *args) - return g.op("Sqrt", var), mean - - -@_onnx_symbolic("aten::logsumexp") -@symbolic_helper.parse_args("v", "is", "i") -def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): - return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) - - -@_onnx_symbolic("aten::arange") -def arange(g: jit_utils.GraphContext, *args): - def _get_arange_dtype(dtype): - dtype = symbolic_helper._maybe_get_const(dtype, "i") - return dtype - - def _float_step_convert(range_tensor): - if symbolic_helper._is_fp(range_tensor): - range_tensor = g.op( - "Cast", - g.op("Ceil", range_tensor), - to_i=_type_utils.JitScalarType.INT64.onnx_type(), - ) - return range_tensor - - if len(args) == 2 or len(args) == 5: - if len(args) == 2: - # aten::arange(Scalar end, Tensor out) - dtype = None - else: - # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[1]) - dtype, end, start, step = symbolic_helper._arange_cast_helper( - g, end=args[0], dtype=dtype - ) - end = symbolic_helper._unsqueeze_helper(g, end, [0]) - range_tensor = _float_step_convert(end) - arange_tensor = symbolic_helper._squeeze_helper( - g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] - ) - return g.op( - "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() - ) - elif len(args) == 4 or len(args) == 7: - if len(args) == 4: - # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) - dtype = None - else: - # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[3]) - dtype, end, start, step = symbolic_helper._arange_cast_helper( - g, start=args[0], end=args[1], step=args[2], dtype=dtype - ) - step = symbolic_helper._unsqueeze_helper(g, step, [0]) - end = symbolic_helper._unsqueeze_helper(g, end, [0]) - start = symbolic_helper._unsqueeze_helper(g, start, [0]) - range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) - arange_tensor = symbolic_helper._squeeze_helper( - g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] - ) - arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) - return g.op( - "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() - ) - elif len(args) == 6: - # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - dtype = _get_arange_dtype(args[2]) - dtype, end, start, step = symbolic_helper._arange_cast_helper( - g, start=args[0], end=args[1], dtype=dtype - ) - end = symbolic_helper._unsqueeze_helper(g, end, [0]) - start = symbolic_helper._unsqueeze_helper(g, start, [0]) - range_tensor = _float_step_convert(g.op("Sub", end, start)) - arange_tensor = g.op( - "Add", - symbolic_helper._squeeze_helper( - g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] - ), - start, - ) - return g.op( - "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() - ) - - return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") - - -@_onnx_symbolic("aten::linspace") -def linspace( - g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory -): - range_tensor = symbolic_helper._arange_helper(g, steps, None) - step = div( - g, - sub(g, end, start), - sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), - ) - return add(g, mul(g, range_tensor, step), start) - - -@_onnx_symbolic("aten::lift") -def lift(g: jit_utils.GraphContext, self): - # at::lift() is a no-op from the perspective of tracing for onnx - return self - - -@_onnx_symbolic("aten::masked_fill") -def masked_fill(g: jit_utils.GraphContext, self, mask, value): - """Implement the masked_fill functionality available for a pytorch tensor in ONNX. - - Fills elements of the input tensor with `value` where `mask` is True. - """ - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - value = symbolic_helper._maybe_get_scalar(value) - return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) - - -@_onnx_symbolic("aten::masked_fill_") -def masked_fill_(g: jit_utils.GraphContext, self, mask, value): - return masked_fill(g, self, mask, value) - - -@_onnx_symbolic("aten::index") -def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper._is_packed_list(index): - indices = symbolic_helper._unpack_list(index) - else: - indices = [index] - - def try_mask_to_index(index): - if not symbolic_helper._is_none(index) and ( - _type_utils.JitScalarType.from_value( - index, _type_utils.JitScalarType.UNDEFINED - ) - == _type_utils.JitScalarType.UINT8 - or symbolic_helper._is_bool(index) - ): - if g.opset < 9: - raise errors.SymbolicValueError( - "Exporting masked indices are only supported after ONNX opset 9.", - self, - ) - warnings.warn( - "Exporting aten::index operator with indices of type Byte. " - "Only 1-D indices are supported. In any other case, " - "this will produce an incorrect ONNX graph." - ) - index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) - return index - - indices = [try_mask_to_index(idx) for idx in indices] - if len(indices) == 1: - return symbolic_helper._select_helper( - g, self, 0, indices[0], apply_reshape=False - ) - else: - # Multiple tensors as indices. Each tensor could either be - # 1. prim::Constant() - # representing ":" in python indexing. E.g. tensor[:, :] - # 2. prim::Constant[value=...] or tensor output - # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. - # For more info on advanced indexing, - # check https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing - - # Consider a general case of - # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] - # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". - # Same results can be achieved through transposing t into - # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] - # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t - # and process the tensor indices. - # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] - # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) - # After gather, reshape and transpose back. - adv_idx_indices = [ - i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) - ] - - if len(adv_idx_indices) == 0: - return self - elif len(adv_idx_indices) == 1: - return index_select( - g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] - ) - else: - rank = symbolic_helper._get_tensor_rank(self) - if rank is None: - return symbolic_helper._unimplemented( - "aten::index", - "operator of advanced indexing on tensor of unknown rank. ", - self, - ) - # TODO: If indexing is supported natively in ONNX in future opsets, - # update the warning to recommend exporting with higher opset version. - warnings.warn( - "Exporting aten::index operator of advanced indexing in opset " - f"{GLOBALS.export_onnx_opset_version}" - " is achieved by combination of multiple ONNX operators, " - "including Reshape, Transpose, Concat, and Gather. " - "If indices include negative values, the exported graph will produce incorrect results." - ) - adv_idx_count = len(adv_idx_indices) - shape_tensor = _shape_as_tensor(g, self) - dim_tensor_list = [ - g.op( - "Gather", - shape_tensor, - g.op("Constant", value_t=torch.LongTensor([dim])), - axis_i=0, - ) - for dim in range(rank) - ] - - self = g.op( - "Transpose", - self, - perm_i=adv_idx_indices - + [i for i in range(rank) if i not in adv_idx_indices], - ) - self = g.op("Flatten", self, axis_i=adv_idx_count) - - # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. - cum_adv_index = indices[adv_idx_indices[-1]] - multiplier = dim_tensor_list[adv_idx_indices[-1]] - for i in range(adv_idx_count - 2, -1, -1): - adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) - cum_adv_index = g.op("Add", cum_adv_index, adv_index) - multiplier = g.op( - "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] - ) - - # perform gather - self = index_select(g, self, 0, cum_adv_index) - - cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) - # check if all advanced indices are consecutive. - # Refer to https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - # to understand how the subarray position is decided. - if adv_idx_indices == list( - range(adv_idx_indices[0], adv_idx_indices[-1] + 1) - ): - # unfold regular index axes - folded_adv_idx_shape_list = [ - g.op("Constant", value_t=torch.LongTensor([-1])) - ] + [ - dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices - ] - folded_adv_idx_shape = g.op( - "Concat", *folded_adv_idx_shape_list, axis_i=0 - ) - self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) - - # Transpose folded advanced indexed axis to its original location. - adv_idx_permute = ( - list(range(1, adv_idx_indices[0] + 1)) - + [0] - + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) - ) - self = g.op("Transpose", self, perm_i=adv_idx_permute) - - # unfold advanced index axes - final_shape_list = ( - [dim_tensor_list[i] for i in range(adv_idx_indices[0])] - + [cum_adv_index_shape_tensor] - + [ - dim_tensor_list[i] - for i in range(adv_idx_indices[0], rank) - if i not in adv_idx_indices - ] - ) - final_shape = g.op("Concat", *final_shape_list, axis_i=0) - else: - final_shape = g.op( - "Concat", - cum_adv_index_shape_tensor, - *[ - dim_tensor_list[i] - for i in range(rank) - if i not in adv_idx_indices - ], - axis_i=0, - ) - - return symbolic_helper._reshape_helper(g, self, final_shape) - - -@_onnx_symbolic("aten::linalg_norm") -@symbolic_helper.parse_args("v", "v", "is", "b", "v") -def linalg_norm( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: torch._C.Value, - dim: Sequence[int] | None, - keepdim: bool, - dtype: torch._C.Value, -): - # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html - ord_value = None - if dim is None: - if symbolic_helper._is_none(ord): - self = symbolic_helper._reshape_helper(g, self, [-1]) - ord = g.op("Constant", value_t=torch.LongTensor([2])) - self_dim = symbolic_helper._get_tensor_rank(self) - if self_dim is None: - return symbolic_helper._unimplemented( - "dim", "Input rank must be known at export time.", self - ) - if self_dim == 1: - ord_value = symbolic_helper._parse_arg(ord, "f") - else: - dim = [0, 1] - else: - if len(dim) == 1: - if symbolic_helper._is_none(ord): - ord = g.op("Constant", value_t=torch.LongTensor([2])) - ord_value = symbolic_helper._parse_arg(ord, "f") - if ord_value: - return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) - return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) - - -@_onnx_symbolic("aten::linalg_vector_norm") -@symbolic_helper.parse_args("v", "f", "is", "b", "v") -def linalg_vector_norm( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: float, - dim: Sequence[int] | None, - keepdim: bool, - dtype: torch._C.Value, -): - return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) - - -@_onnx_symbolic("aten::linalg_matrix_norm") -@symbolic_helper.parse_args("v", "v", "is", "b", "v") -def linalg_matrix_norm( - g: jit_utils.GraphContext, - self: torch._C.Value, - ord: torch._C.Value, - dim: list[int], - keepdim: bool, - dtype: torch._C.Value, -): - # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html - ord_value = symbolic_helper._parse_arg(ord, "s") - if ord_value == "fro": - return frobenius_norm(g, self, dim, keepdim) - elif ord_value == "nuc": - return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) - else: - ord_value = symbolic_helper._parse_arg(ord, "f") - if ord_value is None: - return frobenius_norm(g, self, dim, keepdim) - if ord_value == 2 or ord_value == -2: - # ord = 2/-2 unimplemented due to lack of operators - # used to calculate singular values - return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) - # Wrap the dim vector to handle negative dim values - self_dim = symbolic_helper._get_tensor_rank(self) - if self_dim is None: - return symbolic_helper._unimplemented( - "linalg.matrix_norm", "Input rank must be known at export time.", self - ) - # Common implementation for cases with - # ord = 1/-1 and ord = inf/-inf - if dim[0] < 0: - dim[0] += self_dim - if dim[1] < 0: - dim[1] += self_dim - - if ord_value == math.inf or ord_value == -math.inf: - dim[0], dim[1] = dim[1], dim[0] - if dim[1] > dim[0] and not keepdim: - dim[1] -= 1 - sum = symbolic_helper._reducesum_helper( - g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim - ) - if ord_value > 0: - result, _indices = max( - g, - sum, - dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), - keepdim=keepdim, - ) - else: - result, _indices = min( - g, - sum, - dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), - keepdim=keepdim, - ) - return result - - -@_onnx_symbolic("aten::linalg_cross") -@symbolic_helper.parse_args("v", "v", "i") -def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): - return cross(g, input, other, dim) - - -@_onnx_symbolic("aten::frobenius_norm") -@symbolic_helper.parse_args("v", "is", "b") -def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): - sqr = g.op("Mul", self, self) - sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) - return g.op("Sqrt", sumsqr) - - -@_onnx_symbolic("aten::multinomial") -@symbolic_helper.parse_args("v", "i", "b", "v") -def multinomial( - g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None -): - if generator is not None and not symbolic_helper._is_none(generator): - symbolic_helper._unimplemented( - "Multinomial", "generator is not supported for multinomial", input - ) - if not replacement and num_samples > 1: - symbolic_helper._unimplemented( - "Multinomial", - "replacement=False when num_samples > 1 is not supported for multinomial", - input, - ) - - log_input = log(g, input) - return g.op( - "Multinomial", - log_input, - dtype_i=_C_onnx.TensorProtoDataType.INT64, - sample_size_i=num_samples, - ) - - -@_onnx_symbolic("aten::baddbmm") -def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): - scalar_type = _type_utils.JitScalarType.from_value(self) - batch_mul = matmul(g, batch1, batch2) - mul_a = mul( - g, - batch_mul, - g.op("Cast", alpha, to_i=scalar_type.onnx_type()), - ) - mul_b = mul( - g, - self, - g.op("Cast", beta, to_i=scalar_type.onnx_type()), - ) - return add(g, mul_a, mul_b) - - -@_onnx_symbolic("aten::meshgrid") -@symbolic_helper.parse_args("v", "s") -def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): - if indexing is None: - indexing = "ij" - elif indexing not in {"ij", "xy"}: - raise errors.SymbolicValueError( - f"Unsupported indexing: {indexing}", tensor_list - ) - unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) - if indexing == "xy": - unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] - tensors = [ - symbolic_helper._reshape_helper( - g, t, g.op("Constant", value_t=torch.LongTensor([-1])) - ) - for t in unpacked_tensor_list - ] - tensors_shape = [g.op("Shape", t) for t in tensors] - out_shape = g.op("Concat", *tensors_shape, axis_i=0) - out = [] - for i, t in enumerate(tensors): - shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( - tensors - ) - shape_i[i] = tensors_shape[i] - t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) - out.append(g.op("Expand", t_reshaped, out_shape)) - if indexing == "xy": - out[0], out[1] = out[1], out[0] - return g.op("prim::ListConstruct", *out) - - -@_onnx_symbolic("aten::remainder") -def remainder(g: jit_utils.GraphContext, input, other): - div = _floor_divide(g, input, other) - quo = g.op("Mul", div, other) - return g.op("Sub", input, quo) - - -@_onnx_symbolic("aten::gelu") -@symbolic_helper.parse_args("v", "s") -def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): - if approximate == "tanh": - kBeta = math.sqrt(2 / math.pi) - kKappa = 0.044715 - - beta = torch.tensor(kBeta, dtype=torch.double) - kappa = torch.tensor(kKappa, dtype=torch.double) - one = torch.tensor(1.0, dtype=torch.double) - half = torch.tensor(0.5, dtype=torch.double) - - self_cube = mul(g, self, mul(g, self, self)) - inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) - return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) - else: - _sqrt2 = 1.4142135623730951 - erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) - erf_plusone = add( - g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) - ) - return mul( - g, - mul(g, self, erf_plusone), - g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), - ) - - -@_onnx_symbolic("aten::group_norm") -@symbolic_helper.quantized_args(True, False, False, False) -@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") -def group_norm( - g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled -): - channel_size = symbolic_helper._get_tensor_dim_size(input, 1) - if channel_size is not None: - assert channel_size % num_groups == 0 - input_rank = symbolic_helper._get_tensor_rank(input) - if input_rank is None: - return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) - # 0 in the shape list keeps dimension value unchanged. - shape = [0, num_groups, -1] - input_reshaped = symbolic_helper._reshape_helper( - g, input, g.op("Constant", value_t=torch.LongTensor(shape)) - ) - - # C is always divisible by num_groups - # Due to shape difference. we need to apply weight and bias after - # instance norm computation and reshape - weight_ = g.op( - "Constant", - value_t=torch.tensor( - [1.0] * num_groups, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ), - ) - bias_ = g.op( - "Constant", - value_t=torch.tensor( - [0.0] * num_groups, - dtype=_type_utils.JitScalarType.from_value(input).dtype(), - ), - ) - - norm_reshaped = g.op( - "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps - ) - norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) - - if weight is None or weight.node().mustBeNone(): - weight_value = torch.tensor( - [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() - ) - weight = g.op("Constant", value_t=weight_value) - if bias is None or bias.node().mustBeNone(): - bias_value = torch.tensor( - [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() - ) - bias = g.op("Constant", value_t=bias_value) - - # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] - axes = list(range(1, input_rank - 1)) - return add( - g, - mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), - symbolic_helper._unsqueeze_helper(g, bias, axes), - ) - - -@_onnx_symbolic("aten::_weight_norm") -@symbolic_helper.parse_args("v", "v", "i") -def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): - rank = symbolic_helper._get_tensor_rank(weight_v) - if rank is not None: - # W = g * ((v) / ||v||) - # Compute norm_except_dim for l2 norm. dim = None means over all dims - # torch's weight_norm module sets dim = -1 if it's None. - # This conflicts the logic for negative axes to access dims backwards - # TODO: Might need a fix in torch group_norm module - axes = list(range(rank)) - if dim is not None: - if dim < -1: - dim += rank - if dim != -1: - axes.remove(dim) - norm_v = norm(g, weight_v, 2, axes, 1) - div = g.op("Div", weight_v, norm_v) - return g.op("Mul", div, weight_g) - raise errors.SymbolicValueError( - "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", - weight_v, - ) - - -@_onnx_symbolic("aten::dim") -def dim(g: jit_utils.GraphContext, self): - """Implement the dim functionality available for a pytorch tensor in ONNX""" - # ONNX does not support dim directly in this opset so we can use 2 ops to get the info - shape = g.op("Shape", self) - return g.op("Size", shape) - - -@_onnx_symbolic("aten::__contains_") -def __contains_(g: jit_utils.GraphContext, self, element): - unpacked_list = symbolic_helper._unpack_list(self) - if all( - symbolic_helper._is_constant(x) for x in unpacked_list - ) and symbolic_helper._is_constant(element): - return g.op( - "Constant", - value_t=torch.tensor( - symbolic_helper._node_get(element.node(), "value") - in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) - ), - ) - - raise errors.SymbolicValueError( - "Unsupported: ONNX export of __contains__ for non-constant list or element.", - self, - ) - - -@_onnx_symbolic("aten::__getitem_") -def __getitem_(g: jit_utils.GraphContext, self, i): - return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) - - -@_onnx_symbolic("aten::item") -def item(g: jit_utils.GraphContext, self): - return self - - -@_onnx_symbolic("aten::take") -def take(g: jit_utils.GraphContext, self, index): - self_flattened = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) - ) - out = index_select(g, self_flattened, 0, index) - out = reshape_as(g, out, index) - return out - - -def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): - diff_ = sub(g, target, input) - exp_ = exp(g, target) - output = mul(g, exp_, diff_) - return output - - -def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): - log_ = log(g, target) - diff_ = sub(g, log_, input) - output_pos = mul(g, target, diff_) - zeros_ = zeros_like(g, output_pos) - mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) - output = where(g, mask_, output_pos, zeros_) - return output - - -@_onnx_symbolic("aten::kl_div") -@symbolic_helper.parse_args("v", "v", "i", "b") -def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): - if log_target: - output = _kl_div_log_target_impl(g, input, target) - else: - output = _kl_div_non_log_target_impl(g, input, target) - - if reduction == 0: - return output - elif reduction == 1: - return g.op("ReduceMean", output, keepdims_i=0) - elif reduction == 2: - return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) - else: - return symbolic_helper._onnx_unsupported( - "kl_div with reduction other than none, mean, or sum.", input - ) - - -@_onnx_symbolic("aten::mse_loss") -@symbolic_helper.parse_args("v", "v", "i") -def mse_loss(g: jit_utils.GraphContext, input, target, reduction): - output = mul(g, sub(g, input, target), sub(g, input, target)) - if reduction == 0: - return output - elif reduction == 1: - return g.op("ReduceMean", output, keepdims_i=0) - elif reduction == 2: - return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) - else: - return symbolic_helper._onnx_unsupported( - "mse_loss with reduction other than none, mean, or sum.", input - ) - - -@_onnx_symbolic("aten::as_strided") -@symbolic_helper.quantized_args(True) -@symbolic_helper.parse_args("v", "v", "is", "i") -def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): - sizes = symbolic_helper._maybe_get_const(sizes, "is") - rank = len(strides) - self_1d = symbolic_helper._reshape_helper( - g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) - ) - ind: torch.Tensor | None - if not symbolic_helper._is_value(sizes): - ind = torch.tensor([0], dtype=torch.long) - for i, (size, stride) in enumerate(zip(sizes, strides)): - r_size = [1] * rank - r_size[i] = -1 - ind = ind + torch.arange(size).view(r_size) * stride - if offset: - ind = ind + offset - return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) - else: - ind = None - for i, stride in enumerate(strides): - r_size = [1] * rank - r_size[i] = -1 - size = select( - g, - sizes, - g.op("Constant", value_t=torch.tensor([0])), - g.op("Constant", value_t=torch.tensor(i)), - ) - tmp_ind = symbolic_helper._reshape_helper( - g, - arange(g, size, 4, None, None, None), - g.op("Constant", value_t=torch.tensor(r_size)), - ) - tmp_ind = g.op( - "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) - ) - if ind is None: - ind = tmp_ind - else: - ind = g.op("Add", ind, tmp_ind) - if offset: - ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) - return g.op("Gather", self_1d, ind) - - -@_onnx_symbolic("aten::__derive_index") -def __derive_index(g: jit_utils.GraphContext, index, start, step): - return g.op("Add", start, g.op("Mul", index, step)) - - -@_onnx_symbolic("aten::__range_length") -# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp -# if (step > 0 && lo < hi) { -# push(stack, 1 + (hi - 1 - lo) / step); -# } else if (step < 0 && lo > hi) { -# push(stack, 1 + (lo - 1 - hi) / (0 - step)); -# } else { -# push(stack, 0); -# } -def __range_length(g: jit_utils.GraphContext, lo, hi, step): - sub = g.op("Sub", hi, lo) - div = g.op("Ceil", true_divide(g, sub, step)) - return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) - - -@_onnx_symbolic("aten::linear") -def linear(g: jit_utils.GraphContext, input, weight, bias): - rank = symbolic_helper._get_tensor_rank(input) - weight = t(g, weight) - if rank == 2 and not bias.node().mustBeNone(): - alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - output = addmm(g, bias, input, weight, alpha, beta) - else: - output = matmul(g, input, weight) - if not bias.node().mustBeNone(): - output = add(g, bias, output) - - return output - - -@_onnx_symbolic("aten::hann_window") -@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") -def hann_window( - g: jit_utils.GraphContext, - window_length, - periodic=True, - dtype: int | None = None, - layout=None, - device=None, - pin_memory=None, - requires_grad=False, -): - if dtype is None: - dtype_ = torch.get_default_dtype() - if not dtype_ or not dtype_.is_floating_point: - dtype_ = torch.float - scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) - else: - scalar_type = _type_utils.JitScalarType(dtype) - - n_array = arange(g, window_length, 4, None, None, None) - output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) - output = mul( - g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output - ) - - if periodic is False: - window_length = sub( - g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) - ) - output = div(g, output, window_length) - output = g.op( - "Cast", - square(g, sin(g, output)), - to_i=scalar_type.onnx_type(), - ) - - return output - - -@_onnx_symbolic("aten::mv") -def mv(g: jit_utils.GraphContext, self, vec): - return matmul(g, self, vec) - - -@_onnx_symbolic("aten::dot") -def dot(g: jit_utils.GraphContext, self, other): - return matmul(g, self, other) - - -@_onnx_symbolic("aten::movedim") -@symbolic_helper.parse_args("v", "t", "t") -def movedim(g: jit_utils.GraphContext, self, source, destination): - # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim - source = source.view(-1) - destination = destination.view(-1) - - assert source.size() == destination.size() - - if (source == destination).all(): - return self - - self_rank = symbolic_helper._get_tensor_rank(self) - assert self_rank is not None - - perm = list(range(self_rank)) - - src_dims = perm.copy() - dst_dims = perm.copy() - - for src, dst in zip(source.tolist(), destination.tolist()): - perm[dst] = src - src_dims[src] = -1 - dst_dims[dst] = -1 - - src_dims = [dim for dim in src_dims if dim != -1] - dst_dims = [dim for dim in dst_dims if dim != -1] - - for src, dst in zip(src_dims, dst_dims): - perm[dst] = src - - return g.op("Transpose", self, perm_i=perm) - - -@_onnx_symbolic("aten::fill") -@symbolic_helper.parse_args("v", "v") -def fill(g: jit_utils.GraphContext, self, value): - scalar_type = _type_utils.JitScalarType.from_value( - self, _type_utils.JitScalarType.FLOAT - ) - return full_like(g, self, value, scalar_type) - - -@_onnx_symbolic("aten::index_add") -def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): - warnings.warn( - "Warning: ONNX export does not support duplicated values in 'index' field, " - + "this will cause the ONNX model to be incorrect." - ) - - # ONNX does not support "alpha" argument, unlike aten index_add - # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context - if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: - return symbolic_helper._unimplemented("index_add", "alpha != 1", self) - - dim = symbolic_helper._maybe_get_const(dim, "i") - if dim is None: - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting 'index_add_()' function with " - "unknown 'dim' value.", - self, - ) - - self_dim_rank = symbolic_helper._get_tensor_rank(self) - other_dim_rank = symbolic_helper._get_tensor_rank(other) - - if self_dim_rank is None or other_dim_rank is None: - raise errors.SymbolicValueError( - "ONNX export does NOT support exporting 'index_add_()' function while " - "the rank of self tensor or tensor to be added is unknown.", - self, - ) - - if other_dim_rank != self_dim_rank: - delta = self_dim_rank - other_dim_rank - for i in range(delta): - other = symbolic_helper._unsqueeze_helper( - g, other, [symbolic_helper._get_tensor_rank(other)] - ) - - other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) - self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) - - if (other_dim_size is not None) and (self_dim_size is not None): - if other_dim_size > self_dim_size: - raise errors.SymbolicValueError( - "ONNX export does not support exporting 'index_add_()' function with " - "duplicated values in 'index' parameter yet.", - self, - ) - - # Construct a new shape. It's almost as same as self except the size of the 'dim' - # dimension is 1, so that we can expand other dimensions as expected. - new_shape_axes = list(range(self_dim_rank)) - new_shape_starts = [0 for i in range(self_dim_rank)] - new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] - - new_shape = symbolic_helper._slice_helper( - g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends - ) - other = expand_as(g, other, new_shape) - - for i in range(dim): - index = symbolic_helper._unsqueeze_helper(g, index, [0]) - - for i in range(self_dim_rank - dim - 1): - index = symbolic_helper._unsqueeze_helper( - g, index, [symbolic_helper._get_tensor_rank(index)] - ) - - return scatter_add(g, self, dim, expand_as(g, index, other), other) - - -@_onnx_symbolic("aten::roll") -@symbolic_helper.parse_args("v", "is", "is") -def roll(g: jit_utils.GraphContext, self, shifts, dims): - assert len(shifts) == len(dims) - - result = self - for i in range(len(shifts)): - shapes = [] - shape = symbolic_helper._slice_helper( - g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] - ) - shapes.append(shape) - shape = symbolic_helper._slice_helper( - g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] - ) - shapes.append(shape) - result = g.op("Concat", *shapes, axis_i=dims[i]) - - return result - - -@_onnx_symbolic("aten::cross") -@symbolic_helper.parse_args("v", "v", "i") -def cross(g: jit_utils.GraphContext, input, other, dim=None): - dim = symbolic_helper._get_dim_for_cross(input, dim) - # If we have two tensors such that - # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have - # After first roll, - # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) - roll_x_1 = roll(g, input, [2], [dim]) - roll_y_1 = roll(g, other, [1], [dim]) - # After second roll, - # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) - roll_x_2 = roll(g, input, [1], [dim]) - roll_y_2 = roll(g, other, [2], [dim]) - # cross product is calculated as - # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] - return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) - - -@_onnx_symbolic("aten::cdist") -def cdist( - g: jit_utils.GraphContext, - x1, - x2, - p=2.0, - compute_mode="use_mm_for_euclid_dist_if_necessary", -): - # X1.shape = (B * P * D), X2.shape = (B * R * D) - # In order to respect numpy style broadcasting as demonstrated in - # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - # we unsqueeze both input tensors - row_size_x1 = symbolic_helper._get_tensor_dim_size(x1, -2) - row_size_x2 = symbolic_helper._get_tensor_dim_size(x2, -2) - assert row_size_x1 is not None - assert row_size_x2 is not None - p_float = symbolic_helper._parse_arg(p, "f") - compute_mode = symbolic_helper._parse_arg(compute_mode, "i") - if p_float == 2.0 and ( - compute_mode == 1 - or (compute_mode is None and row_size_x1 >= 25 and row_size_x2 >= 25) - ): - return _euclidean_dist(g, x1, x2) - rank = symbolic_helper._get_tensor_rank(x1) - assert rank is not None - broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) - broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) - return pairwise_distance( - g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False - ) - - -def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): - # X1.shape = (B * P * D), X2.shape = (B * R * D) - # using matrix multiplication to accelerate the calculation of - # the euclidean distance - rank = symbolic_helper._get_tensor_rank(x1) - assert rank is not None - x1_norm = symbolic_helper._reducesum_helper( - g, - pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), - axes_i=[-1], - keepdims_i=True, - ) - x1_pad = ones_like(g, x1_norm) - x2_norm = symbolic_helper._reducesum_helper( - g, - pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), - axes_i=[-1], - keepdims_i=True, - ) - x2_pad = ones_like(g, x2_norm) - x1_ = g.op( - "Concat", - *[ - mul(g, symbolic_helper._generate_wrapped_number(g, -2.0), x1), - x1_norm, - x1_pad, - ], - axis_i=-1, - ) - x2_ = g.op("Concat", *[x2, x2_pad, x2_norm], axis_i=-1) - result = matmul(g, x1_, transpose(g, x2_, -2, -1)) - dtype = _type_utils.JitScalarType.from_value(result) - min = g.op( - "Cast", symbolic_helper._generate_wrapped_number(g, 0.0), to_i=dtype.onnx_type() - ) - result = symbolic_helper._op_with_optional_float_cast( - g, "Max", result, min, opset_before=12 - ) - result = sqrt(g, result) - return result - - -@_onnx_symbolic("aten::lerp") -def lerp(g: jit_utils.GraphContext, self, end, weight): - # Conditional for better numeric. This has been discussed in - # https://github.com/pytorch/pytorch/pull/18871 - diff = g.op("Sub", end, self) - return where( - g, - g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), - g.op("Add", self, g.op("Mul", weight, diff)), - g.op( - "Sub", - end, - g.op( - "Mul", - diff, - g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), - ), - ), - ) - - -@_onnx_symbolic("aten::broadcast_tensors") -def broadcast_tensors(g: jit_utils.GraphContext, self): - all_tensors = symbolic_helper._unpack_list(self) - t_with_final_shape = zeros_like(g, all_tensors[0]) - - # Add operator supports multidirectional broadcasting. So we leverage this function - # to infer the final shape generated by the broadcast. - for t in all_tensors: - t_with_final_shape = add(g, t_with_final_shape, t) - - t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] - return g.op("prim::ListConstruct", *t_list) - - -@_onnx_symbolic("aten::is_pinned") -def is_pinned(g: jit_utils.GraphContext, self, device=None): - # Unused by ONNX. - return None - - -@_onnx_symbolic("prim::ConstantSplit") -def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - return symbolic_helper._unimplemented( - "prim::ConstantSplit", "unknown dimension size", self - ) - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) - - -# TODO: It would be better to export this as a chunk directly, as this is -# less sensitive to changes in input size. -# TODO: Once we have proper scoping, stop reimplementing chunk, delete this -# method, and use the desugared version -@_onnx_symbolic("prim::ConstantChunk") -def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): - dim_size = symbolic_helper._get_tensor_dim_size(self, dim) - if dim_size is None: - return symbolic_helper._unimplemented( - "prim::ConstantChunk", "unknown dimension size", self - ) - split_size = (dim_size + chunks - 1) // chunks - return prim_constant_split(g, self, split_size, dim) - - -@_onnx_symbolic("prim::shape") -def prim_shape(g: jit_utils.GraphContext, self): - return g.op("Shape", self) - - -@_onnx_symbolic("prim::max") -def prim_max(g: jit_utils.GraphContext, self, other): - return symbolic_helper._op_with_optional_float_cast( - g, "Max", self, other, opset_before=12 - ) - - -@_onnx_symbolic("prim::min") -def prim_min(g: jit_utils.GraphContext, self, other=None): - if not other: - if symbolic_helper._is_packed_list(self): - self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) - return min(g, self) - return min(g, self, other) - - -@_onnx_symbolic("prim::data") -def prim_data(g: jit_utils.GraphContext, self): - return self - - -@_onnx_symbolic("prim::layout") -def prim_layout(g: jit_utils.GraphContext, self): - # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. - # Layout class defined in 'c10/core/Layout.h'. - return g.op("Constant", value_t=torch.tensor(0)) - - -@_onnx_symbolic("prim::ListConstruct") -def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): - return None - - -@_onnx_symbolic("prim::ListUnpack") -def prim_list_unpack( - g: jit_utils.GraphContext, *inputs, **kwargs -) -> list[_C.Value] | None: - if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": - # Cancel the previous node if it is ListConstruct by returning its inputs - # TODO(justinchuby): Use a public method in the helper module - return symbolic_helper._unpack_list(inputs[0]) - - return None - - -@_onnx_symbolic("prim::TupleConstruct") -def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): - return None - - -@_onnx_symbolic("prim::Uninitialized") -def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): - return None - - -# exists to refine the type of the Value -# if x is an optional Tensor, unchecked_cast will cast -# x to Tensor, so the rest of the graph knows that x is a Tensor -# this doesn't do anything in runtime and is a noop in ONNX -@_onnx_symbolic("prim::unchecked_cast") -def prim_unchecked_cast(g: jit_utils.GraphContext, self): - return self - - -@_onnx_symbolic("prim::dtype") -def prim_dtype(g: jit_utils.GraphContext, self): - scalar_type = symbolic_helper._try_get_scalar_type(self) - if scalar_type is None: - scalar_type = _type_utils.JitScalarType.FLOAT - # This node records a torch dtype as int - return g.op("Constant", value_t=torch.tensor(scalar_type)) - - -@_onnx_symbolic("prim::tolist") -def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): - """tolist is currently supported only for 1D input tensors. - - dim_val and elem_ty_val represent dimension and type annotations - that need to match dimension and type of the input tensor. - """ - dim = symbolic_helper._maybe_get_const(dim_val, "i") - if dim > 1: - return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) - return input - - -# ----------------------------------------------------------------------------- -# Symbolic functions that need extra context -# ----------------------------------------------------------------------------- -@_onnx_symbolic("prim::device") -def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: - output_type = g.original_node.output().type() - if isinstance(output_type, _C.DeviceObjType): - return None - - return symbolic_helper._unimplemented( - "prim::device", - f"output type should be 'DeviceObjType', not '{output_type.kind()}'", - g.original_node.output(), - ) - - -@_onnx_symbolic("prim::Loop") -def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: - node = g.original_node - env = g.env - values_in_env = g.values_in_env - params_dict = g.params_dict - - operator_export_type = GLOBALS.operator_export_type - opset_version = GLOBALS.export_onnx_opset_version - - old_blocks = tuple(node.blocks()) - _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( - g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) - ) - - for old_block, new_block_context in zip(old_blocks, new_block_contexts): - # Copy input metadata to subblock - # - # prim::Loop(iter, cond, input_1, ..., input_n) - # block0(iter, input_1, ..., input_n) - # - # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. - for i, b_in in enumerate(old_block.inputs()): - if i == 0 and i < len(inputs): - b_in.setType(inputs[i].type()) - # For optional block inputs, they may switch between None not-None inside - # the loop body, so if the loop input is not optional, the block input may - # still need to be optional. - if ( - i > 0 - and (i + 1) < len(inputs) - and not isinstance(b_in.type(), _C.OptionalType) - ): - b_in.setType(inputs[i + 1].type()) - torch._C._jit_pass_onnx_block( - old_block, - new_block_context.block, - operator_export_type, - env, - values_in_env, - False, - ) - fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( - new_node, opset_version - ) - # Run shape type inference for Loop after subblock is converted. - if GLOBALS.onnx_shape_inference: - torch._C._jit_pass_onnx_node_shape_type_inference( - new_node, params_dict, opset_version - ) - return fixed_outputs - - -@_onnx_symbolic("prim::If") -def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: - n = g.original_node - block = g.block - env = g.env - values_in_env = g.values_in_env - params_dict = g.params_dict - - operator_export_type = GLOBALS.operator_export_type - opset_version = GLOBALS.export_onnx_opset_version - - static_if = inputs[0].node().kind() == "onnx::Constant" - if static_if: - # Fold static if - # - # The torch IR - # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), - # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... - # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() - # %21 : Long(device=cpu) = aten::eq(%20, %64) - # %22 : Long(device=cpu) = prim::If(%21) - # block0(): - # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) - # -> (%23) - # block1(): - # -> (%65) - # %input.53 : Tensor, %weight : Tensor = prim::If(%22) - # block0(): - # -> (%embedding_matrix.1, %input.1) - # block1(): - # -> (%input.1, %embedding_matrix.1) - # %26 : int[] = aten::size(%input.53) - # - # The converted ONNX graph - # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() - # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) - # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() - # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) - input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() - const_value = ( - all(input_flag) if isinstance(input_flag, list) else bool(input_flag) - ) - block_idx = 0 if const_value else 1 - current_b = list(n.blocks())[block_idx] - env = torch._C._jit_pass_onnx_block( - current_b, - block, - operator_export_type, - env, - values_in_env, - True, - ) - if_output_list = list(n.outputs()) - current_b_list = list(current_b.outputs()) - - final_b_list = [] - for idx in range(len(if_output_list)): - if current_b_list[idx] not in env: - raise errors.SymbolicValueError( - f"The sub block ATen output {current_b_list[idx]} is not in env.", - current_b_list[idx], - ) # type:ignore[operator] - onnx_b = env[current_b_list[idx]] - final_b_list.append(onnx_b) - return final_b_list - else: - old_blocks = tuple(n.blocks()) - _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( - g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) - ) - - for old_block, new_block_context in zip(old_blocks, new_block_contexts): - torch._C._jit_pass_onnx_block( - old_block, - new_block_context.block, - operator_export_type, - env, - values_in_env, - False, - ) - fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( - new_node, opset_version - ) - # Run shape type inference for If after subblock is converted. - if GLOBALS.onnx_shape_inference: - torch._C._jit_pass_onnx_node_shape_type_inference( - new_node, params_dict, opset_version - ) - return fixed_outputs - - -@_onnx_symbolic("prim::Constant") -def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): - node = g.original_node - - if node.mustBeNone(): - return None - # This must go before checking for string values, because some device constants - # have string values, but we want to keep them as unconverted Device types so - # that eq() can work on them. - if isinstance(node.output().type(), _C.DeviceObjType): - return None - if node.kindOf("value") == "t": - return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) - if node.kindOf("value") == "s": - return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) - if node.output().type().isSubtypeOf( - _C.ListType.ofInts() - ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): - return g.op( - "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) - ) - if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): - str_constants = [ - g.op("Constant", value_s=s) - for s in symbolic_helper._node_get(node, "value") - ] - return g.op("prim::ListConstruct", *str_constants) - - raise errors.SymbolicValueError( - f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " - f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", - node.output(), - ) - - -@_onnx_symbolic("prim::type") -def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): - if device_value.node().kind() == "prim::device": - device = jit_utils.get_device_from_value(device_value.node().input()) - if device is not None: - return g.op("Constant", value_s=str(device)) - - return symbolic_helper._unimplemented( - "prim::type", - "Device type cannot be statically determined.", - device_value, - ) - - -@_onnx_symbolic("onnx::Placeholder") -def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): - node = g.original_node - block = g.block - env = g.env - values_in_env = g.values_in_env - - return torch._C._jit_onnx_convert_pattern_from_subblock( - block, node, env, values_in_env - ) - - -@_onnx_symbolic("aten::resolve_conj") -@_onnx_symbolic("aten::resolve_neg") -def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): - # ONNX does not have operators to *directly* manipulate real/imaginary components - # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, - # which results in failures due to missing operators for complex numbers - - # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op - return input - - -@_onnx_symbolic("aten::_conj") -@_onnx_symbolic("aten::conj_physical") -def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): - # ONNX does not have operators to *directly* manipulate real/imaginary components - # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, - # which results in failures due to missing operators for complex numbers - - # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex - if symbolic_helper.is_complex_value(input): - # FIXME(justinchuby): report correct name for symbolic being executed - return symbolic_helper._onnx_unsupported( - "aten::_conj, aten::conj_physical", - input, - ) - - # they can safely be implemented as no-op for real numbers only - return noop_complex_operators(g, input) - - -@_onnx_symbolic("aten::logit") -def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): - one = g.op("Constant", value_t=torch.tensor(1.0)) - - if not symbolic_helper._is_none(eps): - eps = g.op( - "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() - ) - one_sub_eps = g.op("Sub", one, eps) - self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) - temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) - - temporary_self_less_eps = g.op("Less", temporary_self, eps) - z = g.op("Where", temporary_self_less_eps, eps, temporary_self) - else: - z = self - - sub = g.op("Sub", one, z) - div = g.op("Div", z, sub) - return g.op("Log", div) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ec08090a595f..6b1d752bb04e 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1,1880 +1,8 @@ -# mypy: allow-untyped-defs -"""Functions to export models into the ONNX IR format. - -These models can be loaded with the ONNX library and then -converted to models which run on other deep learning frameworks. -""" +"""Backward compatibility module for torch.onnx.utils.""" from __future__ import annotations -import contextlib -import copy -import inspect -import re -import typing -import warnings -from typing import Any, Callable, cast -from typing_extensions import deprecated -import torch -import torch._C._onnx as _C_onnx -import torch.jit._trace -import torch.serialization -from torch import _C -from torch.onnx import _constants, errors, symbolic_helper # noqa: F401 -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import jit_utils, onnx_proto_utils, registration +__all__: list[str] = [] - -if typing.TYPE_CHECKING: - from collections.abc import Collection, Mapping, Sequence - - -__all__ = [ - "select_model_mode_for_export", - "disable_apex_o2_state_dict_hook", - "setup_onnx_logging", - "exporter_context", - "export", - "model_signature", - "warn_on_static_input_change", - "unpack_quantized_tensor", - "unconvertible_ops", - "register_custom_op_symbolic", - "unregister_custom_op_symbolic", -] - - -# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp -# Skip check due to cannot import IValue from torch._C -_params_dict = {} # type: ignore[var-annotated] - - -@deprecated("Please set training mode before exporting the model", category=None) -@contextlib.contextmanager -def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): - """A context manager to temporarily set the training mode of ``model`` - to ``mode``, resetting it when we exit the with-block. - - .. deprecated:: 2.7 - Please set training mode before exporting the model. - - Args: - model: Same type and meaning as ``model`` arg to :func:`export`. - mode: Same type and meaning as ``training`` arg to :func:`export`. - """ - if not isinstance(mode, _C_onnx.TrainingMode): - raise TypeError( - f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." - ) - originally_training: bool = False - - if hasattr(model, "training"): - originally_training = model.training - - # ONNX opset 12 has better support for training amenable models, with updated - # versions of the dropout and batch_norm operators - if mode == _C_onnx.TrainingMode.TRAINING or ( - mode == _C_onnx.TrainingMode.PRESERVE and originally_training - ): - GLOBALS.export_training = True - if GLOBALS.export_onnx_opset_version < 12: - warnings.warn( - "You are exporting the model in training mode with onnx opset " - f"version {GLOBALS.export_onnx_opset_version}. " - "Opset versions lower than opset 12 will not be able to export " - "nodes such as Dropout and BatchNorm correctly." - ) - else: - GLOBALS.export_training = False - - GLOBALS.training_mode = mode - if mode == _C_onnx.TrainingMode.TRAINING: - model.train(True) - elif mode == _C_onnx.TrainingMode.EVAL: - model.train(False) - # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing - - try: - yield - finally: - if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: - model.train(originally_training) - - -@deprecated( - "Please remove usage of this function. Copy its logic if it is required in user code", - category=None, -) -@contextlib.contextmanager -def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): - """A context manager to temporarily disable the Apex O2 hook that returns. - - .. deprecated:: 2.7 - Please remove usage of this function. - """ - # Apex O2 hook state_dict to return fp16 weights as fp32. - # Exporter cannot identify them as same tensors. - # Since this hook is only used by optimizer, it is safe to - # remove this hook while exporting. - if not isinstance(model, torch.jit.ScriptFunction): - model_hooks = {} # type: ignore[var-annotated] - for module in model.modules(): - for key, hook in module._state_dict_hooks.items(): - if type(hook).__name__ == "O2StateDictHook": - if module not in model_hooks: - model_hooks[module] = {} - model_hooks[module][key] = hook - if module in model_hooks: - for key in model_hooks[module]: - module._state_dict_hooks.pop(key) - try: - yield - finally: - # Add the hooks back - for module, m_map in model_hooks.items(): - for key, hook in m_map.items(): - module._state_dict_hooks[key] = hook - else: - try: - yield - finally: - pass - - -@deprecated("The feature will be removed. Please remove usage of this function") -@contextlib.contextmanager -def setup_onnx_logging(verbose: bool): - """A context manager to temporarily set the ONNX logging verbosity. - - .. deprecated:: 2.7 - Please remove usage of this function. - """ - is_originally_enabled = _C._jit_is_onnx_log_enabled - if is_originally_enabled or verbose: # type: ignore[truthy-function] - _C._jit_set_onnx_log_enabled(True) - try: - yield - finally: - if not is_originally_enabled: # type: ignore[truthy-function] - _C._jit_set_onnx_log_enabled(False) - - -@deprecated( - "The feature will be removed. Please remove usage of this function " - "and implement equivalent logic if needed", - category=None, -) -@contextlib.contextmanager -def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): - """A context manager to temporarily set the training mode of ``model`` - to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity. - - .. deprecated:: 2.7 - Please set training mode before exporting the model. - """ - with ( - select_model_mode_for_export(model, mode) as mode_ctx, - disable_apex_o2_state_dict_hook(model) as apex_ctx, - setup_onnx_logging(verbose) as log_ctx, - ): - yield (mode_ctx, apex_ctx, log_ctx) - - -def _get_torch_export_args( - args: tuple[Any, ...], - kwargs: dict[str, Any] | None, -) -> tuple[tuple[Any, ...], dict[str, Any] | None]: - """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" - if not kwargs and args and isinstance(args[-1], dict): - kwargs = args[-1] - args = args[:-1] - return args, kwargs - - -def export( - model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, - args: tuple[Any, ...] | torch.Tensor, - f: str, - *, - kwargs: dict[str, Any] | None = None, - export_params: bool = True, - verbose: bool = False, - training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, - opset_version: int | None = None, - do_constant_folding: bool = True, - dynamic_axes: Mapping[str, Mapping[int, str]] - | Mapping[str, Sequence[int]] - | None = None, - keep_initializers_as_inputs: bool | None = None, - custom_opsets: Mapping[str, int] | None = None, - export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, - autograd_inlining: bool = True, -) -> None: - r"""Exports a model into ONNX format. - - If ``model`` is not a :class:`torch.jit.ScriptModule` nor a - :class:`torch.jit.ScriptFunction`, this runs - ``model`` once in order to convert it to a TorchScript graph to be exported - (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support - for dynamic control flow as :func:`torch.jit.trace`. - - Args: - model: The model to be exported. - args: - - args can be structured either as: - - 1. ONLY A TUPLE OF ARGUMENTS:: - - args = (x, y, z) - - The tuple should contain model inputs such that ``model(*args)`` is a valid - invocation of the model. Any non-Tensor arguments will be hard-coded into the - exported model; any Tensor arguments will become inputs of the exported model, - in the order they occur in the tuple. - - 2. A TENSOR:: - - args = torch.Tensor([1]) - - This is equivalent to a 1-ary tuple of that Tensor. - - 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: - - args = (x, {"y": input_y, "z": input_z}) - - All but the last element of the tuple will be passed as non-keyword arguments, - and named arguments will be set from the last element. If a named argument is - not present in the dictionary, it is assigned the default value, or None if a - default value is not provided. - - .. warning:: - This behavior will be deprecated in a future release. Please use the - kwargs argument instead. - - .. note:: - If a dictionary is the last element of the args tuple, it will be - interpreted as containing named arguments. In order to pass a dict as the - last non-keyword arg, provide an empty dict as the last element of the args - tuple. For example, instead of:: - - torch.onnx.export( - model, - ( - x, - # WRONG: will be interpreted as named arguments - {y: z}, - ), - "test.onnx.pb", - ) - - Write:: - - torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") - - f: Path to the output ONNX model file. E.g. "model.onnx". - kwargs: Named arguments to the model. - export_params: If True, all parameters will - be exported. Set this to False if you want to export an untrained model. - In this case, the exported model will first take all of its parameters - as arguments, with the ordering as specified by ``model.state_dict().values()`` - verbose: if True, prints a description of the - model being exported to stdout. In addition, the final ONNX graph will include the - field ``doc_string``` from the exported model which mentions the source code locations - for ``model``. If True, ONNX exporter logging will be turned on. - training: - * ``TrainingMode.EVAL``: export the model in inference mode. - * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is - False and in training mode if model.training is True. - * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations - which might interfere with training. - input_names (list of str, default empty list): names to assign to the - input nodes of the graph, in order. - output_names (list of str, default empty list): names to assign to the - output nodes of the graph, in order. - operator_export_type (enum, default OperatorExportTypes.ONNX): - - .. warning:: - This option will be deprecated in a future release. Future exported - graphs will always use the default opset domain. - - * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops - (in the default opset domain). - * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops - to standard ONNX ops in the default opset domain. If unable to do so - (e.g. because support has not been added to convert a particular torch op to ONNX), - fall back to exporting the op into a custom opset domain without conversion. Applies - to `custom ops `_ - as well as ATen ops. For the exported model to be usable, the runtime must support - these non-standard ops. - * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") - are exported as ATen ops (in opset domain "org.pytorch.aten"). - `ATen `_ is PyTorch's built-in tensor library, so - this instructs the runtime to use PyTorch's implementation of these ops. - - .. warning:: - - Models exported this way are probably runnable only by Caffe2. - - This may be useful if the numeric differences in implementations of operators are - causing large differences in behavior between PyTorch and Caffe2 (which is more - common on untrained models). - - * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op - (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so - (e.g. because support has not been added to convert a particular torch op to ONNX), - fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for - context. - For example:: - - graph(%0 : Float): - %3 : int = prim::Constant[value=0]() - # conversion unsupported - %4 : Float = aten::triu(%0, %3) - # conversion supported - %5 : Float = aten::mul(%4, %0) - return (%5) - - Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: - - graph(%0 : Float): - %1 : Long() = onnx::Constant[value={0}]() - # not converted - %2 : Float = aten::ATen[operator="triu"](%0, %1) - # converted - %3 : Float = onnx::Mul(%2, %0) - return (%3) - - .. warning:: - - Models exported this way are probably runnable only by Caffe2. - - opset_version (int, default 18): The version of the - `default (ai.onnx) opset `_ - to target. Must be >= 7. - do_constant_folding: Apply the constant-folding optimization. - Constant-folding will replace some of the ops that have all constant inputs - with pre-computed constant nodes. - dynamic_axes: - - By default the exported model will have the shapes of all input and output tensors - set to exactly match those given in ``args``. To specify axes of tensors as - dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: - - * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or - ``output_names``. - * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a - list, each element is an axis index. - - For example:: - - class SumModule(torch.nn.Module): - def forward(self, x): - return torch.sum(x, dim=1) - - - torch.onnx.export( - SumModule(), - (torch.ones(2, 2),), - "onnx.pb", - input_names=["x"], - output_names=["sum"], - ) - - Produces:: - - input { - name: "x" - ... - shape { - dim { - dim_value: 2 # axis 0 - } - dim { - dim_value: 2 # axis 1 - ... - output { - name: "sum" - ... - shape { - dim { - dim_value: 2 # axis 0 - ... - - While:: - - torch.onnx.export( - SumModule(), - (torch.ones(2, 2),), - "onnx.pb", - input_names=["x"], - output_names=["sum"], - dynamic_axes={ - # dict value: manually named axes - "x": {0: "my_custom_axis_name"}, - # list value: automatic names - "sum": [0], - }, - ) - - Produces:: - - input { - name: "x" - ... - shape { - dim { - dim_param: "my_custom_axis_name" # axis 0 - } - dim { - dim_value: 2 # axis 1 - ... - output { - name: "sum" - ... - shape { - dim { - dim_param: "sum_dynamic_axes_1" # axis 0 - ... - - keep_initializers_as_inputs: If True, all the - initializers (typically corresponding to parameters) in the - exported graph will also be added as inputs to the graph. If False, - then initializers are not added as inputs to the graph, and only - the non-parameter inputs are added as inputs. - This may allow for better optimizations (e.g. constant folding) by - backends/runtimes. - - If True, `deduplicate_initializers` pass will not be executed. This means - initializers with duplicated values will not be deduplicated and - will be treated as distinct inputs to the graph. This allows different - input initializers to be supplied at the runtime following export. - - If ``opset_version < 9``, initializers MUST be part of graph - inputs and this argument will be ignored and the behavior will be - equivalent to setting this argument to True. - - custom_opsets (dict[str, int], default empty dict): A dict with schema: - - * KEY (str): opset domain name - * VALUE (int): opset version - - If a custom opset is referenced by ``model`` but not mentioned in this dictionary, - the opset version is set to 1. Only custom opset domain name and version should be - indicated through this argument. - - export_modules_as_functions: Flag to enable - exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the - particular types of modules to export as local functions in ONNX. - This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because - ``opset_version`` < 15 implies IR version < 8, which means no local function support. - Module variables will be exported as function attributes. There are two categories of function - attributes. - - 1. Annotated attributes: class variables that have type annotations via - `PEP 526-style `_ - will be exported as attributes. - Annotated attributes are not used inside the subgraph of ONNX local function because - they are not created by PyTorch JIT tracing, but they may be used by consumers - to determine whether or not to replace the function with a particular fused kernel. - - 2. Inferred attributes: variables that are used by operators inside the module. Attribute names - will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from - python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. - - * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. - * ``True``: export all ``nn.Module`` forward calls as local function nodes. - * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, - only if the type of the ``nn.Module`` is found in the set. - - autograd_inlining: Flag used to control whether to inline autograd functions. - Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. - - Raises: - :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. - :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it - uses an operator that is not supported by the exporter. - :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. - All errors are subclasses of :class:`errors.OnnxExporterError`. - """ - if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: - warnings.warn( - "Setting `operator_export_type` to something other than default is deprecated. " - "The option will be removed in a future release.", - category=DeprecationWarning, - ) - if training == _C_onnx.TrainingMode.TRAINING: - warnings.warn( - "Setting `training` to something other than default is deprecated. " - "The option will be removed in a future release. Please set the training mode " - "before exporting the model.", - category=DeprecationWarning, - ) - - args = (args,) if isinstance(args, torch.Tensor) else args - if kwargs is not None: - args = args + (kwargs,) - - _export( - model, - args, - f, - export_params, - verbose, - training, - input_names, - output_names, - operator_export_type=operator_export_type, - opset_version=opset_version, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - custom_opsets=custom_opsets, - export_modules_as_functions=export_modules_as_functions, - autograd_inlining=autograd_inlining, - ) - - return None - - -def _is_constant_tensor_list(node): - if node.kind() != "prim::Constant": - return False - output_type = node.output().type() - if output_type.isSubtypeOf(_C.ListType.ofTensors()): - return True - if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): - return True - - -# ONNX can't handle constants that are lists of tensors, which can -# get generated in constant prop. So we split them back into prim::ListConstructs - - -def _split_tensor_list_constants(g, block): - for node in block.nodes(): - for subblock in node.blocks(): - _split_tensor_list_constants(g, subblock) - if _is_constant_tensor_list(node): - inputs = [] - for val in node.output().toIValue(): - input = g.insertConstant(val) - input.node().moveBefore(node) - input.node().copyMetadata(node) - inputs.append(input) - - lc = ( - g.create("prim::ListConstruct", inputs) - .insertBefore(node) - .output() - .setType(_C.ListType.ofTensors()) - ) - lc.node().copyMetadata(node) - node.output().replaceAllUsesWith(lc) - - -def _optimize_graph( - graph: _C.Graph, - operator_export_type: _C_onnx.OperatorExportTypes, - _disable_torch_constant_prop: bool = False, - fixed_batch_size: bool = False, - params_dict=None, - dynamic_axes=None, - input_names=None, - module=None, -): - if params_dict is None: - params_dict = {} - - # Inline everything - _C._jit_pass_inline(graph) - - # Remove fork/wait nodes - _C._jit_pass_inline_fork_wait(graph) - _C._jit_pass_lint(graph) - if GLOBALS.autograd_inlining: - _C._jit_pass_onnx_autograd_function_process(graph) - _C._jit_pass_lower_all_tuples(graph) - - # we now record some ops like ones/zeros - # into a trace where we previously recorded constants. - # use constant prop to maintain our current level of onnx support - # without implementing symbolics for all of them - if _disable_torch_constant_prop is False: - _C._jit_pass_constant_propagation(graph) - - _split_tensor_list_constants(graph, graph) - # run dce to eliminate dead parts of the graph that might have been - # left behind by things like symbolic_override - _C._jit_pass_dce(graph) - _C._jit_pass_lint(graph) - - # CSE should improve perf when Autocast is used with disabled cache - # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 - # Must run before _C._jit_pass_erase_number_types to prevent type substitution - if _C._jit_pass_cse(graph): - _C._jit_pass_onnx_lint(graph) - - _C._jit_pass_canonicalize_graph_fuser_ops(graph) - _C._jit_pass_lint(graph) - _C._jit_pass_peephole(graph, True) - _C._jit_pass_fuse_addmm(graph) - _C._jit_pass_lint(graph) - - _C._jit_pass_peephole(graph, True) - _C._jit_pass_lower_all_tuples(graph) - # in _jit_pass_onnx, symbolic functions are called for each node for conversion. - # However, there are nodes that cannot be converted without additional context. - # For example, the number of outputs from split (and whether it is static or dynamic) is unknown - # until the point where it is unpacked by listUnpack node. - # This pass does a preprocess, and prepares the nodes such that enough context can be received - # by the symbolic function. - _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) - _C._jit_pass_onnx_preprocess(graph) - - # onnx does not support tuples, so try to remove them - _C._jit_pass_lint(graph) - - # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 - _C._jit_pass_prepare_division_for_onnx(graph) - - _C._jit_pass_onnx_remove_print(graph) - _C._jit_pass_onnx_preprocess_caffe2(graph) - - symbolic_helper._quantized_ops.clear() - # Unpack quantized weights for conv and linear ops and insert into graph. - _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) - # onnx only supports tensors, so we turn all out number types into tensors - _C._jit_pass_erase_number_types(graph) - if GLOBALS.onnx_shape_inference: - input_names = [] if input_names is None else input_names - dynamic_axes = {} if dynamic_axes is None else dynamic_axes - _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) - _C._jit_pass_onnx_lint(graph) - - graph = _C._jit_pass_onnx(graph, operator_export_type) - _C._jit_pass_onnx_lint(graph) - _C._jit_pass_lint(graph) - - _C._jit_pass_onnx_scalar_type_analysis( - graph, True, GLOBALS.export_onnx_opset_version - ) - _C._jit_pass_lint(graph) - - _C._jit_pass_onnx_peephole( - graph, GLOBALS.export_onnx_opset_version, fixed_batch_size - ) - _C._jit_pass_lint(graph) - - # graph is not a valid jit graph anymore because types have been replaced - # (e.g. int with Tensor), so it now contains operators that don't actually - # exist. We can't run normal dead code elimination because it'd fail trying - # to look up if an operator has side effects, but we can run a dead code - # elimination variant that doesn't need to look up if an op has side effects. - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - _C._jit_pass_lint(graph) - graph = _C._jit_pass_canonicalize(graph) - _C._jit_pass_lint(graph) - if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError: - # NOTE: shape type inference error should not stop the export process - # https://github.com/pytorch/pytorch/issues/132205 - pass - - return graph - - -def warn_on_static_input_change(input_states): - """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. - - We accept dictionaries and strings as ONNX inputs, but they should be only for - configuration use. we detect here if these inputs are modified, and if so we warn - the user that the changes won't take effect in the traced ONNX graph. - """ - for input, traced_input in zip(input_states[0], input_states[1]): - if isinstance(input, dict): - if list(input.keys()) != list(traced_input.keys()): - warning = ( - "We detected that you are modifying a dictionary that is an input to your " - "model. " - "Note that dictionaries are allowed as inputs in ONNX but they should be " - "handled with care. " - "Usages of dictionaries is not recommended, and should not be used except " - "for configuration use. " - "Also note that the order and values of the keys must remain the same. " - ) - warnings.warn(warning) - elif isinstance(input, str): - if input != traced_input: - warning = ( - "The model seems to have string inputs/outputs. " - "Note that strings will not appear as inputs/outputs of the ONNX graph. " - ) - warnings.warn(warning) - - -def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): - """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" - return arg_value - - -def _decide_keep_init_as_input( - keep_initializers_as_inputs: bool | None, - operator_export_type: _C_onnx.OperatorExportTypes, - opset_version: int, -): - """Decides whether the initializers in the graph should be listed as ONNX graph inputs. - - This method encapsulates the logic to decide whether the initializers in the graph - should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). - If keep_initializers_as_inputs is not specified (None), then we decide whether to keep - initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type - is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other - export types keep initializers as input (val_keep_init_as_ip=True). - If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, - in which case it must be ignored because for opset version <= 8, all initializers MUST be - part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. - - Special handling is needed for opset version 8 or lower, because irrespective - of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 - semantics, i.e. all initializers must be listed as ONNX graph input. - """ - - if opset_version < 9: - if keep_initializers_as_inputs is False: - warnings.warn( - "Setting 'keep_initializers_as_inputs=False' for opset version" - "8 or lower would lead to an invalid ONNX graph. Therefore, " - "'keep_initializers_as_inputs=False' is ignored during export." - "Exported model will have initializers as graph inputs (compliant " - " to ONNX IR v3)." - ) - return True # i.e. True == initializers are part of graph input (ONNX IR v3) - val_keep_init_as_ip = ( - True if keep_initializers_as_inputs is None else keep_initializers_as_inputs - ) - if ( - keep_initializers_as_inputs is None - and operator_export_type is _C_onnx.OperatorExportTypes.ONNX - ): - val_keep_init_as_ip = False - return val_keep_init_as_ip - - -def _decide_add_node_names(add_node_names, operator_export_type): - return _resolve_args_by_export_type( - "add_node_names", add_node_names, operator_export_type - ) - - -def _decide_constant_folding(do_constant_folding, operator_export_type, training): - do_constant_folding = _resolve_args_by_export_type( - "do_constant_folding", do_constant_folding, operator_export_type - ) - if do_constant_folding and ( - training is not None and training is not _C_onnx.TrainingMode.EVAL - ): - warnings.warn( - "It is recommended that constant folding be turned off ('do_constant_folding=False') " - "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " - "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " - "learnable model parameters may not translate correctly in the exported ONNX model " - "because constant folding mutates model parameters. Please consider " - "turning off constant folding or setting the training=TrainingMode.EVAL." - ) - return do_constant_folding - - -def _signature(model) -> inspect.Signature: - should_be_callable = getattr(model, "forward", model) - if callable(should_be_callable): - return inspect.signature(should_be_callable) - raise ValueError("model has no forward method and is not callable") - - -def _decide_input_format(model, args): - try: - sig = _signature(model) - except ValueError as e: - warnings.warn(f"{e}, skipping _decide_input_format") - return args - try: - ordered_list_keys = list(sig.parameters.keys()) - if ordered_list_keys[0] == "self": - ordered_list_keys = ordered_list_keys[1:] - args_dict: dict = {} - if isinstance(args, list): - args_list = args - elif isinstance(args, tuple): - args_list = list(args) - else: - args_list = [args] - if isinstance(args_list[-1], dict): - args_dict = args_list[-1] - args_list = args_list[:-1] - n_nonkeyword = len(args_list) - for optional_arg in ordered_list_keys[n_nonkeyword:]: - if optional_arg in args_dict: - args_list.append(args_dict[optional_arg]) - # Check if this arg has a default value - else: - param = sig.parameters[optional_arg] - if param.default != param.empty: - args_list.append(param.default) - args = args_list if isinstance(args, list) else tuple(args_list) - # Cases of models with no input args - except IndexError: - warnings.warn("No input args, skipping _decide_input_format") - except Exception as e: - warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") - return args - - -def _trace(func, args, operator_export_type, return_outs=False): - # Special case for common case of passing a single Tensor - if isinstance(args, torch.Tensor): - args = (args,) - - trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( - func, - args, - strict=False, - _force_outplace=False, - _return_inputs_states=True, - ) - warn_on_static_input_change(inputs_states) - - trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) - if return_outs: - return trace_graph, torch_out - return trace_graph - - -def _trace_and_get_graph_from_model(model, args): - # A basic sanity check: make sure the state_dict keys are the same - # before and after running the model. Fail fast! - orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() - - # Disable Autocast cache because it replaces kernel's weight and bias - # by (undesired) constants. - # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 - prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() - torch.set_autocast_cache_enabled(False) - trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( - model, - args, - strict=False, - _force_outplace=False, - _return_inputs_states=True, - ) - torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) - - warn_on_static_input_change(inputs_states) - - if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): - raise RuntimeError( - "state_dict changed after running the tracer; " - "something weird is happening in your model!" - ) - - return trace_graph, torch_out - - -def _get_param_count_list(method_graph, args_params): - param_count_list = [] - for input_, arg_params_ in zip(method_graph.inputs(), args_params): - if "PackedParams" in str(input_.type()): - in_vars, _ = torch.jit._flatten(arg_params_) - param_count_list.append(len(in_vars)) - else: - param_count_list.append(arg_params_ is not None) - - return param_count_list - - -def _check_flatten_did_not_remove(original, jit_flattened): - """torch.jit._flatten removes None. Check if it did so in this case.""" - - def flatten(x): - if isinstance(x, (list, tuple)): - for inner in x: - yield from flatten(inner) - elif isinstance(x, dict): - for inner in x.values(): - yield from flatten(inner) - else: - yield x - - flattened_with_none = list(flatten(original)) - num_none = len(flattened_with_none) - len(jit_flattened) - assert num_none >= 0 - if num_none: - raise ValueError( - f"args contained {num_none} None's after flattening. " - "When exporting a ScriptModule or ScriptFunction, no args may " - "be None because that breaks type propagation." - ) - - -def _create_jit_graph( - model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] -) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: - if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): - flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) - _check_flatten_did_not_remove(args, flattened_args) - torch_out = None - - if isinstance(model, torch.jit.ScriptModule): - try: - graph = model.forward.graph # type: ignore[attr-defined] - except AttributeError as e: - raise RuntimeError("'forward' method must be a script method") from e - _C._jit_pass_onnx_function_substitution(graph) - freezed_module = _C._freeze_module( - cast(_C.ScriptModule, model._c), preserveParameters=True - ) - module, params = _C._jit_onnx_list_model_parameters(freezed_module) - method_graph = module._get_method("forward").graph - args_params = tuple(args) + tuple(params) - param_count_list = _get_param_count_list(method_graph, args_params) - in_vars, _ = torch.jit._flatten(args_params) - graph = _C._propagate_and_assign_input_shapes( - method_graph, tuple(in_vars), param_count_list, False, False - ) - return graph, params, torch_out, module - - # torch.jit.ScriptFunction - params = [] - graph = model.graph - _C._jit_pass_onnx_function_substitution(graph) - param_count_list = _get_param_count_list(graph, args) - graph = _C._propagate_and_assign_input_shapes( - graph, flattened_args, param_count_list, False, False - ) - return graph, params, torch_out, None - - graph, torch_out = _trace_and_get_graph_from_model(model, args) - _C._jit_pass_onnx_lint(graph) - state_dict = torch.jit._unique_state_dict(model) - params = list(state_dict.values()) - graph_inputs = list(graph.inputs()) - user_input_num = len(graph_inputs) - len(state_dict) - param_names = list(state_dict.keys()) - for i, inp in enumerate(graph_inputs): - if i >= user_input_num: - inp.setDebugName(param_names[i - user_input_num]) - _C._jit_pass_onnx_function_substitution(graph) - return graph, params, torch_out, None - - -def _get_named_param_dict(graph, params): - input_and_param_names = [val.debugName() for val in graph.inputs()] - param_names = input_and_param_names[len(input_and_param_names) - len(params) :] - _params_dict = dict(zip(param_names, params)) - return _params_dict - - -def _get_example_outputs(model, args): - input_args = copy.deepcopy(args) - input_kwargs = {} - if input_args and isinstance(input_args[-1], dict): - input_kwargs = input_args[-1] - input_args = input_args[:-1] - - example_outputs = model(*input_args, **input_kwargs) - if isinstance(example_outputs, list): - example_outputs = [example_outputs] - elif not isinstance(example_outputs, tuple): - example_outputs = (example_outputs,) - - return example_outputs - - -_qtype_vtype_map = { - torch.quint8: torch.uint8, - torch.qint8: torch.int8, - torch.qint32: torch.int32, - torch.quint4x2: torch.int8, -} - - -def unpack_quantized_tensor(value, cast_onnx_accepted=True): - if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: - q_value_dequantize = value.dequantize() - q_scale = ( - torch.tensor(value.q_scale(), dtype=torch.double) - if cast_onnx_accepted - else torch.tensor(value.q_scale(), dtype=torch.float32) - ) - q_zero_point = ( - torch.tensor(value.q_zero_point(), dtype=torch.int64) - if cast_onnx_accepted - else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) - ) - q_value = q_value_dequantize / q_scale + q_zero_point - q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) - return q_value, q_scale, q_zero_point - else: - return (value,) - - -def _pre_trace_quant_model(model, args): - r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return - original model. - - This is due to https://github.com/pytorch/pytorch/issues/75761. - """ - if any( - hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() - ) or any(getattr(arg, "is_quantized", False) for arg in args): - return torch.jit.trace(model, args) - return model - - -def _model_to_graph( - model, - args, - verbose=False, - input_names=None, - output_names=None, - operator_export_type=_C_onnx.OperatorExportTypes.ONNX, - do_constant_folding=True, - _disable_torch_constant_prop=False, - fixed_batch_size=False, - training=_C_onnx.TrainingMode.EVAL, - dynamic_axes=None, -) -> tuple[ - _C.Graph, - dict[str, torch.Tensor], - torch.Tensor - | tuple[torch.Tensor, ...] - | list[torch.Tensor] - | dict[str, torch.Tensor] - | Any - | None, -]: - """Converts model into an ONNX graph. - - Returns: - graph: A TorchScript IR Graph with ONNX nodes. - params_dict: Dict from input param name to param value. - torch_out: The output tensors resulting from the trace of ``model``. - If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, - this will be None, since we are not doing any tracing. - """ - # TODO: can we simplify this to always return a tuple of Tensor or None? - - # Special case for common case of passing a single Tensor - if isinstance(args, (torch.Tensor, int, float, bool)): - args = (args,) - - model = _pre_trace_quant_model(model, args) - graph, params, torch_out, module = _create_jit_graph(model, args) - params_dict = _get_named_param_dict(graph, params) - - try: - graph = _optimize_graph( - graph, - operator_export_type, - _disable_torch_constant_prop=_disable_torch_constant_prop, - fixed_batch_size=fixed_batch_size, - params_dict=params_dict, - dynamic_axes=dynamic_axes, - input_names=input_names, - module=module, - ) - except Exception: - _C._jit_onnx_log("Torch IR graph at exception: ", graph) - raise - - is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) - if is_script: - example_outputs = _get_example_outputs(model, args) - example_outputs_final = () - for example_output in example_outputs: - example_outputs_final += unpack_quantized_tensor(example_output) - out_vars, desc = torch.jit._flatten(example_outputs_final) - _C._jit_pass_onnx_assign_output_shape( - graph, - out_vars, - desc, - GLOBALS.onnx_shape_inference, - is_script, - GLOBALS.export_onnx_opset_version, - ) - - # NB: ONNX requires complete information about output types, which might be - # erased by some optimizations, so we need to set it explicitly again. - else: - if not isinstance(torch_out, (list, tuple)): - output_wrapped = [torch_out] - else: - output_wrapped = torch_out # type: ignore[assignment] - - output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) - # assign_output_shape pass is not compatible with quantized outputs. - # Quantized outputs are flattened to 3 values in ONNX, while packed as - # single value in PyTorch. - if not any(getattr(out, "is_quantized", False) for out in output_tensors): - _C._jit_pass_onnx_assign_output_shape( - graph, - output_tensors, - out_desc, - GLOBALS.onnx_shape_inference, - is_script, - GLOBALS.export_onnx_opset_version, - ) - - _set_input_and_output_names(graph, input_names, output_names) - params_dict = _get_named_param_dict(graph, params) - - if ( - do_constant_folding - and GLOBALS.export_onnx_opset_version - >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET - ): - if training is None or training == _C_onnx.TrainingMode.EVAL: - params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) - - params_dict = _C._jit_pass_onnx_constant_fold( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - - if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError: - # NOTE: shape type inference error should not stop the export process - # https://github.com/pytorch/pytorch/issues/132205 - pass - - params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) - - # For ONNX opset < 9, constants only have three data types: float16, float, double. - # In this pass transform constants of other data types to float/double + cast operator. - if GLOBALS.export_onnx_opset_version < 9: - _C._jit_pass_onnx_cast_all_constant_to_floating(graph) - - params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) - _C._jit_decay_packed_param_input_types(graph) - - # If output names lack a proper name and are identified only by their unique - # give them a legible name for debugging purposes - _apply_friendly_debug_names(graph, params_dict) - - return graph, params_dict, torch_out - - -@deprecated( - "Unconvertible ops are not definitive. Please remove usage of this function" -) -def unconvertible_ops( - model, - args, - training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, - opset_version: int | None = None, -) -> tuple[_C.Graph, list[str]]: - """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. - - .. deprecated:: 2.5 - Unconvertible ops are not definitive. Please remove usage of this function. - - The list is approximated because some ops may be removed during the conversion - process and don't need to be converted. Some other ops may have partial support - that will fail conversion with particular inputs. Please open a Github Issue - for op support requests. - - Args: - model: Same as the `model` parameter in :func:`torch.onnx.export`. - args: Same as the `args` parameter in :func:`torch.onnx.export`. - training: Same as the `training` parameter in :func:`torch.onnx.export`. - opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. - - Returns: - The JIT graph and a list of unconvertible ops in the format of "domain::op". - """ - - opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET - GLOBALS.export_onnx_opset_version = opset_version - - try: - with exporter_context(model, training, verbose=False): - # Create a mostly clean JIT graph that contains the plain aten and - # other ops we can check with the symbolic registry. - # NOTE: We don't want to actually convert any ops to ONNX or run any - # symbolic functions because there is a higher chance that a pass - # fails or an unconvertible op messes up the graph during ONNX conversion. - # This way we can always generate a list just by looking at the names - # of the ops in the graph. - args = _decide_input_format(model, args) - model = _pre_trace_quant_model(model, args) - graph, _, _, module = _create_jit_graph(model, args) - _C._jit_pass_inline(graph) - _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) - _C._jit_pass_erase_number_types(graph) - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - except Exception as e: - raise errors.OnnxExporterError( - "Failed to discover unconvertible ops because of errors during the JIT graph " - "generation process." - ) from e - - unsupported_ops = [] - for node in graph.nodes(): - domain_op = node.kind() - if domain_op.startswith(("onnx::", "prim::")): - # We consider onnx and prim ops as supported ops, even though some "prim" - # ops are not implemented as symbolic functions, because they may be - # eliminated in the conversion passes. Users may still see errors caused - # by prim ops even though they don't show up in the list. - continue - if not registration.registry.is_registered_op( - domain_op.rstrip("_"), opset_version - ): - # We consider all registered ops supported, even though some of them are - # only partially supported, because there is not yet a good way to check - # if an op is fully supported. - # TODO(justinchuby): Create a way to check if an op is fully supported. - unsupported_ops.append(domain_op) - return graph, unsupported_ops - - -def _setup_trace_module_map( - model: torch.nn.Module | torch.jit.ScriptModule, - export_modules_as_functions: bool | Collection[type[torch.nn.Module]], -) -> set[str]: - def __register_attribute_hook(): - attr_name = "_onnx_attrs" - - def _track_module_attributes_forward_pre_hook(module, input): - setattr(module, attr_name, _get_module_attributes(module)) - - def _track_module_attributes_forward_hook(module, input, output): - tracing_state = _C._get_tracing_state() - if not tracing_state: - return - - graph = tracing_state.graph() - onnx_attrs = {} - if hasattr(module, attr_name): - onnx_attrs = getattr(module, attr_name) - delattr(module, attr_name) - - _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) - - for m in model.modules(): - m.register_forward_hook(_track_module_attributes_forward_hook) - m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) - - def _unqualified_variable_name(qualified_name: str) -> str: - """ - Parse qualified variable name and return the unqualified version. - - Pure numeric atoms are considered inadequate, so this function will look past them, - and start from the first non-numeric atom. - - Example: - >>> _unqualified_variable_name("__main__.Foo.bar") - 'bar' - >>> _unqualified_variable_name("__main__.Foo.bar.0") - 'bar.0' - """ - name_atoms = qualified_name.split(".") - for i, atom in reversed(list(enumerate(name_atoms))): - if not atom.isnumeric(): - return ".".join(name_atoms[i:]) - return qualified_name - - trace_module_map = { - _m: torch._C._jit_onnx_create_full_scope_name( - torch.typename(type(_m)), _unqualified_variable_name(_n) - ) - for _n, _m in model.named_modules() - } - torch.jit._trace._trace_module_map = trace_module_map - if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: - module_typenames = {torch.typename(type(module)) for module in trace_module_map} - elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: - - def _find_typename(v): - if isinstance(v, type): - return torch.typename(v) - else: - raise RuntimeError( - "Only type of the `nn.Module` should be " - "passed in the set for argument `export_modules_as_functions`. " - f"Got `{type(v).__name__}`." - ) - - module_typenames = {_find_typename(v) for v in export_modules_as_functions} - else: - module_typenames = set() - - if module_typenames: - __register_attribute_hook() - - return module_typenames - - -def _reset_trace_module_map(): - torch.jit._trace._trace_module_map = None - _C._jit_pass_onnx_clear_scope_records() - - -def _get_module_attributes(module): - annotations = typing.get_type_hints(type(module)) - base_m_annotations = typing.get_type_hints(torch.nn.Module) - [annotations.pop(k, None) for k in base_m_annotations] - # Check whether module attributes can be accessed. Some classes - # define attributes but don't provide access to them in their - # constructor. - # - # For example, torch.nn.Embedding has the `freeze` variable and its - # type specified in the class but the attribute is not created in the - # constructor. In other words, there is no `self.freeze = ` - # in the constructor. - # - # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 - attrs = {} - for k in annotations: - try: - attrs[k] = getattr(module, k) - except AttributeError: - _C._jit_onnx_log(f"Skipping module attribute '{k}'") - continue - return attrs - - -def _export( - model, - args, - f, - export_params=True, - verbose=False, - training=_C_onnx.TrainingMode.EVAL, - input_names=None, - output_names=None, - operator_export_type=_C_onnx.OperatorExportTypes.ONNX, - export_type=None, - opset_version=None, - do_constant_folding=True, - dynamic_axes=None, - keep_initializers_as_inputs=None, - fixed_batch_size=False, - custom_opsets=None, - add_node_names=True, - onnx_shape_inference=True, - export_modules_as_functions: Any = False, - autograd_inlining=True, -): - assert GLOBALS.in_onnx_export is False - - if isinstance(model, torch.nn.DataParallel): - raise ValueError( - "torch.nn.DataParallel is not supported by ONNX " - "exporter, please use 'attribute' module to " - "unwrap model from torch.nn.DataParallel. Try " - "torch.onnx.export(model.module, ...)" - ) - - GLOBALS.onnx_shape_inference = onnx_shape_inference - - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - - if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: - warnings.warn( - f"Exporting to ONNX opset version {opset_version} is not supported. " - f"by 'torch.onnx.export()'. " - f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " - f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", - category=errors.OnnxExporterWarning, - ) - - if export_modules_as_functions and opset_version < 15: - raise ValueError( - "`export_modules_as_functions` is not supported for `opset_version` < 15." - "This is because `opset_version` < 15 implies IR version < 8, which means " - "no local function support. " - ) - if not operator_export_type: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX - - # By default, training=TrainingMode.EVAL, - # which is good because running a model in training mode could result in - # internal buffers getting updated, dropout getting applied, etc. - # If you really know what you're doing, you can turn - # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, - # (to preserve whatever the original training mode was.) - GLOBALS.export_onnx_opset_version = opset_version - GLOBALS.operator_export_type = operator_export_type - - try: - GLOBALS.in_onnx_export = True - _autograd_inlining_previous = GLOBALS.autograd_inlining - GLOBALS.autograd_inlining = autograd_inlining - - module_typenames_to_export_as_functions: set[str] = set() - if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): - module_typenames_to_export_as_functions = _setup_trace_module_map( - model, export_modules_as_functions - ) - - with exporter_context(model, training, verbose): - val_keep_init_as_ip = _decide_keep_init_as_input( - keep_initializers_as_inputs, - operator_export_type, - opset_version, - ) - val_add_node_names = _decide_add_node_names( - add_node_names, operator_export_type - ) - val_do_constant_folding = _decide_constant_folding( - do_constant_folding, operator_export_type, training - ) - # Normally f can be a file-like object, but for large models, the external data format requires a - # valid `model_file_location`. Code in export.cpp will enforce this. - if isinstance(f, str): - model_file_location = f - else: - model_file_location = "" - args = _decide_input_format(model, args) - if dynamic_axes is None: - dynamic_axes = {} - _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) - - graph, params_dict, torch_out = _model_to_graph( - model, - args, - verbose, - input_names, - output_names, - operator_export_type, - val_do_constant_folding, - fixed_batch_size=fixed_batch_size, - training=training, - dynamic_axes=dynamic_axes, - ) - - if custom_opsets is None: - custom_opsets = {} - - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - node_attr_to_name = {} # type: ignore[var-annotated] - if module_typenames_to_export_as_functions: - # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. - node_attr_to_name = _C._jit_pass_onnx_function_extraction( - graph, - module_typenames_to_export_as_functions, - list(params_dict.keys()), - ) - - if keep_initializers_as_inputs is not True: - params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] - graph, - params_dict, # type: ignore[arg-type] - getattr(model, "training", False), # type: ignore[arg-type] - ) - _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) - defer_weight_export = False - if export_params: - ( - proto, - export_map, - _val_use_external_data_format, - _node_names, - ) = graph._export_onnx( # type: ignore[attr-defined] - params_dict, - opset_version, - dynamic_axes, - defer_weight_export, - operator_export_type, - not verbose, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - model_file_location, - node_attr_to_name, - ) - else: - ( - proto, - export_map, - _, - _, - ) = graph._export_onnx( # type: ignore[attr-defined] - {}, - opset_version, - dynamic_axes, - defer_weight_export, - operator_export_type, - not verbose, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - model_file_location, - node_attr_to_name, - ) - # insert function_proto into model_proto. - proto = onnx_proto_utils._add_onnxscript_fn( - proto, - custom_opsets, - ) - if verbose: - _C._jit_onnx_log("Exported graph: ", graph) - onnx_proto_utils._export_file(proto, f, export_map) - finally: - assert GLOBALS.in_onnx_export - GLOBALS.in_onnx_export = False - GLOBALS.autograd_inlining = _autograd_inlining_previous - _reset_trace_module_map() - - return torch_out - - -def _apply_friendly_debug_names(graph, params): - for n in graph.nodes(): - for v in n.inputs(): - old_name = v.debugName() - if old_name != str(v.unique()): - continue - new_name = f"{n.kind()}_{v.unique()}" - v.setDebugName(new_name) - if old_name in params: - params[new_name] = params.pop(old_name) - - -def _set_input_and_output_names(graph, input_names, output_names): - def set_names(node_list, name_list, descriptor): - if name_list is None: - return - if len(name_list) > len(node_list): - raise RuntimeError( - f"number of {descriptor} names provided ({len(name_list)}) " - f"exceeded number of {descriptor}s ({len(node_list)})" - ) - - # Mark if the output node DebugName is set before. - output_node_set = set() - for i, (name, node) in enumerate(zip(name_list, node_list)): - # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). - if descriptor == "output": - if node in output_node_set: - identity_node = graph.create("onnx::Identity") - identity_node.insertAfter(node.node()) - identity_node.addInput(node) - identity_node.output().setType(node.type()) - graph.return_node().replaceInput(i, identity_node.output()) - node = identity_node.output() - output_node_set.add(node) - - if node.debugName() != name: - node.setDebugName(name) - - set_names(list(graph.inputs()), input_names, "input") - set_names(list(graph.outputs()), output_names, "output") - - -def _run_symbolic_method(g, op_name, symbolic_fn, args): - r""" - This trampoline function gets invoked for every symbolic method - call from C++. - """ - try: - graph_context = jit_utils.GraphContext( - graph=g, - block=g.block(), - opset=GLOBALS.export_onnx_opset_version, - original_node=None, # type: ignore[arg-type] - params_dict=_params_dict, - env={}, - values_in_env=set(), - new_nodes=[], - ) - return symbolic_fn(graph_context, *args) - except TypeError as e: - # Handle the specific case where we didn't successfully dispatch - # to symbolic_fn. Otherwise, the backtrace will have the clues - # you need. - e.args = (f"{e.args[0]} (occurred when translating {op_name})",) - raise - - -def _add_block(node: _C.Node) -> _C.Block: - return node.addBlock() - - -def _add_input_to_block(block: _C.Block): - return block.addInputToBlock() # type: ignore[attr-defined] - - -def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: - return block.registerOutput(value) - - -def _should_aten_fallback( - name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes -): - # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, - # an aten::ATen operator is created regardless of symbolics existence - - is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) - is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN - is_aten_fallback_export = ( - operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - ) - - if not name.startswith("aten::"): - return False - - if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): - return True - - return False - - -def _get_aten_op_overload_name(n: _C.Node) -> str: - # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds - schema = n.schema() - if not schema.startswith("aten::"): - return "" - return _C.parse_schema(schema).overload_name - - -def _run_symbolic_function( - graph: _C.Graph, - block: _C.Block, - node: _C.Node, - inputs: Any, - env: dict[_C.Value, _C.Value], - values_in_env: set[_C.Value], - new_nodes: list[_C.Node], - operator_export_type=_C_onnx.OperatorExportTypes.ONNX, -) -> _C.Value | Sequence[_C.Value | None] | None: - """Runs a symbolic function. - - The function is used in C++ to export the node to ONNX. - - Returns: - A single or a tuple of Values. - None when the node gets cloned as is into the new graph. - """ - - opset_version = GLOBALS.export_onnx_opset_version - - # See Note [Export inplace] - node_kind = node.kind() - if node_kind.endswith("_"): - # Treat relu_ -> relu; add_ -> add etc. - ns_op_name = node_kind[:-1] - else: - ns_op_name = node_kind - - namespace, op_name = jit_utils.parse_node_kind(ns_op_name) - - graph_context = jit_utils.GraphContext( - graph=graph, - block=block, - opset=opset_version, - original_node=node, - params_dict=_params_dict, - env=env, - values_in_env=values_in_env, - new_nodes=new_nodes, - ) - - # Direct ATen export requested - if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): - attrs = { - k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) - for k in node.attributeNames() - } - outputs = node.outputsSize() - attrs["outputs"] = outputs - return graph_context.aten_op( - op_name, - *inputs, - overload_name=_get_aten_op_overload_name(node), - **attrs, - ) - - try: - domain = namespace - symbolic_function_name = f"{domain}::{op_name}" - - symbolic_function_group = registration.registry.get_function_group( - symbolic_function_name - ) - if symbolic_function_group is not None: - symbolic_fn = symbolic_function_group.get(opset_version) - if symbolic_fn is not None: - # TODO Wrap almost identical attrs assignment or comment the difference. - attrs = { - k: symbolic_helper._node_get(node, k) for k in node.attributeNames() - } - return symbolic_fn(graph_context, *inputs, **attrs) - - attrs = { - k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) - for k in node.attributeNames() - } - if namespace == "onnx": - # Clone node to trigger ONNX shape inference - return graph_context.op( - op_name, *inputs, **attrs, outputs=node.outputsSize() - ) # type: ignore[attr-defined] - - raise errors.UnsupportedOperatorError( - symbolic_function_name, - opset_version, - symbolic_function_group.get_min_supported() - if symbolic_function_group - else None, - ) - - except RuntimeError: - if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: - return None - elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: - # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` - attrs = { - k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) - for k in node.attributeNames() - } - return graph_context.aten_op( - op_name, - *inputs, - overload_name=_get_aten_op_overload_name(node), - **attrs, - ) - raise - except TypeError as e: - # Handle the specific case where we didn't successfully dispatch. - # Otherwise, the backtrace will have the clues you need. - e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) - raise - - -def _verify_custom_op_name(symbolic_name: str): - if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): - raise errors.OnnxExporterError( - f"Failed to register operator {symbolic_name}. " - "The symbolic name must match the format domain::name, " - "and should start with a letter and contain only " - "alphanumerical characters" - ) - - ns, _ = jit_utils.parse_node_kind(symbolic_name) - if ns == "onnx": - raise ValueError( - f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." - ) - - -def register_custom_op_symbolic( - symbolic_name: str, - symbolic_fn: Callable, - opset_version: int, -): - """Registers a symbolic function for a custom operator. - - When the user registers symbolic for custom/contrib ops, - it is highly recommended to add shape inference for that operator via setType API, - otherwise the exported graph may have incorrect shape inference in some extreme cases. - An example of setType is `test_aten_embedding_2` in `test_operators.py`. - - See "Custom Operators" in the module documentation for an example usage. - - Args: - symbolic_name (str): The name of the custom operator in "::" - format. - symbolic_fn (Callable): A function that takes in the ONNX graph and - the input arguments to the current operator, and returns new - operator nodes to add to the graph. - opset_version (int): The ONNX opset version in which to register. - """ - if symbolic_name.startswith("::"): - symbolic_name = f"aten{symbolic_name}" - - _verify_custom_op_name(symbolic_name) - - registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) - - -def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): - """Unregisters ``symbolic_name``. - - See "Custom Operators" in the module documentation for an example usage. - - Args: - symbolic_name (str): The name of the custom operator in "::" - format. - opset_version (int): The ONNX opset version in which to unregister. - """ - if symbolic_name.startswith("::"): - symbolic_name = f"aten{symbolic_name}" - - _verify_custom_op_name(symbolic_name) - - registration.registry.unregister(symbolic_name, opset_version) - - -def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): - """Ensures dynamic axes argument is follows the expected format.""" - if len(dynamic_axes) == 0: - return - - if hasattr(model, "graph"): - # Extracting set of valid input/output names that shall be used for dynamic_axes - if (input_names is None) or len(input_names) == 0: - input_names = [x.debugName() for x in model.graph.inputs()] - if (output_names is None) or len(output_names) == 0: - output_names = [y.debugName() for y in model.graph.outputs()] - - valid_names = set((input_names or []) + (output_names or [])) - - # If dynamic axes are provided as a list rather than dictionary, they should - # first get converted to a dictionary in expected format. If desired axes names - # are not provided for dynamic axes, automatic names shall be generated for - # provided dynamic axes of specified input/output - for key, value in dynamic_axes.items(): - if key not in valid_names: - warnings.warn( - f"Provided key {key} for dynamic axes is not a valid input/output name" - ) - if isinstance(value, list): - warnings.warn( - "No names were found for specified dynamic axes of provided input." - f"Automatically generated names will be applied to each dynamic axes of input {key}" - ) - - value_dict = {} - for i, x in enumerate(value): - if not isinstance(x, int): - raise ValueError( - "The type of axis index is expected to be an integer" - ) - if x in value_dict: - warnings.warn( - f"Duplicate dynamic axis index {x} was provided for input {key}." - ) - else: - value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) - dynamic_axes[key] = value_dict - - -def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: - return inspect.signature( - model.forward if isinstance(model, torch.nn.Module) else model - ) +from torch.onnx._internal.torchscript_exporter.utils import * # noqa: F401,F403 diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index bc98fedae086..70d901acb47a 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -1,1872 +1,12 @@ -# mypy: allow-untyped-defs -"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models.""" +"""A set of tools to verify the correctness of ONNX models.""" -from __future__ import annotations +__all__ = ["VerificationInfo", "verify_onnx_program"] - -__all__ = [ - "OnnxBackend", - "VerificationOptions", - "verify", - "check_export_model_diff", - "VerificationInfo", - "verify_onnx_program", - "GraphInfo", - "GraphInfoPrettyPrinter", - "OnnxTestCaseRepro", - "find_mismatch", - "verify_aten_graph", -] - -import contextlib -import copy -import dataclasses -import datetime -import difflib -import enum -import functools -import io -import itertools -import os -import tempfile -import typing_extensions -import warnings -from collections.abc import Collection, Mapping, Sequence -from typing import Any, Callable, Union - -import numpy as np -import numpy.typing as npt - -import torch -import torch._C._onnx as _C_onnx -from torch import _C -from torch.onnx import _constants, _experimental, utils -from torch.onnx._globals import GLOBALS -from torch.onnx._internal import onnx_proto_utils from torch.onnx._internal.exporter._verification import ( VerificationInfo, verify_onnx_program, ) -from torch.types import Number -# TODO: Update deprecation messages to recommend the new classes - VerificationInfo.__module__ = "torch.onnx.verification" verify_onnx_program.__module__ = "torch.onnx.verification" - -# Everything below are deprecated ############################################## - -_ORT_PROVIDERS = ("CPUExecutionProvider",) - -_NumericType = Union[Number, torch.Tensor, np.ndarray] -_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] -_InputArgsType = Union[torch.Tensor, tuple[Any, ...]] -_InputKwargsType = Mapping[str, Any] -_OutputsType = Union[Sequence[_NumericType], Sequence] - - -class OnnxBackend(enum.Enum): - """Enum class for ONNX backend used for export verification. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - """ - - REFERENCE = "ONNXReferenceEvaluator" - ONNX_RUNTIME_CPU = "CPUExecutionProvider" - ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" - - -@dataclasses.dataclass -class VerificationOptions: - """Options for ONNX export verification. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - - Attributes: - flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of - Tensors for ONNX. Set this to False if nested structures are to be preserved - for ONNX, which is usually the case with exporting ScriptModules. Default True. - ignore_none: Whether to ignore None type in torch output, which is usually the - case with tracing. Set this to False, if torch output should keep None type, - which is usually the case with exporting ScriptModules. Default to True. - check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs - are exactly the same. Set this to False to allow output shape broadcasting. - Default to True. - check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs - are consistent. Default to True. - backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. - rtol: relative tolerance in comparison between ONNX and PyTorch outputs. - atol: absolute tolerance in comparison between ONNX and PyTorch outputs. - remained_onnx_input_idx: If provided, only the specified inputs will be passed - to the ONNX model. Supply a list when there are unused inputs in the model. - Since unused inputs will be removed in the exported ONNX model, supplying - all inputs will cause an error on unexpected inputs. This parameter tells - the verifier which inputs to pass into the ONNX model. - acceptable_error_percentage: acceptable percentage of element mismatches in comparison. - It should be a float of value between 0.0 and 1.0. - """ - - flatten: bool = True - ignore_none: bool = True - check_shape: bool = True - check_dtype: bool = True - backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU - rtol: float = 1e-3 - atol: float = 1e-7 - remained_onnx_input_idx: Sequence[int] | None = None - acceptable_error_percentage: float | None = None - - -def _flatten_tuples(elem): - flattened = [] - for t in elem: - if isinstance(t, tuple): - flattened.extend(_flatten_tuples(t)) - else: - flattened.append(t) - return flattened - - -# TODO(justinchuby): Add type checking by narrowing down the return type when input is None -def _to_numpy(elem) -> list | npt.NDArray: - if isinstance(elem, torch.Tensor): - if elem.requires_grad: - return elem.detach().cpu().numpy() - else: - return elem.cpu().numpy() - elif isinstance(elem, (list, tuple)): - return [_to_numpy(inp) for inp in elem] - elif isinstance(elem, (bool, int, float)): - return np.array(elem) - elif isinstance(elem, dict): - flattened = [] - for k in elem: - flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) - return flattened - return elem - - -def _inline_flatten_list(inputs, res_list) -> list: - for i in inputs: - res_list.append(i) if not isinstance( - i, (list, tuple) - ) else _inline_flatten_list(i, res_list) - return res_list - - -def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: - value_unpacked = [] - for value in values: - value_unpacked.extend( - utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) - ) - return [_to_numpy(v) for v in value_unpacked] - - -def _run_onnx(onnx_session, inputs) -> _OutputsType: - kw_inputs = {} - if inputs and isinstance(inputs[-1], dict): - kw_inputs = inputs[-1] - inputs = inputs[:-1] - inputs = _unpack_to_numpy(_flatten_tuples(inputs)) - ort_inputs = {} - for input_name, input in kw_inputs.items(): - ort_inputs[input_name] = _to_numpy(input) - inputs = _to_numpy(inputs) - if hasattr(onnx_session, "get_inputs"): - # onnxruntime.InferenceSession - input_names = [i.name for i in onnx_session.get_inputs()] - elif hasattr(onnx_session, "input_names"): - # onnx.reference.ReferenceEvaluator - input_names = onnx_session.input_names - else: - raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") - - for i, input in enumerate(inputs): - if i == len(input_names) or input_names[i] in ort_inputs: - raise ValueError( - f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " - f"input names: {input_names}." - ) - ort_inputs[input_names[i]] = input - onnx_outs = onnx_session.run(None, ort_inputs) - return onnx_outs - - -def _ort_session( - model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS -): - try: - import onnxruntime # type: ignore[import] - except ImportError as e: - raise ImportError("onnxruntime is required for export verification.") from e - - if ort_providers is None: - ort_providers = _ORT_PROVIDERS - - session_options = onnxruntime.SessionOptions() - # suppress ort warnings. - # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. - session_options.log_severity_level = 3 - ort_session = onnxruntime.InferenceSession( - model if isinstance(model, str) else model.getvalue(), - session_options, - providers=ort_providers, - ) - return ort_session - - -def _onnx_reference_evaluator_session(model: str | io.BytesIO): - try: - import onnx - from onnx import reference as onnx_reference # type: ignore[attr-defined] - except ImportError as exc: - raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc - - proto = ( - onnx.load(model) # type: ignore[attr-defined] - if isinstance(model, str) - else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] - ) - onnx_session = onnx_reference.ReferenceEvaluator(proto) - return onnx_session - - -def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): - if backend == OnnxBackend.REFERENCE: - onnx_session = _onnx_reference_evaluator_session(model) - elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: - onnx_session = _ort_session(model, (backend.value,)) - else: - raise ValueError(f"Unsupported backend: {backend}") - return onnx_session - - -def _compare_onnx_pytorch_outputs_in_np( - onnx_outs: _OutputsType, - pt_outs: _OutputsType, - options: VerificationOptions, -): - assert len(onnx_outs) == len(pt_outs), ( - f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" - ) - acceptable_error_percentage = options.acceptable_error_percentage - if acceptable_error_percentage and ( - acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 - ): - raise ValueError( - "If set, acceptable_error_percentage should be between 0.0 and 1.0" - ) - - for ort_out, pt_out in zip(onnx_outs, pt_outs): - try: - # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. - if not options.check_shape: - # Allow different but broadcastable output shapes. - ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) - torch.testing.assert_close( - ort_out, - pt_out, - rtol=options.rtol, - atol=options.atol, - check_dtype=options.check_dtype, - equal_nan=True, - ) - except AssertionError as e: - if acceptable_error_percentage: - error_percentage = 1 - np.sum( - np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) - ) / np.prod(ort_out.shape) - if error_percentage <= acceptable_error_percentage: - warnings.warn( - f"Suppressed AssertionError:\n{e}.\n" - f"Error percentage {error_percentage} " - f"within acceptable range {acceptable_error_percentage}." - ) - continue - if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: - warnings.warn("ONNX output is quantized") - if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: - warnings.warn("PyTorch output is quantized") - raise - - -def _compare_onnx_pytorch_outputs( - onnx_outs: _OutputsType, - pt_outs: Any, - options: VerificationOptions, -): - """ - Compare ONNX and PyTorch outputs. - - Args: - onnx_outs: outputs from ONNX backend. - pt_outs: outputs from PyTorch. - options: options for verification. - - Raises: - AssertionError: if outputs from ONNX model and PyTorch model are not - equal up to specified precision. - ValueError: if arguments provided are invalid. - """ - if options.ignore_none: - # torch.jit._flatten filters None type - pt_outs, _ = torch.jit._flatten(pt_outs) - else: - pt_outs = _inline_flatten_list([pt_outs], []) - pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) - onnx_outs = _inline_flatten_list(onnx_outs, []) - _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) - - -def _prepare_input_for_pytorch(args, kwargs): - """Prepare input for PyTorch model execution. - - Any future changes/formatting to the input before dispatching to the PyTorch - model should be made in this function. - - Args: - args: positional arguments for PyTorch model forward method. - kwargs: keyword arguments for PyTorch model forward method. - - Returns: - args: positional arguments for PyTorch model forward method. - kwargs: keyword arguments for PyTorch model forward method. - """ - if isinstance(args, (torch.Tensor, dict)): - args = (args,) - # In-place operators will update input tensor data as well. - # Thus inputs are replicated before every forward call. - args = copy.deepcopy(args) - if kwargs: - kwargs = copy.deepcopy(kwargs) - else: - kwargs = {} - return args, kwargs - - -def _prepare_input_for_export(args, kwargs): - """Prepare input for ONNX model export. - - Any future changes/formatting to the input before dispatching to the - :func:`torch.onnx.export` api should be made in this function. - - Args: - args: positional arguments for PyTorch model forward method. - kwargs: keyword arguments for PyTorch model forward method. - - Returns: - onnx_inputs: positional arguments for ONNX model export, as `args` in - :func:`torch.onnx.export`. - """ - args, kwargs = _prepare_input_for_pytorch(args, kwargs) - if not kwargs and len(args) > 0 and isinstance(args[-1], dict): - onnx_inputs = args + ({},) - elif kwargs: - onnx_inputs = args + (kwargs,) - else: - onnx_inputs = args - return onnx_inputs - - -def _prepare_input_for_onnx( - args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool -): - """Prepare input for ONNX model execution in ONNX backend. - - Any future changes/formatting to the input before dispatching to the ONNX backend - run should be made in this function. - - Args: - args: positional arguments for PyTorch model forward method. - kwargs: keyword arguments for PyTorch model forward method. - remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. - flatten: whether to flatten the input before dispatching to the ONNX model execution. - - Returns: - onnx_inputs: positional arguments for ONNX model execution in ONNX backend. - """ - onnx_inputs = _prepare_input_for_export(args, kwargs) - if flatten: - onnx_inputs, _ = torch.jit._flatten(onnx_inputs) - elif onnx_inputs and onnx_inputs[-1] == {}: - # Handle empty kwargs (normally removed by flatten). - onnx_inputs = onnx_inputs[:-1] - if remained_onnx_input_idx is not None: - return [onnx_inputs[i] for i in remained_onnx_input_idx] - else: - return onnx_inputs - - -def _try_clone_model(model): - """Used for preserving original model in case forward mutates model states.""" - try: - return copy.deepcopy(model) - except Exception: - warnings.warn( - "Failed to clone model. Model state might be mutated during verification." - ) - return model - - -def _compare_onnx_pytorch_model( - pt_model: _ModelType, - onnx_model_f: str | io.BytesIO, - input_args: _InputArgsType, - input_kwargs: _InputKwargsType | None, - additional_test_inputs: Sequence[_InputArgsType] | None, - options: VerificationOptions, -): - """Compare outputs from ONNX model runs with outputs from PyTorch model runs. - - Args: - pt_model: PyTorch model. - onnx_model_f: ONNX model file path or file-like object. - input_args: positional arguments for PyTorch model forward method. - input_kwargs: keyword arguments for PyTorch model forward method. - additional_test_inputs: additional positional arguments for PyTorch model - forward method. - options: options for verification. - - Raises: - AssertionError: if outputs from ONNX model and PyTorch model are not - equal up to specified precision. - """ - onnx_session = _onnx_backend_session(onnx_model_f, options.backend) - - def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): - pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) - # TODO: remove this and treat mutating model separately. See #77679 - pt_model_copy = _try_clone_model(pt_model) - pt_outs = pt_model_copy(*pt_args, **pt_kwargs) - - onnx_inputs = _prepare_input_for_onnx( - input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten - ) - - onnx_outs = _run_onnx(onnx_session, onnx_inputs) - - _compare_onnx_pytorch_outputs( - onnx_outs=onnx_outs, - pt_outs=pt_outs, - options=options, - ) - - compare_onnx_pytorch_model_with_input(input_args, input_kwargs) - - if additional_test_inputs: - for test_input_args in additional_test_inputs: - compare_onnx_pytorch_model_with_input(test_input_args, {}) - - -class _GraphDiff: - """A class to represent the difference between two graphs.""" - - def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): - """Construct a _GraphDiff object. - - Args: - graph_a (_C.Graph): First graph to compare. - graph_b (_C.Graph): Second graph to compare. - """ - self.graph_a = graph_a - self.graph_b = graph_b - - def __str__(self): - """See function :func:`diff_report`.""" - return self.diff_report() - - def _indent(self, lines: str) -> str: - return "\n".join(["\t" + line for line in lines.splitlines()]) - - def diff_report(self) -> str: - """Return a string representation of the graph difference. - - The report shows the first pair of nodes that diverges. It also shows the source - location of the pair of nodes. - - Returns: - graph_diff_report (str): A string representation of the graph difference. - """ - graph_a = self.graph_a - graph_b = self.graph_b - - graph_a_str = str(graph_a) - graph_b_str = str(graph_b) - - if graph_a_str == graph_b_str: - return "" - - graph_diff = difflib.ndiff( - graph_a_str.splitlines(True), graph_b_str.splitlines(True) - ) - graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] - - for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): - if str(node_a) != str(node_b): - graph_diff_report.append("First diverging operator:") - node_diff = difflib.ndiff( - str(node_a).splitlines(True), str(node_b).splitlines(True) - ) - source_printout = ["node diff:", self._indent("".join(node_diff))] - - stack_a = node_a.sourceRange() if node_a else None - if stack_a: - source_printout.extend( - ["Former source location:", self._indent(str(stack_a))] - ) - stack_b = node_b.sourceRange() if node_b else None - if stack_b: - source_printout.extend( - ["Latter source location:", self._indent(str(stack_b))] - ) - - graph_diff_report.extend(source_printout) - - break - - return "\n".join(graph_diff_report) - - -def _check_graph_diff( - model: torch.nn.Module | torch.jit.ScriptModule, - test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], - export_options: _experimental.ExportOptions, - model_to_graph_func: Callable[ - [ - torch.nn.Module, - tuple[Any, ...], - Mapping[str, Any], - _experimental.ExportOptions, - ], - _C.Graph, - ], -) -> str: - """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. - - Args: - model: See :func:`check_export_model_diff`. - test_input_groups: See :func:`check_export_model_diff`. - export_options: See :func:`check_export_model_diff`. - model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. - - Returns: - graph_diff_report (str): A string representation of the graph difference. - """ - if len(test_input_groups) < 2: - raise ValueError("Need at least two groups of test inputs to compare.") - - ref_jit_graph = None - for args, kwargs in test_input_groups: - jit_graph = model_to_graph_func(model, args, kwargs, export_options) - if ref_jit_graph is None: - ref_jit_graph = jit_graph - continue - - graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() - if graph_diff_report: - return graph_diff_report - return "" - - -def _traced_graph_from_model( - model: torch.nn.Module | torch.jit.ScriptModule, - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - export_options: _experimental.ExportOptions, -) -> _C.Graph: - """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. - - Args: - model: See :func:`check_export_model_diff`. - args: See :func:`check_export_model_diff`. - kwargs: See :func:`check_export_model_diff`. - export_options: See :func:`check_export_model_diff`. - - Returns: - jit_graph (_C.Graph): A traced JIT graph. - """ - training = export_options.training - verbose = export_options.verbose - - with utils.exporter_context(model, training, verbose): - export_inputs = _prepare_input_for_export(args, kwargs) - model = utils._pre_trace_quant_model(model, export_inputs) - jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) - return jit_graph - - -def _onnx_graph_from_model( - model: torch.nn.Module | torch.jit.ScriptModule, - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - export_options: _experimental.ExportOptions, -) -> _C.Graph: - """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. - - Args: - model: See :func:`check_export_model_diff`. - args: See :func:`check_export_model_diff`. - kwargs: See :func:`check_export_model_diff`. - export_options: See :func:`check_export_model_diff`. - - Returns: - onnx_graph (_C.Graph): An ONNX JIT graph. - """ - # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 - opset_version = export_options.opset_version - operator_export_type = export_options.operator_export_type - export_modules_as_functions = export_options.export_modules_as_functions - training = export_options.training - verbose = export_options.verbose - dynamic_axes = export_options.dynamic_axes - input_names = export_options.input_names - output_names = export_options.output_names - - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - - utils._setup_trace_module_map(model, export_modules_as_functions) - - if not operator_export_type: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX - - GLOBALS.export_onnx_opset_version = opset_version - GLOBALS.operator_export_type = operator_export_type - - with utils.exporter_context(model, training, verbose): - do_constant_folding = utils._decide_constant_folding( - export_options.do_constant_folding, operator_export_type, training - ) - - if dynamic_axes is None: - dynamic_axes = {} - utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) - - export_inputs = _prepare_input_for_export(args, kwargs) - export_inputs = utils._decide_input_format(model, export_inputs) - onnx_graph, _, _ = utils._model_to_graph( - model, - export_inputs, - verbose, - input_names, - output_names, - operator_export_type, - do_constant_folding, - training=training, - dynamic_axes=dynamic_axes, - ) - - return onnx_graph - - -def _onnx_graph_from_aten_graph( - graph: torch.Graph, - export_options: _experimental.ExportOptions, - params_dict: dict[str, Any] | None = None, -) -> tuple[torch.Graph, dict[str, Any]]: - if params_dict is None: - params_dict = {} - operator_export_type = export_options.operator_export_type - dynamic_axes = export_options.dynamic_axes or {} - input_names = export_options.input_names - training = export_options.training - do_constant_folding = export_options.do_constant_folding - opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET - - GLOBALS.export_onnx_opset_version = opset_version - GLOBALS.operator_export_type = operator_export_type - - do_constant_folding = utils._decide_constant_folding( - do_constant_folding, operator_export_type, training - ) - - # TODO: Below is doing aten graph to onnx. It should be abstracted as a - # function in torch/onnx/utils.py. - graph = graph.copy() - graph = utils._optimize_graph( - graph, - operator_export_type, - params_dict=params_dict, - dynamic_axes=dynamic_axes, - input_names=input_names, - ) - - if training is None or training == _C_onnx.TrainingMode.EVAL: - params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) - - if ( - do_constant_folding - and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET - ): - params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - - if GLOBALS.onnx_shape_inference: - _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) - - params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) - - # For ONNX opset < 9, constants only have three data types: float16, float, double. - # In this pass transform constants of other data types to float/double + cast operator. - if opset_version < 9: - _C._jit_pass_onnx_cast_all_constant_to_floating(graph) - - params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) - _C._jit_decay_packed_param_input_types(graph) - - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - - if export_options.verbose: - print("ONNX graph: ", graph) - - return graph, params_dict - - -def _onnx_proto_from_onnx_graph( - onnx_graph: torch.Graph, - export_options: _experimental.ExportOptions, - params_dict: dict[str, Any], -) -> tuple[bytes, Mapping[str, bytes]]: - opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET - dynamic_axes = export_options.dynamic_axes or {} - operator_export_type = export_options.operator_export_type - val_keep_init_as_ip = utils._decide_keep_init_as_input( - export_options.keep_initializers_as_inputs, - operator_export_type, - opset_version, - ) - val_add_node_names = utils._decide_add_node_names(True, operator_export_type) - custom_opsets = export_options.custom_opsets or {} - - proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] - params_dict, - opset_version, - dynamic_axes, - False, - operator_export_type, - not export_options.verbose, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - "", - {}, - ) - - return proto, export_map - - -def check_export_model_diff( - model: torch.nn.Module | torch.jit.ScriptModule, - test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], - export_options: _experimental.ExportOptions | None = None, -) -> str: - """Verify exported model discrepancy between different groups of inputs. - - A graph is exported for each group of inputs. The exported graphs are then compared - to each other, and discrepancies of first pair of nodes are reported. This function - first checks the jit graph. If no discrepancies were found, it then checks the onnx - graph. - - Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless - of the inputs used for exporting. A discrepancy implies the graph exported is - not accurate when run on other groups of inputs, which will typically results in - runtime errors or mismatching output. - - Args: - model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. - test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence - of input groups to be used to export the model. Each input group is a pair of - (args, kwargs). - export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions - object that controls the export behavior. - - Returns: - str: A string containing the diff of the exported models. - """ - export_options = ( - _experimental.ExportOptions() if export_options is None else export_options - ) - - jit_diff_report = _check_graph_diff( - model, test_input_groups, export_options, _traced_graph_from_model - ) - if jit_diff_report: - return jit_diff_report - - return _check_graph_diff( - model, test_input_groups, export_options, _onnx_graph_from_model - ) - - -@typing_extensions.deprecated( - "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model", - category=None, -) -def verify( - model: _ModelType, - input_args: _InputArgsType, - input_kwargs: _InputKwargsType | None = None, - do_constant_folding: bool = True, - dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] - | None = None, - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, - opset_version: int | None = None, - keep_initializers_as_inputs: bool = True, - verbose: bool = False, - fixed_batch_size: bool = False, - use_external_data: bool = False, - additional_test_inputs: Sequence[_InputArgsType] | None = None, - options: VerificationOptions | None = None, -): - """Verify model export to ONNX against original PyTorch model. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - - Args: - model: See :func:`torch.onnx.export`. - input_args: See :func:`torch.onnx.export`. - input_kwargs: See :func:`torch.onnx.export`. - do_constant_folding: See :func:`torch.onnx.export`. - dynamic_axes: See :func:`torch.onnx.export`. - input_names: See :func:`torch.onnx.export`. - output_names: See :func:`torch.onnx.export`. - training: See :func:`torch.onnx.export`. - opset_version: See :func:`torch.onnx.export`. - keep_initializers_as_inputs: See :func:`torch.onnx.export`. - verbose: See :func:`torch.onnx.export`. - fixed_batch_size: Legacy argument, used only by rnn test cases. - use_external_data: Explicitly specify whether to export the model with external data. - additional_test_inputs: List of tuples. Each tuple is a group of - input arguments to test. Currently only ``*args`` are supported. - options: A VerificationOptions object that controls the verification behavior. - - Raises: - AssertionError: if outputs from ONNX model and PyTorch model are not - equal up to specified precision. - ValueError: if arguments provided are invalid. - """ - if options is None: - options = VerificationOptions() - - if training == torch.onnx.TrainingMode.TRAINING: - model.train() - elif training == torch.onnx.TrainingMode.EVAL: - model.eval() - with torch.no_grad(), contextlib.ExitStack() as stack: - model_f: str | io.BytesIO = io.BytesIO() - if use_external_data: - tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) - model_f = os.path.join(tmpdir_path, "model.onnx") - - inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) - - # TODO(#77679): remove this and treat mutating model separately. - model_copy = _try_clone_model(model) - utils._export( - model, - inputs_for_export, - model_f, - opset_version=opset_version, - do_constant_folding=do_constant_folding, - keep_initializers_as_inputs=keep_initializers_as_inputs, - dynamic_axes=dynamic_axes, - input_names=input_names, - output_names=output_names, - fixed_batch_size=fixed_batch_size, - training=training, - verbose=verbose, - ) - - _compare_onnx_pytorch_model( - pt_model=model_copy, - onnx_model_f=model_f, - input_args=input_args, - input_kwargs=input_kwargs, - additional_test_inputs=additional_test_inputs, - options=options, - ) - - -@typing_extensions.deprecated( - "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model" -) -def verify_aten_graph( - graph: torch.Graph, - input_args: tuple[Any, ...], - export_options: _experimental.ExportOptions, - params_dict: dict[str, Any] | None = None, - verification_options: VerificationOptions | None = None, -) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: - """Verify aten graph export to ONNX against original PyTorch model. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - """ - if verification_options is None: - verification_options = VerificationOptions() - if params_dict is None: - params_dict = {} - - original_jit_graph = graph - graph = graph.copy() - - # Execute aten graph and get reference torch jit outputs. - graph_inputs = list(graph.inputs()) - jit_inputs = tuple([arg for arg in input_args if arg is not None]) - weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] - assert all(w is not None for w in weights) - # TODO: Only copy the argument if mutation is detected in Graph. - jit_inputs = copy.deepcopy(jit_inputs) - jit_input_and_parameters = jit_inputs + tuple(weights) - jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] - if not isinstance(jit_outs, (list, tuple)): - jit_outs = [jit_outs] - - # Convert aten graph to onnx graph. - graph, onnx_params_dict = _onnx_graph_from_aten_graph( - graph, export_options, params_dict - ) - - proto, export_map = _onnx_proto_from_onnx_graph( - graph, export_options, onnx_params_dict - ) - model_f: str | io.BytesIO = io.BytesIO() - onnx_proto_utils._export_file(proto, model_f, export_map) - - # NOTE: Verification is unstable. Try catch to emit information for debugging. - try: - # NOTE: Input might be dce'ed, so we need to remove those from the input args. - new_input_names = {v.debugName() for v in graph.inputs()} - new_input_args = [] - for v, arg in zip(original_jit_graph.inputs(), input_args): - if v.debugName() in new_input_names: - new_input_args.append(arg) - input_args = tuple(new_input_args) - - onnx_inputs = _prepare_input_for_onnx( - input_args, - {}, - verification_options.remained_onnx_input_idx, - verification_options.flatten, - ) - - onnx_session = _onnx_backend_session(model_f, verification_options.backend) - onnx_outs = _run_onnx(onnx_session, onnx_inputs) - del onnx_session # To free device memory - - try: - _compare_onnx_pytorch_outputs( - onnx_outs=onnx_outs, - pt_outs=jit_outs, - options=verification_options, - ) - except AssertionError as e: - return e, graph, jit_outs, onnx_outs - - return None, graph, jit_outs, onnx_outs - - except Exception as e: - print("Unexpected error during verification.") - print("jit graph: ", original_jit_graph) - print("onnx graph: ", graph) - raise e - - -class GraphInfoPrettyPrinter: - graph_info: GraphInfo | None - upper_printer: GraphInfoPrettyPrinter | None - lower_printer: GraphInfoPrettyPrinter | None - - graph_str_lambdas: Mapping[int, str] - connector_str_lambdas: Mapping[int, str] - children_str_lambdas: Mapping[int, str] - - def __init__(self, graph_info: GraphInfo | None): - self.graph_info = graph_info - if ( - graph_info is not None - and graph_info.upper_graph_info is not None - and graph_info.lower_graph_info is not None - ): - self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) - self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) - else: - self.upper_printer = None - self.lower_printer = None - - def _total_rows(self) -> int: - if self.graph_info is None: - return 1 - if self.upper_printer and self.lower_printer: - return ( - self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 - ) - return 2 # Two lines: node count + id. - - def _node_count_segment_str(self) -> str: - if self.graph_info is None: - return "..." - node_count = self.graph_info.essential_node_count() - has_mismatch = self.graph_info.has_mismatch() - error_node_kind = ( - f"({self.graph_info.essential_node_kinds().pop()})" - if node_count == 1 and has_mismatch - else "" - ) - - return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" - - def _graph_id_segment_str(self) -> str: - if self.graph_info is None: - return "" - return f"id: {self.graph_info.id}" - - def _max_segment_columns(self) -> int: - return max( - map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) - ) - - def _graph_segment_str_at_line(self, line: int) -> str: - """Get the string representation of the graph segment at the given line.""" - if line == 0: - result_str = self._node_count_segment_str() - result_str += " " * (self._max_segment_columns() - len(result_str)) - return result_str - if line == 1: - result_str = self._graph_id_segment_str() - result_str += " " * (self._max_segment_columns() - len(result_str)) - return result_str - if 0 <= line < self._total_rows(): - return " " * self._max_segment_columns() - return "" - - def _connector_segment_str_at_line(self, line: int) -> str: - """Get the connector segment string at the given line.""" - if self.upper_printer is None and self.lower_printer is None: - return "" - upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 - lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 - if line == 0: - return " __" - elif line < upper_total_rows + 1: - return " | " - elif line == upper_total_rows + 1: - return " |__" - elif line < upper_total_rows + lower_total_rows + 1: - return " " - return "" - - def _children_str_at_line(self, line: int) -> str: - """Get the string representation of the children at the given line. - - Recursively calls `_str_at_line` on children nodes. - """ - if self.upper_printer is None and self.lower_printer is None: - return "" - upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 - lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 - if 0 <= line < upper_total_rows: - return ( - self.upper_printer._str_at_line(line) if self.upper_printer else "..." - ) - elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: - return ( - self.lower_printer._str_at_line(line - upper_total_rows - 1) - if self.lower_printer - else "..." - ) - return "" - - def _str_at_line(self, line: int) -> str: - """Get the string representation of the graph at the given line.""" - return ( - self._graph_segment_str_at_line(line) - + self._connector_segment_str_at_line(line) - + self._children_str_at_line(line) - ) - - def pretty_print(self): - if self.graph_info is None: - print(None) - return - # Print tree. - print(" Tree: ".center(80, "=")) - total_rows = self._total_rows() - for line in range(total_rows): - print(self._str_at_line(line).rstrip()) - if self.graph_info.has_mismatch(): - # Summarize leaf subgraphs with mismatch. - print(" Mismatch leaf subgraphs: ".center(80, "=")) - print( - [ - graph_info.id - for graph_info in self.graph_info.all_mismatch_leaf_graph_info() - ] - ) - # Summarize node kinds with mismatch. - mismatch_node_kinds: dict[str, int] = {} - for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): - node_kinds = graph_info.essential_node_kinds() - if len(node_kinds) == 1: - node_kind = node_kinds.pop() - mismatch_node_kinds[node_kind] = ( - mismatch_node_kinds.get(node_kind, 0) + 1 - ) - print(" Mismatch node kinds: ".center(80, "=")) - print(mismatch_node_kinds) - else: - print(" No mismatch found. ".center(80, "=")) - - -class OnnxTestCaseRepro: - def __init__(self, repro_dir): - self.repro_dir = repro_dir - self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( - repro_dir - ) - - @classmethod - def create_test_case_repro( - cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None - ): - """Create a repro under "{dir}/test_{name}" for an ONNX test case. - - The test case contains the model and the inputs/outputs data. The directory - structure is as follows: - - dir - \u251c\u2500\u2500 test_ - \u2502 \u251c\u2500\u2500 model.onnx - \u2502 \u2514\u2500\u2500 test_data_set_0 - \u2502 \u251c\u2500\u2500 input_0.pb - \u2502 \u251c\u2500\u2500 input_1.pb - \u2502 \u251c\u2500\u2500 output_0.pb - \u2502 \u2514\u2500\u2500 output_1.pb - - Args: - proto: ONNX model proto. - inputs: Inputs to the model. - outputs: Outputs of the model. - dir: Directory to save the repro. - name: Name of the test case. If not specified, a name based on current time - will be generated. - Returns: - Path to the repro. - """ - if name is None: - name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") - return onnx_proto_utils.export_as_test_case( - proto, - _to_numpy(inputs), - _to_numpy(outputs), - name, - dir, - ) - - def validate(self, options: VerificationOptions): - """Run the ONNX test case with options.backend, and compare with the expected outputs. - - Args: - options: Options for validation. - - Raise: - AssertionError: if outputs from options.backend and expected outputs are not - equal up to specified precision. - """ - onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) - run_outputs = onnx_session.run(None, self.inputs) - if hasattr(onnx_session, "get_outputs"): - output_names = [o.name for o in onnx_session.get_outputs()] - elif hasattr(onnx_session, "output_names"): - output_names = onnx_session.output_names - else: - raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") - expected_outs = [self.outputs[name] for name in output_names] - _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) - - -@typing_extensions.deprecated( - "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model" -) -@dataclasses.dataclass -class GraphInfo: - """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - """ - - graph: torch.Graph - input_args: tuple[Any, ...] - params_dict: dict[str, Any] - export_options: _experimental.ExportOptions = dataclasses.field( - default_factory=_experimental.ExportOptions - ) - mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) - pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) - upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) - lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) - id: str = dataclasses.field(default="") - _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) - - _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( - {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} - ) - - def clear(self): - """Clear states and results of previous verification.""" - self.mismatch_error = None - self.pt_outs = None - self._onnx_graph = None - self.upper_graph_info = None - self.lower_graph_info = None - - def pretty_print_tree(self): - """Pretty print `GraphInfo` tree. - - Each node represents a subgraph, showing the number of nodes in the subgraph and - a check mark if the subgraph has output mismatch between torch and ONNX. - - The id of the subgraph is shown under the node. The `GraphInfo` object for any - subgraph can be retrieved by calling `graph_info.find_partition(id)`. - - Example:: - - ==================================== Tree: ===================================== - 5 X __2 X __1 \u2713 - id: | id: 0 | id: 00 - | | - | |__1 X (aten::relu) - | id: 01 - | - |__3 X __1 \u2713 - id: 1 | id: 10 - | - |__2 X __1 X (aten::relu) - id: 11 | id: 110 - | - |__1 \u2713 - id: 111 - =========================== Mismatch leaf subgraphs: =========================== - ['01', '110'] - ============================= Mismatch node kinds: ============================= - {'aten::relu': 2} - - """ - GraphInfoPrettyPrinter(self).pretty_print() - - def pretty_print_mismatch(self, graph: bool = False): - """Pretty print details of the mismatch between torch and ONNX. - - Args: - graph: If True, print the ATen JIT graph and ONNX graph. - """ - print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) - if graph: - print(" ATen JIT graph ".center(80, "=")) - # TODO: A more compact graph printer. - # * Drop stride, grad, device information. - # * Show source location on a separate line. - print(self.graph) - if self._onnx_graph is not None: - print(" ONNX graph ".center(80, "=")) - print(self._onnx_graph) - if self.has_mismatch(): - print(" Mismatch error ".center(80, "=")) - print(self.mismatch_error) - else: - print(" No mismatch ".center(80, "=")) - - def has_mismatch(self) -> bool: - """Return True if the subgraph has output mismatch between torch and ONNX.""" - return self.mismatch_error is not None - - def essential_node_count(self) -> int: - """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" - return sum( - 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS - ) - - def essential_node_kinds(self) -> set[str]: - """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" - return { - n.kind() - for n in self.graph.nodes() - if n.kind() not in self._EXCLUDED_NODE_KINDS - } - - def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: - """Return a list of all leaf `GraphInfo` objects that have mismatch.""" - if not self.has_mismatch(): - return [] - - no_mismatch_children = ( - self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() - ) and ( - self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() - ) - - if no_mismatch_children: - return [self] - - results = [] - if self.upper_graph_info is not None: - results += self.upper_graph_info.all_mismatch_leaf_graph_info() - if self.lower_graph_info is not None: - results += self.lower_graph_info.all_mismatch_leaf_graph_info() - - return results - - def find_partition(self, id: str) -> GraphInfo | None: - """Find the `GraphInfo` object with the given id.""" - if id == self.id: - return self - current_length = len(self.id) - if len(id) > current_length: - if id[current_length] == "0" and self.upper_graph_info is not None: - return self.upper_graph_info.find_partition(id) - elif id[current_length] == "1" and self.lower_graph_info is not None: - return self.lower_graph_info.find_partition(id) - return None - - def export_repro( - self, repro_dir: str | None = None, name: str | None = None - ) -> str: - """Export the subgraph to ONNX along with the input/output data for repro. - - The repro directory will contain the following files:: - - dir - \u251c\u2500\u2500 test_ - \u2502 \u251c\u2500\u2500 model.onnx - \u2502 \u2514\u2500\u2500 test_data_set_0 - \u2502 \u251c\u2500\u2500 input_0.pb - \u2502 \u251c\u2500\u2500 input_1.pb - \u2502 \u251c\u2500\u2500 output_0.pb - \u2502 \u2514\u2500\u2500 output_1.pb - - Args: - repro_dir: The directory to export the repro files to. Defaults to current - working directory if None. - name: An optional name for the test case folder: "test_{name}". - - Returns: - The path to the exported repro directory. - """ - - if repro_dir is None: - repro_dir = os.getcwd() - repro_dir = os.path.join(repro_dir, "onnx_debug") - - onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( - self.graph, self.export_options, self.params_dict - ) - - proto, _ = _onnx_proto_from_onnx_graph( - onnx_graph, self.export_options, onnx_params_dict - ) - return OnnxTestCaseRepro.create_test_case_repro( - proto, self.input_args, self.pt_outs, repro_dir, name - ) - - def _graph_partition_pivot(self) -> int: - """Find the pivot index to partition the graph. - - The pivot is the node that splits the graph into two parts. Each part should - have the similar amount of nodes, excluding non essential ops, defined in - `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. - If the graph has an odd number of nodes, the upper part will have one more node. - If the graph does not have any node that can be partitioned, return -1. - - Returns: - The index of the pivot node. - """ - included_node_indices = [ - i - for i, n in enumerate(self.graph.nodes()) - if n.kind() not in self._EXCLUDED_NODE_KINDS - ] - half_idx = len(included_node_indices) // 2 - 1 - if half_idx >= 0 and len(included_node_indices) > half_idx: - return included_node_indices[half_idx] + 1 - return -1 - - def _partition_upper_graph(self) -> torch.Graph: - pivot = self._graph_partition_pivot() - if pivot == -1: - return torch.Graph() - graph = self.graph.copy() # Copy to not mutate parent graph. - original_outputs = list(graph.outputs()) - - def _process_bridge_value_for_upper( - new_outputs: list[torch.Value], bridge_value: torch.Value - ) -> torch.Value: - # Add bridge values as upper graph outputs. - new_outputs.append(bridge_value) - return bridge_value - - new_outputs: list[torch.Value] = [] - process_bridge_value_for_upper = functools.partial( - _process_bridge_value_for_upper, new_outputs - ) - _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( - graph, pivot, process_bridge_value_for_upper - ) - - for _ in enumerate(original_outputs): - graph.eraseOutput(0) - for output in new_outputs: - graph.registerOutput(output) - - for node in reversed(dropped_nodes): - node.destroy() - - for i, input in reversed(list(enumerate(list(graph.inputs())))): - if ( - not _has_uses_by_nodes(input, complete_upper_nodes_set) - and input not in new_outputs - ): - try: - graph.eraseInput(i) - except RuntimeError as e: - print(input, graph) - raise e - - return graph - - def _partition_lower_graph(self) -> torch.Graph: - pivot = self._graph_partition_pivot() - if pivot == -1: - return torch.Graph() - graph = self.graph.copy() # Copy to not mutate parent graph. - original_outputs = list(graph.outputs()) - original_inputs = list(graph.inputs()) - - def _process_bridge_value_for_lower( - graph: torch.Graph, bridge_value: torch.Value - ) -> torch.Value: - # Add bridge values as lower graph inputs. - new_input = graph.addInput() - bridge_value.replaceAllUsesWith(new_input) - new_input.copyMetadata(bridge_value) - return new_input - - process_bridge_value_for_lower = functools.partial( - _process_bridge_value_for_lower, graph - ) - - upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( - graph, pivot, process_bridge_value_for_lower - ) - - new_outputs = [ - output for output in original_outputs if _produced_by(output, lower_nodes) - ] - for _ in enumerate(original_outputs): - graph.eraseOutput(0) - for output in new_outputs: - graph.registerOutput(output) - - for input in original_inputs: - if _has_uses_by_nodes(input, complete_lower_nodes_set): - new_input = graph.addInput() - input.replaceAllUsesWith(new_input) - new_input.copyMetadata(input) - - for node in reversed(upper_nodes): - if node not in complete_lower_nodes_set: - try: - node.destroy() - except RuntimeError as e: - print(node, graph) - raise e - - for _ in original_inputs: - graph.eraseInput(0) - - return graph - - def _partition_node( - self, - node: torch.Node, - complete_upper_nodes_set: set[torch.Node], - complete_lower_nodes_set: set[torch.Node], - original_graph_outputs: set[torch.Value], - covered_bridge_values: set[torch.Value], - process_bridge_value: Callable[[torch.Value], torch.Value], - ): - if node in complete_lower_nodes_set: - return - - if ( - _node_has_uses_by(node, complete_lower_nodes_set) - and node.kind() in self._EXCLUDED_NODE_KINDS - ): - complete_lower_nodes_set.update(_all_nodes([node])) - for input in node.inputs(): - if input in covered_bridge_values: - continue - self._partition_node( - input.node(), - complete_upper_nodes_set, - complete_lower_nodes_set, - original_graph_outputs, - covered_bridge_values, - process_bridge_value, - ) - else: - for output in node.outputs(): - if output in covered_bridge_values: - continue - if ( - _has_uses_by_nodes(output, complete_lower_nodes_set) - or output in original_graph_outputs - ): - covered_bridge_values.add(process_bridge_value(output)) - - def _partition_nodes( - self, - graph: torch.Graph, - pivot: int, - process_bridge_value: Callable[[torch.Value], torch.Value], - ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: - nodes = list(graph.nodes()) - upper_nodes = nodes[:pivot] - lower_nodes = nodes[pivot:] - # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter - # recursively contains nodes in subblock of `upper_nodes`. - # The same applies for `lower_nodes` and `complete_lower_nodes_set`. - # With addition that `complete_lower_nodes_set` will include nodes that - # are determined to be copied from `upper_nodes` to `lower_nodes`. - complete_upper_nodes_set = _all_nodes(upper_nodes) - complete_lower_nodes_set = _all_nodes(lower_nodes) - original_graph_outputs = set(graph.outputs()) - # Bridge values are values produced from upper graph, and consumed - # by lower graph. These values need to be become upper graph outputs - # and lower graph inputs, to bridge the interaction. - # Start with all graph inputs marked as covered. If any graph input is - # needed by lower graph, just keep it in lower graph inputs later. - covered_bridge_values = set(graph.inputs()) - for node in upper_nodes: - self._partition_node( - node, - complete_upper_nodes_set, - complete_lower_nodes_set, - original_graph_outputs, - covered_bridge_values, - process_bridge_value, - ) - return ( - upper_nodes, - lower_nodes, - complete_upper_nodes_set, - complete_lower_nodes_set, - ) - - def _bridge_kwargs(self): - pt_outs = self.pt_outs - graph_outputs = list(self.graph.outputs()) - assert pt_outs is not None - assert len(graph_outputs) == len(pt_outs), ( - f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" - ) - return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} - - def _args_and_params_for_partition_graph( - self, - graph: torch.Graph, - bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], - full_kwargs: Mapping[str, torch.Tensor], - full_params: Mapping[str, torch.Tensor], - ): - input_names = [input.debugName() for input in graph.inputs()] - args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) - args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) - params = {k: full_params[k] for k in input_names if k in full_params} - assert len(args) + len(params) == len(input_names), ( - f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" - ) - return args, params - - def verify_export( - self, options: VerificationOptions - ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: - """ - Verify the export from TorchScript IR graph to ONNX. - - Export the TorchScript IR graph to ONNX, with the inputs, parameters and export - options recorded in this object. Then verify the exported ONNX graph against - the original TorchScript IR graph under the provided verification options. - - Args: - options: The verification options. - - Returns: - error: The AssertionError raised during the verification. Returns None if no - error is raised. - onnx_graph: The exported ONNX graph in TorchScript IR format. - onnx_outs: The outputs from running exported ONNX model under the onnx - backend in `options`. - pt_outs: The outputs from running the TorchScript IR graph. - """ - return verify_aten_graph( - self.graph, - input_args=self.input_args, - params_dict=self.params_dict, - export_options=self.export_options, - verification_options=options, - ) - - def find_mismatch( - self, - options: VerificationOptions | None = None, - ): - """ - Find all mismatches between the TorchScript IR graph and the exported onnx model. - - Binary searches the model graph to find the minimal subgraph that exhibits the - mismatch. A `GraphInfo` object is created for each subgraph, recording the test - inputs and export options, as well as the validation results. - - Args: - options: The verification options. - """ - self.clear() - - if options is None: - options = VerificationOptions() - - if self.export_options.verbose: - print(self.graph) - - if len(list(self.graph.outputs())) == 0: - return - - assert len(self.input_args) + len(self.params_dict) == len( - list(self.graph.inputs()) - ), ( - f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " - f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." - ) - - self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( - options - ) - - if self.mismatch_error is None: - # No mismatch found in graph. - return - - if self.essential_node_count() <= 1: - # Reached leaf node, no more partitioning. - return - - full_kwargs = { - k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) - } - full_params = self.params_dict - - upper_graph = self._partition_upper_graph() - upper_args, upper_params = self._args_and_params_for_partition_graph( - upper_graph, {}, full_kwargs, full_params - ) - self.upper_graph_info = GraphInfo( - upper_graph, - upper_args, - upper_params, - self.export_options, - id=self.id + "0", - ) - - self.upper_graph_info.find_mismatch(options) - - bridge_kwargs = self.upper_graph_info._bridge_kwargs() - lower_graph = self._partition_lower_graph() - lower_args, lower_params = self._args_and_params_for_partition_graph( - lower_graph, bridge_kwargs, full_kwargs, full_params - ) - self.lower_graph_info = GraphInfo( - lower_graph, - lower_args, - lower_params, - self.export_options, - id=self.id + "1", - ) - - self.lower_graph_info.find_mismatch(options) - - -def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: - all_nodes = set(nodes) - for n in nodes: - for b in n.blocks(): - all_nodes.update(_all_nodes(list(b.nodes()))) - return all_nodes - - -def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: - return any(use.user in nodes for use in value.uses()) - - -def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: - for output in node.outputs(): - if _has_uses_by_nodes(output, nodes): - return True - return False - - -def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: - return value.node() in nodes - - -@typing_extensions.deprecated( - "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " - "and use ONNXProgram to test the ONNX model" -) -def find_mismatch( - model: torch.nn.Module | torch.jit.ScriptModule, - input_args: tuple[Any, ...], - do_constant_folding: bool = True, - training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, - opset_version: int | None = None, - keep_initializers_as_inputs: bool = True, - verbose: bool = False, - options: VerificationOptions | None = None, -) -> GraphInfo: - r"""Find all mismatches between the original model and the exported model. - - .. deprecated:: 2.7 - Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned - ``ONNXProgram`` to test the ONNX model. - - Experimental. The API is subject to change. - - This tool helps debug the mismatch between the original PyTorch model and exported - ONNX model. It binary searches the model graph to find the minimal subgraph that - exhibits the mismatch. - - Args: - model: The model to be exported. - input_args: The input arguments to the model. - do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. - training: Same as `training` in :func:`torch.onnx.export`. - opset_version: Same as `opset_version` in :func:`torch.onnx.export`. - keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. - verbose: Same as `verbose` in :func:`torch.onnx.export`. - options: The options for the mismatch verification. - - Returns: - A GraphInfo object that contains the mismatch information. - - Example:: - - >>> import torch - >>> import torch.onnx.verification - >>> torch.manual_seed(0) - >>> opset_version = 15 - >>> # Define a custom symbolic function for aten::relu. - >>> # The custom symbolic function is incorrect, which will result in mismatches. - >>> def incorrect_relu_symbolic_function(g, self): - ... return self - >>> torch.onnx.register_custom_op_symbolic( - ... "aten::relu", - ... incorrect_relu_symbolic_function, - ... opset_version=opset_version, - ... ) - >>> class Model(torch.nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... self.layers = torch.nn.Sequential( - ... torch.nn.Linear(3, 4), - ... torch.nn.ReLU(), - ... torch.nn.Linear(4, 5), - ... torch.nn.ReLU(), - ... torch.nn.Linear(5, 6), - ... ) - ... def forward(self, x): - ... return self.layers(x) - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> graph_info = torch.onnx.verification.find_mismatch( - ... Model(), - ... (torch.randn(2, 3),), - ... opset_version=opset_version, - ... ) - ===================== Mismatch info for graph partition : ====================== - ================================ Mismatch error ================================ - Tensor-likes are not close! - Mismatched elements: 12 / 12 (100.0%) - Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) - Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) - ==================================== Tree: ===================================== - 5 X __2 X __1 \u2713 - id: | id: 0 | id: 00 - | | - | |__1 X (aten::relu) - | id: 01 - | - |__3 X __1 \u2713 - id: 1 | id: 10 - | - |__2 X __1 X (aten::relu) - id: 11 | id: 110 - | - |__1 \u2713 - id: 111 - =========================== Mismatch leaf subgraphs: =========================== - ['01', '110'] - ============================= Mismatch node kinds: ============================= - {'aten::relu': 2} - - """ - if options is None: - options = VerificationOptions() - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - """From aten graph, do binary search on graph partition to find operator export discrepancy.""" - # TODO: Copied from utils.py `export` until `_optimize_graph`. - if training == torch.onnx.TrainingMode.TRAINING: - model.train() - elif training == torch.onnx.TrainingMode.EVAL: - model.eval() - with torch.no_grad(): - inputs_for_export = _prepare_input_for_export(input_args, {}) - args = utils._decide_input_format(model, inputs_for_export) - - model = utils._pre_trace_quant_model(model, args) - graph, params, _torch_out, _module = utils._create_jit_graph(model, args) - params_dict = utils._get_named_param_dict(graph, params) - - utils._apply_friendly_debug_names(graph, params_dict) - - graph_info = GraphInfo( - graph, - input_args, - params_dict, - _experimental.ExportOptions( - do_constant_folding=do_constant_folding, - training=training, - opset_version=opset_version, - keep_initializers_as_inputs=keep_initializers_as_inputs, - verbose=verbose, - ), - ) - graph_info.find_mismatch(options) - graph_info.pretty_print_mismatch() - graph_info.pretty_print_tree() - - return graph_info