mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add symbolic_opset19.py and symbolic_opset20.py to support opset 19/20, extend opset 18 support (#118828)
Start to fix https://github.com/pytorch/pytorch/issues/114801 Co-authored-by: Thiago Crepaldi <thiagofc@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118828 Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
committed by
PyTorch MergeBot
parent
34d33df056
commit
bbe846f430
@ -77,6 +77,8 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
||||
.. py:module:: torch.onnx.symbolic_opset16
|
||||
.. py:module:: torch.onnx.symbolic_opset17
|
||||
.. py:module:: torch.onnx.symbolic_opset18
|
||||
.. py:module:: torch.onnx.symbolic_opset19
|
||||
.. py:module:: torch.onnx.symbolic_opset20
|
||||
.. py:module:: torch.onnx.symbolic_opset7
|
||||
.. py:module:: torch.onnx.symbolic_opset8
|
||||
.. py:module:: torch.onnx.symbolic_opset9
|
||||
|
@ -64,9 +64,8 @@ skipIfQuantizationBackendQNNPack = _skipper(
|
||||
|
||||
|
||||
# skips tests for all versions below min_opset_version.
|
||||
# if exporting the op is only supported after a specific version,
|
||||
# add this wrapper to prevent running the test for opset_versions
|
||||
# smaller than the currently tested opset_version
|
||||
# smaller than `min_opset_version`.
|
||||
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
|
||||
def skip_dec(func):
|
||||
@functools.wraps(func)
|
||||
@ -83,6 +82,8 @@ def skipIfUnsupportedMinOpsetVersion(min_opset_version):
|
||||
|
||||
|
||||
# skips tests for all versions above max_opset_version.
|
||||
# add this wrapper to prevent running the test for opset_versions
|
||||
# higher than `max_opset_version`.
|
||||
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
|
||||
def skip_dec(func):
|
||||
@functools.wraps(func)
|
||||
|
@ -480,43 +480,157 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
|
||||
x = torch.randn(20, 16, 50)
|
||||
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
|
||||
|
||||
def test_grid_sample(self):
|
||||
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
|
||||
ops = {16: [{"op_name": "GridSample"}]}
|
||||
|
||||
def test_affine_grid(self):
|
||||
class MyModule(Module):
|
||||
def forward(self, x, grid, mode, padding_mode, align_corers):
|
||||
return torch.nn.functional.grid_sample(
|
||||
x, grid, mode, padding_mode, align_corners
|
||||
def __init__(self, align_corners):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, theta, size):
|
||||
return torch.nn.functional.affine_grid(
|
||||
theta, size, align_corners=self.align_corners
|
||||
)
|
||||
|
||||
for mode, padding_mode, align_corners in itertools.product(
|
||||
("bilinear", "nearest", "bicubic"),
|
||||
("zeros", "border", "reflection"),
|
||||
opset_version = 20
|
||||
ops_2d = {
|
||||
opset_version: [
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Concat"},
|
||||
{"op_name": "AffineGrid"},
|
||||
]
|
||||
}
|
||||
|
||||
ops_3d = {
|
||||
opset_version: [
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name": "Unsqueeze"},
|
||||
{"op_name": "Concat"},
|
||||
{"op_name": "AffineGrid"},
|
||||
]
|
||||
}
|
||||
# 2D affine
|
||||
theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
|
||||
size_2d = torch.Size([1, 1, 2, 2])
|
||||
# 3D affine
|
||||
theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
|
||||
size_3d = torch.Size([1, 1, 2, 2, 2])
|
||||
|
||||
for inputs, align_corners in itertools.product(
|
||||
((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
|
||||
(True, False),
|
||||
):
|
||||
theta, size, ops = inputs
|
||||
args = (
|
||||
torch.randn(n, c, h_in, w_in), # x
|
||||
torch.randn(n, h_out, w_out, 2), # grid,
|
||||
mode,
|
||||
padding_mode,
|
||||
align_corners,
|
||||
theta,
|
||||
size,
|
||||
)
|
||||
check_onnx_opsets_operator(
|
||||
MyModule(),
|
||||
MyModule(align_corners=align_corners),
|
||||
args,
|
||||
ops,
|
||||
opset_versions=[16],
|
||||
opset_versions=[opset_version],
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
)
|
||||
check_onnx_opsets_operator(
|
||||
MyModule(),
|
||||
MyModule(align_corners=align_corners),
|
||||
args,
|
||||
ops,
|
||||
opset_versions=[16],
|
||||
opset_versions=[opset_version],
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
)
|
||||
|
||||
def test_grid_sample(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, mode, padding_mode, align_corners):
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.padding_mode = padding_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x, grid):
|
||||
return torch.nn.functional.grid_sample(
|
||||
x,
|
||||
grid,
|
||||
mode=self.mode,
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
|
||||
for mode, padding_mode, align_corners, opset_version in itertools.product(
|
||||
("bilinear", "nearest", "bicubic"),
|
||||
("zeros", "border", "reflection"),
|
||||
(True, False),
|
||||
(16, 20),
|
||||
):
|
||||
|
||||
def test_eval_and_training(
|
||||
ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
|
||||
):
|
||||
args = (
|
||||
torch.randn(*x_shape), # x
|
||||
torch.randn(grid), # grid,
|
||||
)
|
||||
check_onnx_opsets_operator(
|
||||
MyModule(
|
||||
mode=mode,
|
||||
padding_mode=padding_mode,
|
||||
align_corners=align_corners,
|
||||
),
|
||||
args,
|
||||
ops,
|
||||
opset_versions=[opset_version],
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
)
|
||||
check_onnx_opsets_operator(
|
||||
MyModule(
|
||||
mode=mode,
|
||||
padding_mode=padding_mode,
|
||||
align_corners=align_corners,
|
||||
),
|
||||
args,
|
||||
ops,
|
||||
opset_versions=[opset_version],
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
)
|
||||
|
||||
ops = {opset_version: [{"op_name": "GridSample"}]}
|
||||
# mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode
|
||||
n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
|
||||
test_eval_and_training(
|
||||
ops,
|
||||
opset_version,
|
||||
mode,
|
||||
padding_mode,
|
||||
align_corners,
|
||||
(n, c, h_in, w_in),
|
||||
(n, h_out, w_out, 2),
|
||||
)
|
||||
if opset_version == 20 and mode != "bicubic":
|
||||
test_eval_and_training(
|
||||
ops,
|
||||
opset_version,
|
||||
mode,
|
||||
padding_mode,
|
||||
align_corners,
|
||||
(n, c, d_in, h_in, w_in),
|
||||
(n, d_out, h_out, w_out, 3),
|
||||
)
|
||||
|
||||
def test_flatten(self):
|
||||
class MyModule(Module):
|
||||
def forward(self, x):
|
||||
|
@ -1604,7 +1604,9 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# TODO: ceil_mode is not included in the test, because of
|
||||
# https://github.com/microsoft/onnxruntime/issues/16203
|
||||
# The ORT and PyTorch has different calculation for ceil_mode (the last value).
|
||||
@skipIfUnsupportedMinOpsetVersion(19)
|
||||
# the issue requires fix in onnx(21) (https://github.com/onnx/onnx/issues/5711)
|
||||
# a fix in ORT is planned. After the fixes in place, we can add ceil_mode to the test.
|
||||
@skipIfUnsupportedMinOpsetVersion(21)
|
||||
def test_avgpool_3d_ceil(self):
|
||||
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
|
||||
x = torch.randn(20, 16, 50, 44, 31)
|
||||
@ -13393,6 +13395,113 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
|
||||
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
|
||||
|
||||
class AffineGridModule(torch.nn.Module):
|
||||
def __init__(self, align_corners) -> None:
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, theta, size):
|
||||
return torch.nn.functional.affine_grid(theta, size, self.align_corners)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(20)
|
||||
@skipScriptTest()
|
||||
@common_utils.parametrize(
|
||||
"align_corners",
|
||||
(True, False),
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"theta_params",
|
||||
(
|
||||
(
|
||||
10,
|
||||
np.array([0.3, -0.5]),
|
||||
np.array([1.5, 0.5]),
|
||||
),
|
||||
(
|
||||
60,
|
||||
np.array([-0.5, -0.5]),
|
||||
np.array([3.0, 5.5]),
|
||||
),
|
||||
),
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"size",
|
||||
([1, 1, 3, 2], [2, 10, 2, 3]),
|
||||
)
|
||||
def test_affine_grid_2d(self, align_corners, theta_params, size):
|
||||
angle, translation, scale = theta_params
|
||||
theta = np.array([], dtype=np.float32)
|
||||
for _ in range(size[0]):
|
||||
angle_radian = (angle / 180.0) * np.pi
|
||||
theta = np.append(
|
||||
theta,
|
||||
[
|
||||
np.cos(angle_radian) * scale[0],
|
||||
-np.sin(angle_radian),
|
||||
translation[0],
|
||||
np.sin(angle_radian),
|
||||
np.cos(angle_radian) * scale[1],
|
||||
translation[1],
|
||||
],
|
||||
)
|
||||
theta = theta.reshape(size[0], 2, 3)
|
||||
theta = torch.Tensor(theta)
|
||||
self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(20)
|
||||
@skipScriptTest()
|
||||
@common_utils.parametrize(
|
||||
"align_corners",
|
||||
(True, False),
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"theta_params",
|
||||
(
|
||||
(
|
||||
[10, 20],
|
||||
np.array([0.3, -0.5, 1.8]),
|
||||
np.array([1.5, 2.0, 0.5]),
|
||||
),
|
||||
(
|
||||
[60, -30],
|
||||
np.array([-0.5, -0.5, 0.3]),
|
||||
np.array([0.3, 3.0, 5.5]),
|
||||
),
|
||||
),
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"size",
|
||||
([1, 1, 3, 2, 2], [2, 10, 2, 2, 3]),
|
||||
)
|
||||
def test_affine_grid_3d(self, align_corners, theta_params, size):
|
||||
angle, translation, scale = theta_params
|
||||
theta = np.array([], dtype=np.float32)
|
||||
for _ in range(size[0]):
|
||||
angle_radian_x = (angle[0] / 180.0) * np.pi
|
||||
angle_radian_y = (angle[1] / 180.0) * np.pi
|
||||
rot_matrix_x = np.array(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[0, np.cos(angle_radian_x), -np.sin(angle_radian_x)],
|
||||
[0, np.sin(angle_radian_x), np.cos(angle_radian_x)],
|
||||
]
|
||||
)
|
||||
rot_matrix_y = np.array(
|
||||
[
|
||||
[np.cos(angle_radian_y), 0, np.sin(angle_radian_y)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)],
|
||||
]
|
||||
)
|
||||
rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y)
|
||||
rot_matrix = rot_matrix * scale.reshape(3, 1)
|
||||
rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1)
|
||||
theta = np.append(theta, rot_matrix.flatten())
|
||||
|
||||
theta = theta.reshape(size[0], 3, 4)
|
||||
theta = torch.Tensor(theta)
|
||||
self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(16)
|
||||
@common_utils.parametrize(
|
||||
"mode",
|
||||
@ -13408,7 +13517,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
name_fn=lambda align_corners: str(align_corners),
|
||||
)
|
||||
def test_grid_sample(self, mode, padding_mode, align_corners):
|
||||
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
|
||||
n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
|
||||
|
||||
atol_rtol = {}
|
||||
if (mode, padding_mode) == ("bicubic", "border"):
|
||||
if align_corners:
|
||||
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
|
||||
else:
|
||||
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
|
||||
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
|
||||
|
||||
class GridSampleModule(torch.nn.Module):
|
||||
def __init__(self, mode, padding_mode, align_corners) -> None:
|
||||
@ -13424,13 +13541,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
input, grid, self.mode, self.padding_mode, self.align_corners
|
||||
)
|
||||
|
||||
atol_rtol = {}
|
||||
if (mode, padding_mode) == ("bicubic", "border"):
|
||||
if align_corners:
|
||||
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
|
||||
else:
|
||||
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
|
||||
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
|
||||
self.run_test(
|
||||
GridSampleModule(mode, padding_mode, align_corners),
|
||||
(input, grid),
|
||||
@ -13438,8 +13548,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
)
|
||||
|
||||
# ONNX Opset 16 GridSample with 5D volumetric input is not supported.
|
||||
d_in = 2
|
||||
d_out = 3
|
||||
volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in)
|
||||
volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3)
|
||||
for mode, padding_mode, align_corners in itertools.product(
|
||||
@ -13457,9 +13565,16 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
False,
|
||||
),
|
||||
):
|
||||
with self.assertRaises(
|
||||
torch.onnx.errors.OnnxExporterError,
|
||||
):
|
||||
if self.opset_version < 20:
|
||||
with self.assertRaises(
|
||||
torch.onnx.errors.OnnxExporterError,
|
||||
):
|
||||
self.run_test(
|
||||
GridSampleModule(mode, padding_mode, align_corners),
|
||||
(volumetric_input_tensor, volumetric_grid_tensor),
|
||||
**atol_rtol,
|
||||
)
|
||||
else:
|
||||
self.run_test(
|
||||
GridSampleModule(mode, padding_mode, align_corners),
|
||||
(volumetric_input_tensor, volumetric_grid_tensor),
|
||||
|
@ -256,6 +256,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
self.assertNotEqual(node.kind(), "onnx::Cast")
|
||||
self.assertEqual(len(list(graph.nodes())), 2)
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(17)
|
||||
def test_constant_fold_reduceL2(self):
|
||||
class ReduceModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -273,6 +274,7 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
for node in graph.nodes():
|
||||
self.assertNotEqual(node.kind(), "onnx::ReduceL2")
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(17)
|
||||
def test_constant_fold_reduceL1(self):
|
||||
class NormModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -916,11 +916,15 @@ void ProcessReduceNode(Node* n) {
|
||||
size_t rank_0 = input_shape_value_0.value().size();
|
||||
std::vector<::c10::ShapeSymbol> final_shape;
|
||||
std::vector<int64_t> axes_vector(rank_0);
|
||||
if (!n->hasAttributeS("axes")) {
|
||||
std::iota(axes_vector.begin(), axes_vector.end(), 0);
|
||||
} else {
|
||||
if (n->hasAttributeS("axes")) {
|
||||
axes_vector = n->is(attr::axes);
|
||||
} else if (n->inputs().size() > 1) {
|
||||
axes_vector =
|
||||
ConstantValueMap::GetValueInto1DInt64Vector(n->input(1)->debugName());
|
||||
} else {
|
||||
std::iota(axes_vector.begin(), axes_vector.end(), 0);
|
||||
}
|
||||
|
||||
for (auto idx : c10::irange(axes_vector.size())) {
|
||||
if (axes_vector[idx] < 0) {
|
||||
axes_vector[idx] += rank_0;
|
||||
|
@ -91,7 +91,7 @@ namespace onnx_torch = ::torch::onnx;
|
||||
namespace onnx = ::ONNX_NAMESPACE;
|
||||
|
||||
const static int kInvalidOpsetVersion = -1;
|
||||
const static int kMainOpsetVersion = 19;
|
||||
const static int kMainOpsetVersion = 20;
|
||||
// Based on OP_SET_ID_VERSION_MAP in
|
||||
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
||||
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||
@ -116,6 +116,7 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||
8, // opset 17
|
||||
8, // opset 18
|
||||
9, // opset 19
|
||||
9, // opset 20
|
||||
};
|
||||
|
||||
std::string getNodeStackTraceString(const Node* n) {
|
||||
|
@ -24,6 +24,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
|
||||
symbolic_opset16,
|
||||
symbolic_opset17,
|
||||
symbolic_opset18,
|
||||
symbolic_opset19,
|
||||
symbolic_opset20,
|
||||
utils,
|
||||
)
|
||||
|
||||
@ -82,6 +84,8 @@ __all__ = [
|
||||
"symbolic_opset16",
|
||||
"symbolic_opset17",
|
||||
"symbolic_opset18",
|
||||
"symbolic_opset19",
|
||||
"symbolic_opset20",
|
||||
# Enums
|
||||
"ExportTypes",
|
||||
"OperatorExportTypes",
|
||||
|
@ -4,8 +4,8 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
||||
|
||||
ONNX_BASE_OPSET = 9
|
||||
ONNX_MIN_OPSET = 7
|
||||
ONNX_MAX_OPSET = 19
|
||||
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 17
|
||||
ONNX_MAX_OPSET = 20
|
||||
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
|
||||
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
|
||||
ONNX_DEFAULT_OPSET = 17
|
||||
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
@ -23,27 +24,11 @@ import torch._C._onnx as _C_onnx
|
||||
from torch import _C
|
||||
|
||||
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||||
from torch.onnx import _constants, _type_utils, errors
|
||||
from torch.onnx import _constants, _type_utils, errors, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, jit_utils
|
||||
from torch.types import Number
|
||||
|
||||
__all__ = [
|
||||
"args_have_same_dtype",
|
||||
"cast_pytorch_to_onnx",
|
||||
"check_training_mode",
|
||||
"dequantize_helper",
|
||||
"is_caffe2_aten_fallback",
|
||||
"is_complex_value",
|
||||
"parse_args",
|
||||
"pytorch_name_to_type",
|
||||
"quantize_helper",
|
||||
"quantized_args",
|
||||
"requantize_bias_helper",
|
||||
"scalar_name_to_pytorch",
|
||||
"scalar_type_to_onnx",
|
||||
"scalar_type_to_pytorch_type",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
# Helper functions
|
||||
@ -1714,6 +1699,526 @@ def args_have_same_dtype(args):
|
||||
return has_same_dtype
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
|
||||
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
|
||||
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
|
||||
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
|
||||
`Clip<int>(INPUT)` (opset version < 12).
|
||||
|
||||
Args:
|
||||
g (torch._C.Graph): graph to write the ONNX representation into.
|
||||
op_name (str): operator name in ONNX.
|
||||
*args (tuple): operands to the operator.
|
||||
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
|
||||
indicating the smallest opset version to trigger such casting behavior and "target_float_t"
|
||||
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
|
||||
|
||||
Returns:
|
||||
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
|
||||
"""
|
||||
opset_before = kwargs.pop("opset_before", None)
|
||||
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
|
||||
|
||||
inputs = list(args)
|
||||
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
|
||||
|
||||
require_cast = not _is_fp(inputs[0]) and (
|
||||
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
|
||||
)
|
||||
|
||||
if require_cast:
|
||||
for input in inputs:
|
||||
if input.isCompleteTensor():
|
||||
input_scalar_type = _type_utils.JitScalarType.from_value(input)
|
||||
if input_scalar_type != dtype_0:
|
||||
raise errors.SymbolicValueError(
|
||||
f"Inputs of {op_name} must have same dtype."
|
||||
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
|
||||
input,
|
||||
)
|
||||
for i, input in enumerate(inputs):
|
||||
if input.isCompleteTensor() and not _is_fp(input):
|
||||
inputs[i] = g.op(
|
||||
"Cast",
|
||||
input,
|
||||
to_i=target_float_t.onnx_type(),
|
||||
)
|
||||
|
||||
self = g.op(op_name, *inputs, **kwargs)
|
||||
|
||||
if require_cast:
|
||||
self = g.op("Cast", self, to_i=dtype_0.onnx_type())
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
self, _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
|
||||
# This check only covers traced modules where dtype is present
|
||||
# pytorch reduce-ops cast all other integral types to int64
|
||||
if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64:
|
||||
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
return self
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = _maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None or dim == tuple():
|
||||
# Dim can be 0, which will cause (not dim) == True. So we don't want to do
|
||||
# (not dim)
|
||||
# all-reduce path
|
||||
return _handle_reduce_dim_none(g, self, onnx_op_name)
|
||||
else:
|
||||
# dim-reduce path
|
||||
keepdim = _get_const(keepdim, "i", "keepdim")
|
||||
if g.opset < 18:
|
||||
desc = "is" if allow_multi_dim_support else "i"
|
||||
dim = _get_const(dim, desc, "dim")
|
||||
dim_list = dim if allow_multi_dim_support else [dim]
|
||||
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
|
||||
else:
|
||||
if _is_value(dim):
|
||||
axes = dim
|
||||
else:
|
||||
if allow_multi_dim_support:
|
||||
axes = g.op(
|
||||
"Constant", value_t=torch.tensor(dim, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
axes = g.op(
|
||||
"Constant", value_t=torch.tensor([dim], dtype=torch.long)
|
||||
)
|
||||
return g.op(onnx_op_name, self, axes, keepdims_i=keepdim)
|
||||
|
||||
return symbolic
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _overload_by_arg_count(fn):
|
||||
@functools.wraps(fn)
|
||||
@_beartype.beartype
|
||||
def wrapper(g, *args):
|
||||
overloads = fn(g, *args)
|
||||
for overload in overloads:
|
||||
arg_descriptors = overload._arg_descriptors
|
||||
if len(arg_descriptors) == len(args):
|
||||
return overload(g, *args)
|
||||
return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype_helper(
|
||||
onnx_op: str, name: str, allow_multi_dim_support: bool = True
|
||||
):
|
||||
symbolic = _reduce_op_symbolic_helper(
|
||||
onnx_op, allow_multi_dim_support=allow_multi_dim_support
|
||||
)
|
||||
|
||||
@_overload_by_arg_count
|
||||
def reduce(g, *args, **kwargs):
|
||||
@quantized_args(True)
|
||||
@parse_args("v", "none")
|
||||
def reduce_nodim(g, self, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = _get_const(dtype, "i", "dtype")
|
||||
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||||
self = g.op("Cast", self, to_i=dtype_onnx)
|
||||
elif dtype.node().kind() != "prim::Constant":
|
||||
return _unimplemented(name, "dtype", dtype)
|
||||
result = symbolic(g, self)
|
||||
if dtype_onnx is not None:
|
||||
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||||
result
|
||||
).onnx_type()
|
||||
if result_dtype_onnx != dtype_onnx:
|
||||
result = g.op("Cast", result, to_i=dtype_onnx)
|
||||
return result
|
||||
|
||||
dim_desc = "is" if allow_multi_dim_support else "i"
|
||||
|
||||
@quantized_args(True)
|
||||
@parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type]
|
||||
def reduce_dim(g, self, dim, keepdim, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = _get_const(dtype, "i", "dtype")
|
||||
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||||
self = g.op("Cast", self, to_i=dtype_onnx)
|
||||
elif dtype.node().kind() != "prim::Constant":
|
||||
return _unimplemented(name, "dtype", dtype)
|
||||
result = symbolic(g, self, dim, keepdim)
|
||||
if dtype_onnx is not None:
|
||||
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||||
result
|
||||
).onnx_type()
|
||||
if result_dtype_onnx != dtype_onnx:
|
||||
result = g.op("Cast", result, to_i=dtype_onnx)
|
||||
return result
|
||||
|
||||
return reduce_nodim, reduce_dim
|
||||
|
||||
return reduce
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
return g.op("ReduceMax", self, keepdims_i=0)
|
||||
# torch.max(input, other)
|
||||
if keepdim is None:
|
||||
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
|
||||
# torch.max(input, dim, keepdim)
|
||||
else:
|
||||
keepdim = _get_const(keepdim, "i", "keepdim")
|
||||
dim = _get_const(dim_or_y, "i", "dim")
|
||||
if g.opset < 18:
|
||||
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
max = g.op("ReduceMax", self, axes, keepdims_i=keepdim)
|
||||
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
|
||||
return max, indices
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
return g.op("ReduceMin", self, keepdims_i=0)
|
||||
# torch.min(input, other)
|
||||
if keepdim is None:
|
||||
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
|
||||
# torch.min(input, dim, keepdim)
|
||||
else:
|
||||
keepdim = _get_const(keepdim, "i", "keepdim")
|
||||
dim = _get_const(dim_or_y, "i", "dim")
|
||||
if g.opset < 18:
|
||||
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
min = g.op("ReduceMin", self, axes, keepdims_i=keepdim)
|
||||
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
|
||||
return min, indices
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _numel_helper(g: jit_utils.GraphContext, self):
|
||||
shape = g.op("Shape", self)
|
||||
return g.op("ReduceProd", shape, keepdims_i=0)
|
||||
|
||||
|
||||
@parse_args("v", "is", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim):
|
||||
if g.opset < 18:
|
||||
if dim is None:
|
||||
mean = g.op("ReduceMean", input, keepdims_i=0)
|
||||
t_mean = mean
|
||||
num_elements = _numel_helper(g, input)
|
||||
else:
|
||||
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
|
||||
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
|
||||
redudced_dims = g.op("Shape", input)
|
||||
# dim could contain one or multiple dimensions
|
||||
redudced_dims = g.op(
|
||||
"Gather",
|
||||
redudced_dims,
|
||||
g.op("Constant", value_t=torch.tensor(dim)),
|
||||
axis_i=0,
|
||||
)
|
||||
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
|
||||
sub_v = g.op("Sub", input, t_mean)
|
||||
sqr_sub = g.op("Mul", sub_v, sub_v)
|
||||
keepdim_mean = 0 if dim is None else keepdim
|
||||
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
|
||||
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
|
||||
if correction is None:
|
||||
correction = 1
|
||||
if correction != 0:
|
||||
num_elements = g.op(
|
||||
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
|
||||
)
|
||||
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
|
||||
mul = g.op("Mul", var, num_elements)
|
||||
var = g.op("Div", mul, g.op("Sub", num_elements, one))
|
||||
return var, mean
|
||||
else:
|
||||
axes = None
|
||||
if dim is None:
|
||||
mean = g.op("ReduceMean", input, keepdims_i=0)
|
||||
t_mean = mean
|
||||
num_elements = _numel_helper(g, input)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim)
|
||||
t_mean = g.op("ReduceMean", input, axes, keepdims_i=1)
|
||||
redudced_dims = g.op("Shape", input)
|
||||
# dim could contain one or multiple dimensions
|
||||
redudced_dims = g.op(
|
||||
"Gather",
|
||||
redudced_dims,
|
||||
g.op("Constant", value_t=torch.tensor(dim)),
|
||||
axis_i=0,
|
||||
)
|
||||
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
|
||||
sub_v = g.op("Sub", input, t_mean)
|
||||
sqr_sub = g.op("Mul", sub_v, sub_v)
|
||||
keepdim_mean = 0 if dim is None else keepdim
|
||||
if axes is None:
|
||||
var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean)
|
||||
else:
|
||||
var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean)
|
||||
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
|
||||
if correction is None:
|
||||
correction = 1
|
||||
if correction != 0:
|
||||
num_elements = g.op(
|
||||
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
|
||||
)
|
||||
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
|
||||
mul = g.op("Mul", var, num_elements)
|
||||
var = g.op("Div", mul, g.op("Sub", num_elements, one))
|
||||
return var, mean
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _embedding_bag_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
):
|
||||
if scale_grad_by_freq and GLOBALS.export_training:
|
||||
return _onnx_unsupported(
|
||||
"embedding_bag with scale_grad_by_freq for training mode"
|
||||
)
|
||||
if padding_idx is not None and padding_idx >= 0:
|
||||
raise RuntimeError("embedding_bag with padding_idx")
|
||||
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||||
zero = g.op("Constant", value_t=torch.tensor([0]))
|
||||
|
||||
indices_len = _unsqueeze_helper(
|
||||
g,
|
||||
_size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))),
|
||||
[0],
|
||||
)
|
||||
if not include_last_offset:
|
||||
offsets = [offsets, indices_len]
|
||||
offsets = g.op("Concat", *offsets, axis_i=0)
|
||||
|
||||
# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
|
||||
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
|
||||
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
|
||||
offsets_starts = _slice_helper(
|
||||
g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
|
||||
)
|
||||
offsets_ends = _slice_helper(
|
||||
g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
|
||||
)
|
||||
|
||||
loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0)))
|
||||
|
||||
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", loop_len, loop_condition, n_blocks=1
|
||||
)
|
||||
loop_block = loop_context.block
|
||||
|
||||
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
|
||||
indices_start = loop_context.op(
|
||||
"Gather", offsets_starts, block_input_iter, axis_i=0
|
||||
)
|
||||
indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
|
||||
indices_start = _unsqueeze_helper(loop_context, indices_start, [0])
|
||||
indices_end = _unsqueeze_helper(loop_context, indices_end, [0])
|
||||
|
||||
indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
|
||||
embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
|
||||
if not _is_none(per_sample_weights):
|
||||
per_sample_weights_row = loop_context.op(
|
||||
"Slice", per_sample_weights, indices_start, indices_end, zero
|
||||
)
|
||||
per_sample_weights_row = _unsqueeze_helper(
|
||||
loop_context, per_sample_weights_row, [1]
|
||||
)
|
||||
embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
|
||||
if mode == 0:
|
||||
embeddings = _reducesum_helper(
|
||||
loop_context, embeddings, axes_i=[0], keepdims_i=0
|
||||
)
|
||||
elif mode == 1:
|
||||
if loop_context.opset < 18:
|
||||
embeddings = loop_context.op(
|
||||
"ReduceMean", embeddings, axes_i=[0], keepdims_i=0
|
||||
)
|
||||
else:
|
||||
axes = loop_context.op(
|
||||
"Constant", value_t=torch.tensor([0], dtype=torch.long)
|
||||
)
|
||||
embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0)
|
||||
else:
|
||||
if loop_context.opset < 18:
|
||||
embeddings = loop_context.op(
|
||||
"ReduceMax", embeddings, axes_i=[0], keepdims_i=0
|
||||
)
|
||||
else:
|
||||
axes = loop_context.op(
|
||||
"Constant", value_t=torch.tensor([0], dtype=torch.long)
|
||||
)
|
||||
embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0)
|
||||
|
||||
cond_out = loop_context.op(
|
||||
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
utils._add_output_to_block(loop_block, cond_out)
|
||||
utils._add_output_to_block(loop_block, embeddings)
|
||||
|
||||
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||||
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||||
return loop.node().output(), None, None, None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _linalg_vector_norm_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: float,
|
||||
dim: Optional[Sequence[int]],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
axes = None
|
||||
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
|
||||
if _is_none(dim):
|
||||
self = _reshape_helper(g, self, [-1])
|
||||
keepdim = False
|
||||
elif g.opset >= 18:
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
|
||||
if ord == math.inf:
|
||||
if g.opset < 18:
|
||||
result = g.op(
|
||||
"ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
if axes is None:
|
||||
result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim)
|
||||
else:
|
||||
result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim)
|
||||
elif ord == -math.inf:
|
||||
if g.opset < 18:
|
||||
result = g.op(
|
||||
"ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
if axes is None:
|
||||
result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim)
|
||||
else:
|
||||
result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim)
|
||||
elif ord == 0:
|
||||
if g.opset < 11:
|
||||
return _onnx_opset_unsupported_detailed(
|
||||
"linalg_vector_norm", 9, 11, "ord=0 not supported", self
|
||||
)
|
||||
else:
|
||||
if dim is None:
|
||||
self = _reshape_helper(
|
||||
g,
|
||||
self,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
|
||||
)
|
||||
keepdim = False
|
||||
|
||||
cond_op = g.op(
|
||||
"Not",
|
||||
g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))),
|
||||
)
|
||||
cond_op = g.op(
|
||||
"Cast",
|
||||
cond_op,
|
||||
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
||||
)
|
||||
return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim)
|
||||
elif ord == 1:
|
||||
if g.opset < 18:
|
||||
result = _reduce_op_symbolic_helper("ReduceL1")(
|
||||
g, self, dim=dim, keepdim=keepdim
|
||||
)
|
||||
else:
|
||||
if axes is None:
|
||||
result = _reduce_op_symbolic_helper("ReduceL1")(
|
||||
g, self, keepdim=keepdim
|
||||
)
|
||||
else:
|
||||
result = _reduce_op_symbolic_helper("ReduceL1")(
|
||||
g, self, axes, keepdim=keepdim
|
||||
)
|
||||
elif ord == 2:
|
||||
if g.opset < 18:
|
||||
result = _reduce_op_symbolic_helper("ReduceL2")(
|
||||
g, self, dim=dim, keepdim=keepdim
|
||||
)
|
||||
else:
|
||||
if axes is None:
|
||||
result = _reduce_op_symbolic_helper("ReduceL2")(
|
||||
g, self, keepdim=keepdim
|
||||
)
|
||||
else:
|
||||
result = _reduce_op_symbolic_helper("ReduceL2")(
|
||||
g, self, axes, keepdim=keepdim
|
||||
)
|
||||
else:
|
||||
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
|
||||
result = _reducesum_helper(
|
||||
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
|
||||
)
|
||||
result = g.op(
|
||||
"Pow",
|
||||
result,
|
||||
g.op(
|
||||
"Div",
|
||||
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
|
||||
ord_op,
|
||||
),
|
||||
)
|
||||
|
||||
if not _is_none(dtype):
|
||||
dtype = _get_const(dtype, "i", "dtype")
|
||||
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type]
|
||||
return result
|
||||
|
||||
|
||||
# Deprecated. Internally use _type_utils.ScalarType
|
||||
# TODO: remove these once we support Type's in the JIT IR and we can once again
|
||||
# use the unified toType operator
|
||||
|
@ -70,15 +70,6 @@ __all__ = [
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::div")
|
||||
@_beartype.beartype
|
||||
def div(g: jit_utils.GraphContext, self, other, *args):
|
||||
@ -276,20 +267,20 @@ def _aten_max_pool_with_indices_onnx(
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool1d",
|
||||
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
|
||||
decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool2d",
|
||||
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
|
||||
decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool3d",
|
||||
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
|
||||
decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool1d_with_indices",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool1d_with_indices",
|
||||
1,
|
||||
return_indices=True,
|
||||
@ -299,7 +290,7 @@ def _aten_max_pool_with_indices_onnx(
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool2d_with_indices",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool2d_with_indices",
|
||||
2,
|
||||
return_indices=True,
|
||||
@ -309,7 +300,7 @@ def _aten_max_pool_with_indices_onnx(
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool3d_with_indices",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool3d_with_indices",
|
||||
3,
|
||||
return_indices=True,
|
||||
@ -397,15 +388,15 @@ def _adjust_attributes_of_avg_pool(
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool1d",
|
||||
decorate=[_apply_params("avg_pool1d", 1)],
|
||||
decorate=[symbolic_helper._apply_params("avg_pool1d", 1)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool2d",
|
||||
decorate=[_apply_params("avg_pool2d", 2)],
|
||||
decorate=[symbolic_helper._apply_params("avg_pool2d", 2)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool3d",
|
||||
decorate=[_apply_params("avg_pool3d", 3)],
|
||||
decorate=[symbolic_helper._apply_params("avg_pool3d", 3)],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _avg_pool(name, expand_size):
|
||||
@ -443,27 +434,27 @@ def _avg_pool(name, expand_size):
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
|
@ -17,7 +17,6 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
@ -86,15 +85,6 @@ __all__ = [
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::hardtanh")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "f", "f")
|
||||
@ -111,7 +101,7 @@ def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val:
|
||||
"Constant",
|
||||
value_t=torch.tensor(max_val, dtype=scalar_type.dtype()),
|
||||
)
|
||||
return opset9._op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min_val, max_val, opset_before=12
|
||||
)
|
||||
|
||||
@ -146,7 +136,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
|
||||
symbolic_helper._get_tensor_rank(min) == 0
|
||||
and symbolic_helper._get_tensor_rank(max) == 0
|
||||
):
|
||||
return opset9._op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min, max, opset_before=12
|
||||
)
|
||||
else:
|
||||
@ -160,11 +150,13 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
|
||||
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
|
||||
if symbolic_helper._get_tensor_rank(min) == 0:
|
||||
max = opset9.unused(g)
|
||||
return opset9._op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min, max, opset_before=12
|
||||
)
|
||||
else:
|
||||
return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Max", self, min, opset_before=12
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::clamp_max")
|
||||
@ -174,11 +166,13 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
|
||||
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
|
||||
if symbolic_helper._get_tensor_rank(max) == 0:
|
||||
min = opset9.unused(g)
|
||||
return opset9._op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min, max, opset_before=12
|
||||
)
|
||||
else:
|
||||
return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Min", self, max, opset_before=12
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::relu6")
|
||||
@ -348,31 +342,31 @@ def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bicubic2d",
|
||||
decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _interpolate(name: str, dim: int, interpolate_mode: str):
|
||||
@ -1281,26 +1275,7 @@ def linalg_vector_norm(
|
||||
keepdim: bool,
|
||||
dtype,
|
||||
):
|
||||
if ord == 0:
|
||||
if dim is None:
|
||||
self = symbolic_helper._reshape_helper(
|
||||
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
)
|
||||
keepdim = False
|
||||
|
||||
cond_op = g.op(
|
||||
"Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
)
|
||||
cond_op = g.op(
|
||||
"Cast",
|
||||
cond_op,
|
||||
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
||||
)
|
||||
return symbolic_helper._reducesum_helper(
|
||||
g, cond_op, axes_i=dim, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
|
||||
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@ -1318,86 +1293,18 @@ def embedding_bag(
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
):
|
||||
if scale_grad_by_freq and GLOBALS.export_training:
|
||||
return symbolic_helper._onnx_unsupported(
|
||||
"embedding_bag with scale_grad_by_freq for training mode"
|
||||
)
|
||||
if padding_idx is not None and padding_idx >= 0:
|
||||
raise RuntimeError("embedding_bag with padding_idx")
|
||||
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||||
zero = g.op("Constant", value_t=torch.tensor([0]))
|
||||
|
||||
indices_len = symbolic_helper._unsqueeze_helper(
|
||||
return symbolic_helper._embedding_bag_helper(
|
||||
g,
|
||||
symbolic_helper._size_helper(
|
||||
g, indices, g.op("Constant", value_t=torch.tensor(0))
|
||||
),
|
||||
[0],
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
)
|
||||
if not include_last_offset:
|
||||
offsets = [offsets, indices_len]
|
||||
offsets = g.op("Concat", *offsets, axis_i=0)
|
||||
|
||||
# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
|
||||
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
|
||||
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
|
||||
offsets_starts = symbolic_helper._slice_helper(
|
||||
g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
|
||||
)
|
||||
offsets_ends = symbolic_helper._slice_helper(
|
||||
g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
|
||||
)
|
||||
|
||||
loop_len = symbolic_helper._size_helper(
|
||||
g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))
|
||||
)
|
||||
|
||||
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", loop_len, loop_condition, n_blocks=1
|
||||
)
|
||||
loop_block = loop_context.block
|
||||
|
||||
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
|
||||
block_input_iter = utils._add_input_to_block(loop_block)
|
||||
cond = utils._add_input_to_block(loop_block)
|
||||
|
||||
indices_start = loop_context.op(
|
||||
"Gather", offsets_starts, block_input_iter, axis_i=0
|
||||
)
|
||||
indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
|
||||
indices_start = symbolic_helper._unsqueeze_helper(loop_context, indices_start, [0])
|
||||
indices_end = symbolic_helper._unsqueeze_helper(loop_context, indices_end, [0])
|
||||
|
||||
indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
|
||||
embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
|
||||
if not symbolic_helper._is_none(per_sample_weights):
|
||||
per_sample_weights_row = loop_context.op(
|
||||
"Slice", per_sample_weights, indices_start, indices_end, zero
|
||||
)
|
||||
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
|
||||
loop_context, per_sample_weights_row, [1]
|
||||
)
|
||||
embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
|
||||
if mode == 0:
|
||||
embeddings = symbolic_helper._reducesum_helper(
|
||||
loop_context, embeddings, axes_i=[0], keepdims_i=0
|
||||
)
|
||||
elif mode == 1:
|
||||
embeddings = loop_context.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
|
||||
else:
|
||||
embeddings = loop_context.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
|
||||
|
||||
cond_out = loop_context.op(
|
||||
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
|
||||
)
|
||||
utils._add_output_to_block(loop_block, cond_out)
|
||||
utils._add_output_to_block(loop_block, embeddings)
|
||||
|
||||
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||||
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||||
return loop.node().output(), None, None, None
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_renorm")
|
||||
|
@ -21,15 +21,6 @@ from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::softmax")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
@ -412,7 +403,7 @@ def fake_quantize_per_tensor_affine(
|
||||
def _reduce_op_symbolic(onnx_op_name):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = opset9._maybe_cast_reduce_op_input(g, self)
|
||||
self = symbolic_helper._maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None:
|
||||
# all-reduce path
|
||||
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
|
||||
@ -425,13 +416,13 @@ def _reduce_op_symbolic(onnx_op_name):
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::sum",
|
||||
decorate=[_apply_params("ReduceSum", "sum")],
|
||||
decorate=[symbolic_helper._apply_params("ReduceSum", "sum")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op, name):
|
||||
symbolic = _reduce_op_symbolic(onnx_op)
|
||||
|
||||
@opset9.overload_by_arg_count
|
||||
@symbolic_helper._overload_by_arg_count
|
||||
@_beartype.beartype
|
||||
def reduce(g, *args, **kwargs):
|
||||
@symbolic_helper.parse_args("v", "none")
|
||||
|
@ -14,19 +14,23 @@ New operators:
|
||||
Resize
|
||||
ScatterElements
|
||||
ScatterND
|
||||
Split
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import _beartype, registration
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["col2im"]
|
||||
__all__ = [
|
||||
"col2im",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
||||
|
||||
@ -68,3 +72,190 @@ def col2im(
|
||||
pads_i=adjusted_padding,
|
||||
strides_i=stride,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::prod",
|
||||
decorate=[
|
||||
symbolic_helper._apply_params(
|
||||
"ReduceProd", "prod", allow_multi_dim_support=False
|
||||
)
|
||||
],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||||
return symbolic_helper._reduce_with_dtype_helper(
|
||||
onnx_op, name, allow_multi_dim_support
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::native_layer_norm")
|
||||
@symbolic_helper.quantized_args(True, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
|
||||
@_beartype.beartype
|
||||
def _native_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: _C.Value,
|
||||
bias: _C.Value,
|
||||
eps: float,
|
||||
) -> Tuple[_C.Value, _C.Value, _C.Value]:
|
||||
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::glu")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def _glu(g: jit_utils.GraphContext, input, dim):
|
||||
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
|
||||
if dim_size is not None:
|
||||
assert dim_size % 2 == 0
|
||||
|
||||
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
|
||||
return g.op("Mul", first, g.op("Sigmoid", second))
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::max")
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::maximum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
@_beartype.beartype
|
||||
def maximum(g: jit_utils.GraphContext, input, other):
|
||||
return max(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::minimum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
@_beartype.beartype
|
||||
def minimum(g: jit_utils.GraphContext, input, other):
|
||||
return min(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::amin")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::aminmax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
if not symbolic_helper._is_none(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, axes, keepdims_i=keepdim
|
||||
)
|
||||
else:
|
||||
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
|
||||
"ReduceMax", self, keepdims_i=keepdim
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::var_mean")
|
||||
@_beartype.beartype
|
||||
def _var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
if len(args) == 1:
|
||||
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
|
||||
else:
|
||||
return symbolic_helper._var_mean_helper(g, input, *args)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::logsumexp")
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||||
if dim is None:
|
||||
return g.op("ReduceLogSumExp", input, keepdims_i=0)
|
||||
else:
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def _linalg_matrix_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: torch._C.Value,
|
||||
dim: List[int],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
):
|
||||
return symbolic_helper._embedding_bag_helper(
|
||||
g,
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_vector_norm")
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def linalg_vector_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
ord: float,
|
||||
dim: Optional[Sequence[int]],
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
|
||||
|
32
torch/onnx/symbolic_opset19.py
Normal file
32
torch/onnx/symbolic_opset19.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""This file exports ONNX ops for opset 19.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 19]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AveragePool
|
||||
Cast
|
||||
CastLike
|
||||
Constant
|
||||
DeformConv
|
||||
DequantizeLinear
|
||||
Equal
|
||||
Identity
|
||||
If
|
||||
Loop
|
||||
Pad
|
||||
QuantizeLinear
|
||||
Reshape
|
||||
Resize
|
||||
Scan
|
||||
Shape
|
||||
Size
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__: List[str] = []
|
85
torch/onnx/symbolic_opset20.py
Normal file
85
torch/onnx/symbolic_opset20.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""This file exports ONNX ops for opset 20.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 20]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
AffineGrid
|
||||
ConstantOfShape
|
||||
DFT
|
||||
Gelu
|
||||
GridSample
|
||||
ImageDecoder
|
||||
IsInf
|
||||
IsNaN
|
||||
ReduceMax
|
||||
ReduceMin
|
||||
RegexFullMatch
|
||||
StringConcat
|
||||
StringSplit
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["_grid_sampler", "_affine_grid_generator"]
|
||||
|
||||
|
||||
def convert_grid_sample_mode(mode_s):
|
||||
return (
|
||||
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
|
||||
)
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
@_beartype.beartype
|
||||
def _grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
grid: _C.Value,
|
||||
mode_enum: int,
|
||||
padding_mode_enum: int,
|
||||
align_corners: bool,
|
||||
):
|
||||
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
|
||||
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
|
||||
mode_s = convert_grid_sample_mode(mode_s)
|
||||
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::affine_grid_generator")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def _affine_grid_generator(
|
||||
g: jit_utils.GraphContext,
|
||||
theta: _C.Value,
|
||||
size: _C.Value,
|
||||
align_corners: bool,
|
||||
):
|
||||
return g.op(
|
||||
"AffineGrid",
|
||||
theta,
|
||||
size,
|
||||
align_corners_i=int(align_corners),
|
||||
)
|
@ -64,38 +64,29 @@ for block_listed_op in block_listed_operators:
|
||||
)
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
|
@ -185,7 +185,6 @@ __all__ = [
|
||||
"ones_like",
|
||||
"ones",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
@ -293,15 +292,6 @@ __all__ = [
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
|
||||
|
||||
|
||||
def _apply_params(*args, **kwargs):
|
||||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||||
|
||||
def _apply(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _apply
|
||||
|
||||
|
||||
def _export(name: str):
|
||||
"""Exports the function in the current global namespace."""
|
||||
|
||||
@ -774,120 +764,27 @@ def _slice(g: jit_utils.GraphContext, input, axes, starts, ends):
|
||||
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
self, _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
|
||||
# This check only covers traced modules where dtype is present
|
||||
# pytorch reduce-ops cast all other integral types to int64
|
||||
if (
|
||||
not symbolic_helper._is_fp(self)
|
||||
and scalar_type != _type_utils.JitScalarType.INT64
|
||||
):
|
||||
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
return self
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = _maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None or dim == tuple():
|
||||
# Dim can be 0, which will cause (not dim) == True. So we don't want to do
|
||||
# (not dim)
|
||||
# all-reduce path
|
||||
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
|
||||
else:
|
||||
# dim-reduce path
|
||||
desc = "is" if allow_multi_dim_support else "i"
|
||||
dim, keepdim = symbolic_helper._get_const(
|
||||
dim, desc, "dim"
|
||||
), symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||||
dim_list = dim if allow_multi_dim_support else [dim]
|
||||
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
|
||||
|
||||
return symbolic
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def overload_by_arg_count(fn):
|
||||
@functools.wraps(fn)
|
||||
@_beartype.beartype
|
||||
def wrapper(g, *args):
|
||||
overloads = fn(g, *args)
|
||||
for overload in overloads:
|
||||
arg_descriptors = overload._arg_descriptors
|
||||
if len(arg_descriptors) == len(args):
|
||||
return overload(g, *args)
|
||||
return symbolic_helper._unimplemented(
|
||||
f"aten::{fn.__name__}", f"with {len(args)} arguments"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::sum", decorate=[_apply_params("ReduceSum", "sum")])
|
||||
@_onnx_symbolic("aten::mean", decorate=[_apply_params("ReduceMean", "mean")])
|
||||
@_onnx_symbolic(
|
||||
"aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")]
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
|
||||
)
|
||||
# torch.prod does not support multidimensional "dim"
|
||||
@_onnx_symbolic(
|
||||
"aten::prod",
|
||||
decorate=[_apply_params("ReduceProd", "prod", allow_multi_dim_support=False)],
|
||||
decorate=[
|
||||
symbolic_helper._apply_params(
|
||||
"ReduceProd", "prod", allow_multi_dim_support=False
|
||||
)
|
||||
],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||||
symbolic = _reduce_op_symbolic(
|
||||
onnx_op, allow_multi_dim_support=allow_multi_dim_support
|
||||
return symbolic_helper._reduce_with_dtype_helper(
|
||||
onnx_op, name, allow_multi_dim_support
|
||||
)
|
||||
|
||||
@overload_by_arg_count
|
||||
def reduce(g, *args, **kwargs):
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "none")
|
||||
def reduce_nodim(g, self, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||||
self = g.op("Cast", self, to_i=dtype_onnx)
|
||||
elif dtype.node().kind() != "prim::Constant":
|
||||
return symbolic_helper._unimplemented(name, "dtype", dtype)
|
||||
result = symbolic(g, self)
|
||||
if dtype_onnx is not None:
|
||||
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||||
result
|
||||
).onnx_type()
|
||||
if result_dtype_onnx != dtype_onnx:
|
||||
result = g.op("Cast", result, to_i=dtype_onnx)
|
||||
return result
|
||||
|
||||
dim_desc = "is" if allow_multi_dim_support else "i"
|
||||
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type]
|
||||
def reduce_dim(g, self, dim, keepdim, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||||
self = g.op("Cast", self, to_i=dtype_onnx)
|
||||
elif dtype.node().kind() != "prim::Constant":
|
||||
return symbolic_helper._unimplemented(name, "dtype", dtype)
|
||||
result = symbolic(g, self, dim, keepdim)
|
||||
if dtype_onnx is not None:
|
||||
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||||
result
|
||||
).onnx_type()
|
||||
if result_dtype_onnx != dtype_onnx:
|
||||
result = g.op("Cast", result, to_i=dtype_onnx)
|
||||
return result
|
||||
|
||||
return reduce_nodim, reduce_dim
|
||||
|
||||
return reduce
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::cumsum")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@ -1356,65 +1253,13 @@ def mish(g: jit_utils.GraphContext, input):
|
||||
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
|
||||
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
|
||||
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
|
||||
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
|
||||
`Clip<int>(INPUT)` (opset version < 12).
|
||||
|
||||
Args:
|
||||
g (torch._C.Graph): graph to write the ONNX representation into.
|
||||
op_name (str): operator name in ONNX.
|
||||
*args (tuple): operands to the operator.
|
||||
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
|
||||
indicating the smallest opset version to trigger such casting behavior and "target_float_t"
|
||||
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
|
||||
|
||||
Returns:
|
||||
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
|
||||
"""
|
||||
opset_before = kwargs.pop("opset_before", None)
|
||||
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
|
||||
|
||||
inputs = list(args)
|
||||
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
|
||||
|
||||
require_cast = not symbolic_helper._is_fp(inputs[0]) and (
|
||||
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
|
||||
)
|
||||
|
||||
if require_cast:
|
||||
for input in inputs:
|
||||
if input.isCompleteTensor():
|
||||
input_scalar_type = _type_utils.JitScalarType.from_value(input)
|
||||
if input_scalar_type != dtype_0:
|
||||
raise errors.SymbolicValueError(
|
||||
f"Inputs of {op_name} must have same dtype."
|
||||
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
|
||||
input,
|
||||
)
|
||||
for i, input in enumerate(inputs):
|
||||
if input.isCompleteTensor() and not symbolic_helper._is_fp(input):
|
||||
inputs[i] = g.op(
|
||||
"Cast",
|
||||
input,
|
||||
to_i=target_float_t.onnx_type(),
|
||||
)
|
||||
|
||||
self = g.op(op_name, *inputs, **kwargs)
|
||||
|
||||
if require_cast:
|
||||
self = g.op("Cast", self, to_i=dtype_0.onnx_type())
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::relu")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@_beartype.beartype
|
||||
def relu(g: jit_utils.GraphContext, input):
|
||||
return _op_with_optional_float_cast(g, "Relu", input, opset_before=14)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Relu", input, opset_before=14
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::relu6")
|
||||
@ -1603,7 +1448,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool1d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
|
||||
),
|
||||
_export("max_pool1d"),
|
||||
@ -1612,7 +1457,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool2d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
|
||||
),
|
||||
_export("max_pool2d"),
|
||||
@ -1621,7 +1466,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool3d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
|
||||
),
|
||||
_export("max_pool3d"),
|
||||
@ -1716,21 +1561,21 @@ max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")(
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool1d",
|
||||
decorate=[
|
||||
_apply_params("avg_pool1d", torch.nn.modules.utils._single),
|
||||
symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single),
|
||||
_export("avg_pool1d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool2d",
|
||||
decorate=[
|
||||
_apply_params("avg_pool2d", torch.nn.modules.utils._pair),
|
||||
symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair),
|
||||
_export("avg_pool2d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::avg_pool3d",
|
||||
decorate=[
|
||||
_apply_params("avg_pool3d", torch.nn.modules.utils._triple),
|
||||
symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple),
|
||||
_export("avg_pool3d"),
|
||||
],
|
||||
)
|
||||
@ -1762,7 +1607,7 @@ def _avg_pool(name, tuple_fn):
|
||||
# this accommodation.
|
||||
# More detail on https://github.com/pytorch/pytorch/issues/57178
|
||||
if count_include_pad:
|
||||
input = _op_with_optional_float_cast(
|
||||
input = symbolic_helper._op_with_optional_float_cast(
|
||||
g,
|
||||
"Pad",
|
||||
input,
|
||||
@ -1794,7 +1639,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_avg_pool1d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
|
||||
),
|
||||
_export("adaptive_avg_pool1d"),
|
||||
@ -1803,7 +1648,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_avg_pool2d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
|
||||
),
|
||||
_export("adaptive_avg_pool2d"),
|
||||
@ -1812,7 +1657,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_avg_pool3d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
|
||||
),
|
||||
_export("adaptive_avg_pool3d"),
|
||||
@ -1821,7 +1666,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_max_pool1d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_max_pool1d",
|
||||
"MaxPool",
|
||||
torch.nn.modules.utils._single,
|
||||
@ -1833,7 +1678,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_max_pool2d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_max_pool2d",
|
||||
"MaxPool",
|
||||
torch.nn.modules.utils._pair,
|
||||
@ -1845,7 +1690,7 @@ def _avg_pool(name, tuple_fn):
|
||||
@_onnx_symbolic(
|
||||
"aten::adaptive_max_pool3d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
symbolic_helper._apply_params(
|
||||
"adaptive_max_pool3d",
|
||||
"MaxPool",
|
||||
torch.nn.modules.utils._triple,
|
||||
@ -1961,7 +1806,7 @@ def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value):
|
||||
|
||||
padding = _convert_padding_node(padding)
|
||||
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
|
||||
)
|
||||
|
||||
@ -2016,7 +1861,7 @@ def reflection_pad(g: jit_utils.GraphContext, input, padding):
|
||||
mode = "reflect"
|
||||
padding = _convert_padding_node(padding)
|
||||
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
|
||||
)
|
||||
|
||||
@ -2029,7 +1874,7 @@ def replication_pad(g: jit_utils.GraphContext, input, padding):
|
||||
mode = "edge"
|
||||
padding = _convert_padding_node(padding)
|
||||
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
|
||||
)
|
||||
|
||||
@ -2059,42 +1904,42 @@ def pad(
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest1d",
|
||||
decorate=[
|
||||
_apply_params("upsample_nearest1d", 3, "nearest"),
|
||||
symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"),
|
||||
_export("upsample_nearest1d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest2d",
|
||||
decorate=[
|
||||
_apply_params("upsample_nearest2d", 4, "nearest"),
|
||||
symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"),
|
||||
_export("upsample_nearest2d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_nearest3d",
|
||||
decorate=[
|
||||
_apply_params("upsample_nearest3d", 5, "nearest"),
|
||||
symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"),
|
||||
_export("upsample_nearest3d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_linear1d",
|
||||
decorate=[
|
||||
_apply_params("upsample_linear1d", 3, "linear"),
|
||||
symbolic_helper._apply_params("upsample_linear1d", 3, "linear"),
|
||||
_export("upsample_linear1d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_bilinear2d",
|
||||
decorate=[
|
||||
_apply_params("upsample_bilinear2d", 4, "linear"),
|
||||
symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"),
|
||||
_export("upsample_bilinear2d"),
|
||||
],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[
|
||||
_apply_params("upsample_trilinear3d", 5, "linear"),
|
||||
symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"),
|
||||
_export("upsample_trilinear3d"),
|
||||
],
|
||||
)
|
||||
@ -2942,7 +2787,15 @@ def native_layer_norm(
|
||||
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
|
||||
eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
|
||||
|
||||
mean = g.op("ReduceMean", input, axes_i=axes)
|
||||
if g.opset < 18:
|
||||
mean = g.op("ReduceMean", input, axes_i=axes)
|
||||
else:
|
||||
mean = g.op(
|
||||
"ReduceMean",
|
||||
input,
|
||||
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
|
||||
)
|
||||
|
||||
numerator = sub(g, input, mean)
|
||||
|
||||
# Cast it to eps dtype to avoid precision loss
|
||||
@ -2957,7 +2810,15 @@ def native_layer_norm(
|
||||
)
|
||||
|
||||
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
|
||||
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
|
||||
if g.opset < 18:
|
||||
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
|
||||
else:
|
||||
variance = g.op(
|
||||
"ReduceMean",
|
||||
pow(g, numerator, two_cst),
|
||||
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
|
||||
)
|
||||
|
||||
denominator = sqrt(g, g.op("Add", variance, eps_cst))
|
||||
normalized = g.op("Div", numerator, denominator)
|
||||
|
||||
@ -3405,7 +3266,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
|
||||
return clamp_min(g, self, min)
|
||||
else:
|
||||
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g,
|
||||
"Clip",
|
||||
self,
|
||||
@ -3422,13 +3283,15 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
|
||||
@_beartype.beartype
|
||||
def clamp_min(g: jit_utils.GraphContext, self, min):
|
||||
if symbolic_helper._is_constant(min):
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
|
||||
)
|
||||
else:
|
||||
dtype = _type_utils.JitScalarType.from_value(self)
|
||||
min = g.op("Cast", min, to_i=dtype.onnx_type())
|
||||
return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Max", self, min, opset_before=12
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::clamp_max")
|
||||
@ -3436,13 +3299,15 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
|
||||
@_beartype.beartype
|
||||
def clamp_max(g: jit_utils.GraphContext, self, max):
|
||||
if symbolic_helper._is_constant(max):
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
|
||||
)
|
||||
else:
|
||||
dtype = _type_utils.JitScalarType.from_value(self)
|
||||
max = g.op("Cast", max, to_i=dtype.onnx_type())
|
||||
return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Min", self, max, opset_before=12
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::max")
|
||||
@ -3451,19 +3316,7 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
return g.op("ReduceMax", self, keepdims_i=0)
|
||||
# torch.max(input, other)
|
||||
if keepdim is None:
|
||||
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
|
||||
# torch.max(input, dim, keepdim)
|
||||
else:
|
||||
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
|
||||
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||||
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
|
||||
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
|
||||
return max, indices
|
||||
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::maximum")
|
||||
@ -3477,19 +3330,7 @@ def maximum(g: jit_utils.GraphContext, input, other):
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
return g.op("ReduceMin", self, keepdims_i=0)
|
||||
# torch.min(input, other)
|
||||
if keepdim is None:
|
||||
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
|
||||
# torch.min(input, dim, keepdim)
|
||||
else:
|
||||
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
|
||||
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||||
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
|
||||
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
|
||||
return min, indices
|
||||
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::minimum")
|
||||
@ -3550,22 +3391,28 @@ def dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::alpha_dropout_", decorate=[_apply_params("aten::alpha_dropout_")]
|
||||
"aten::alpha_dropout_",
|
||||
decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")],
|
||||
) # See Note [Export inplace]
|
||||
@_onnx_symbolic(
|
||||
"aten::feature_alpha_dropout_",
|
||||
decorate=[_apply_params("aten::feature_alpha_dropout_")],
|
||||
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::feature_dropout_", decorate=[_apply_params("aten::feature_dropout_")]
|
||||
"aten::feature_dropout_",
|
||||
decorate=[symbolic_helper._apply_params("aten::feature_dropout_")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::feature_alpha_dropout",
|
||||
decorate=[_apply_params("aten::feature_alpha_dropout")],
|
||||
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")],
|
||||
)
|
||||
@_onnx_symbolic("aten::alpha_dropout", decorate=[_apply_params("aten::alpha_dropout")])
|
||||
@_onnx_symbolic(
|
||||
"aten::feature_dropout", decorate=[_apply_params("aten::feature_dropout")]
|
||||
"aten::alpha_dropout",
|
||||
decorate=[symbolic_helper._apply_params("aten::alpha_dropout")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::feature_dropout",
|
||||
decorate=[symbolic_helper._apply_params("aten::feature_dropout")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _unsupported_dropout(name: str):
|
||||
@ -3585,9 +3432,9 @@ def _unsupported_dropout(name: str):
|
||||
@_beartype.beartype
|
||||
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
|
||||
if p == 1:
|
||||
f = _reduce_op_symbolic("ReduceL1")
|
||||
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1")
|
||||
elif p == 2:
|
||||
f = _reduce_op_symbolic("ReduceL2")
|
||||
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2")
|
||||
else:
|
||||
raise errors.SymbolicValueError(
|
||||
"ONNX export only p-norms with p of 1 or 2", self
|
||||
@ -4135,7 +3982,7 @@ def slice(g: jit_utils.GraphContext, self, *args):
|
||||
@symbolic_helper.parse_args("v", "f", "f")
|
||||
@_beartype.beartype
|
||||
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
|
||||
return _op_with_optional_float_cast(
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
|
||||
)
|
||||
|
||||
@ -4283,8 +4130,7 @@ def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||||
@_onnx_symbolic("aten::numel")
|
||||
@_beartype.beartype
|
||||
def numel(g: jit_utils.GraphContext, self):
|
||||
shape = g.op("Shape", self)
|
||||
return g.op("ReduceProd", shape, keepdims_i=0)
|
||||
return symbolic_helper._numel_helper(g, self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::topk")
|
||||
@ -4974,12 +4820,16 @@ def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
), symbolic_helper._squeeze_helper(g, c_outs, [0])
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::gru", decorate=[_apply_params("GRU"), _export("gru")])
|
||||
@_onnx_symbolic(
|
||||
"aten::rnn_tanh", decorate=[_apply_params("RNN_TANH"), _export("rnn_tanh")]
|
||||
"aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")]
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::rnn_relu", decorate=[_apply_params("RNN_RELU"), _export("rnn_relu")]
|
||||
"aten::rnn_tanh",
|
||||
decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::rnn_relu",
|
||||
decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")],
|
||||
)
|
||||
def _one_hidden_rnn(kind: str):
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
|
||||
@ -5583,37 +5433,7 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
||||
@symbolic_helper.parse_args("v", "is", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim):
|
||||
if dim is None:
|
||||
mean = g.op("ReduceMean", input, keepdims_i=0)
|
||||
t_mean = mean
|
||||
num_elements = numel(g, input)
|
||||
else:
|
||||
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
|
||||
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
|
||||
redudced_dims = g.op("Shape", input)
|
||||
# dim could contain one or multiple dimensions
|
||||
redudced_dims = g.op(
|
||||
"Gather",
|
||||
redudced_dims,
|
||||
g.op("Constant", value_t=torch.tensor(dim)),
|
||||
axis_i=0,
|
||||
)
|
||||
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
|
||||
sub_v = g.op("Sub", input, t_mean)
|
||||
sqr_sub = g.op("Mul", sub_v, sub_v)
|
||||
keepdim_mean = 0 if dim is None else keepdim
|
||||
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
|
||||
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
|
||||
if correction is None:
|
||||
correction = 1
|
||||
if correction != 0:
|
||||
num_elements = g.op(
|
||||
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
|
||||
)
|
||||
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
|
||||
mul = g.op("Mul", var, num_elements)
|
||||
var = g.op("Div", mul, g.op("Sub", num_elements, one))
|
||||
return var, mean
|
||||
return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::std")
|
||||
@ -5633,11 +5453,6 @@ def var(g: jit_utils.GraphContext, input, *args):
|
||||
@_onnx_symbolic("aten::var_mean")
|
||||
@_beartype.beartype
|
||||
def var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure
|
||||
# out the correct arguments:
|
||||
# aten::var_mean(Tensor self, bool unbiased)
|
||||
# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False)
|
||||
# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False)
|
||||
if len(args) == 1:
|
||||
return _var_mean(g, input, None, args[0], None)
|
||||
else:
|
||||
@ -5994,42 +5809,7 @@ def linalg_vector_norm(
|
||||
keepdim: bool,
|
||||
dtype: torch._C.Value,
|
||||
):
|
||||
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
|
||||
if symbolic_helper._is_none(dim):
|
||||
self = symbolic_helper._reshape_helper(g, self, [-1])
|
||||
keepdim = False
|
||||
|
||||
if ord == math.inf:
|
||||
result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
|
||||
elif ord == -math.inf:
|
||||
result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
|
||||
elif ord == 0:
|
||||
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||||
"linalg_vector_norm", 9, 11, "ord=0 not supported", self
|
||||
)
|
||||
elif ord == 1:
|
||||
result = _reduce_op_symbolic("ReduceL1")(g, self, dim=dim, keepdim=keepdim)
|
||||
elif ord == 2:
|
||||
result = _reduce_op_symbolic("ReduceL2")(g, self, dim=dim, keepdim=keepdim)
|
||||
else:
|
||||
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
|
||||
result = symbolic_helper._reducesum_helper(
|
||||
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
|
||||
)
|
||||
result = g.op(
|
||||
"Pow",
|
||||
result,
|
||||
g.op(
|
||||
"Div",
|
||||
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
|
||||
ord_op,
|
||||
),
|
||||
)
|
||||
|
||||
if not symbolic_helper._is_none(dtype):
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type]
|
||||
return result
|
||||
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||||
@ -6842,7 +6622,9 @@ def prim_shape(g: jit_utils.GraphContext, self):
|
||||
@_onnx_symbolic("prim::max")
|
||||
@_beartype.beartype
|
||||
def prim_max(g: jit_utils.GraphContext, self, other):
|
||||
return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12)
|
||||
return symbolic_helper._op_with_optional_float_cast(
|
||||
g, "Max", self, other, opset_before=12
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::min")
|
||||
|
Reference in New Issue
Block a user