Add symbolic_opset19.py and symbolic_opset20.py to support opset 19/20, extend opset 18 support (#118828)

Start to fix https://github.com/pytorch/pytorch/issues/114801

Co-authored-by: Thiago Crepaldi <thiagofc@microsoft.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118828
Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
liqunfu
2024-03-22 18:01:28 +00:00
committed by PyTorch MergeBot
parent 34d33df056
commit bbe846f430
18 changed files with 1265 additions and 547 deletions

View File

@ -77,6 +77,8 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
.. py:module:: torch.onnx.symbolic_opset16
.. py:module:: torch.onnx.symbolic_opset17
.. py:module:: torch.onnx.symbolic_opset18
.. py:module:: torch.onnx.symbolic_opset19
.. py:module:: torch.onnx.symbolic_opset20
.. py:module:: torch.onnx.symbolic_opset7
.. py:module:: torch.onnx.symbolic_opset8
.. py:module:: torch.onnx.symbolic_opset9

View File

@ -64,9 +64,8 @@ skipIfQuantizationBackendQNNPack = _skipper(
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions
# smaller than the currently tested opset_version
# smaller than `min_opset_version`.
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
def skip_dec(func):
@functools.wraps(func)
@ -83,6 +82,8 @@ def skipIfUnsupportedMinOpsetVersion(min_opset_version):
# skips tests for all versions above max_opset_version.
# add this wrapper to prevent running the test for opset_versions
# higher than `max_opset_version`.
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
def skip_dec(func):
@functools.wraps(func)

View File

@ -480,43 +480,157 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
x = torch.randn(20, 16, 50)
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
def test_grid_sample(self):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
ops = {16: [{"op_name": "GridSample"}]}
def test_affine_grid(self):
class MyModule(Module):
def forward(self, x, grid, mode, padding_mode, align_corers):
return torch.nn.functional.grid_sample(
x, grid, mode, padding_mode, align_corners
def __init__(self, align_corners):
super().__init__()
self.align_corners = align_corners
def forward(self, theta, size):
return torch.nn.functional.affine_grid(
theta, size, align_corners=self.align_corners
)
for mode, padding_mode, align_corners in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
opset_version = 20
ops_2d = {
opset_version: [
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Concat"},
{"op_name": "AffineGrid"},
]
}
ops_3d = {
opset_version: [
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Constant"},
{"op_name": "Unsqueeze"},
{"op_name": "Concat"},
{"op_name": "AffineGrid"},
]
}
# 2D affine
theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
size_2d = torch.Size([1, 1, 2, 2])
# 3D affine
theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
size_3d = torch.Size([1, 1, 2, 2, 2])
for inputs, align_corners in itertools.product(
((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
(True, False),
):
theta, size, ops = inputs
args = (
torch.randn(n, c, h_in, w_in), # x
torch.randn(n, h_out, w_out, 2), # grid,
mode,
padding_mode,
align_corners,
theta,
size,
)
check_onnx_opsets_operator(
MyModule(),
MyModule(align_corners=align_corners),
args,
ops,
opset_versions=[16],
opset_versions=[opset_version],
training=torch.onnx.TrainingMode.TRAINING,
)
check_onnx_opsets_operator(
MyModule(),
MyModule(align_corners=align_corners),
args,
ops,
opset_versions=[16],
opset_versions=[opset_version],
training=torch.onnx.TrainingMode.EVAL,
)
def test_grid_sample(self):
class MyModule(torch.nn.Module):
def __init__(self, mode, padding_mode, align_corners):
super().__init__()
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
def forward(self, x, grid):
return torch.nn.functional.grid_sample(
x,
grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)
for mode, padding_mode, align_corners, opset_version in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
(True, False),
(16, 20),
):
def test_eval_and_training(
ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
):
args = (
torch.randn(*x_shape), # x
torch.randn(grid), # grid,
)
check_onnx_opsets_operator(
MyModule(
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
),
args,
ops,
opset_versions=[opset_version],
training=torch.onnx.TrainingMode.TRAINING,
)
check_onnx_opsets_operator(
MyModule(
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
),
args,
ops,
opset_versions=[opset_version],
training=torch.onnx.TrainingMode.EVAL,
)
ops = {opset_version: [{"op_name": "GridSample"}]}
# mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode
n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
test_eval_and_training(
ops,
opset_version,
mode,
padding_mode,
align_corners,
(n, c, h_in, w_in),
(n, h_out, w_out, 2),
)
if opset_version == 20 and mode != "bicubic":
test_eval_and_training(
ops,
opset_version,
mode,
padding_mode,
align_corners,
(n, c, d_in, h_in, w_in),
(n, d_out, h_out, w_out, 3),
)
def test_flatten(self):
class MyModule(Module):
def forward(self, x):

View File

@ -1604,7 +1604,9 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
# TODO: ceil_mode is not included in the test, because of
# https://github.com/microsoft/onnxruntime/issues/16203
# The ORT and PyTorch has different calculation for ceil_mode (the last value).
@skipIfUnsupportedMinOpsetVersion(19)
# the issue requires fix in onnx(21) (https://github.com/onnx/onnx/issues/5711)
# a fix in ORT is planned. After the fixes in place, we can add ceil_mode to the test.
@skipIfUnsupportedMinOpsetVersion(21)
def test_avgpool_3d_ceil(self):
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
@ -13393,6 +13395,113 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
class AffineGridModule(torch.nn.Module):
def __init__(self, align_corners) -> None:
super().__init__()
self.align_corners = align_corners
def forward(self, theta, size):
return torch.nn.functional.affine_grid(theta, size, self.align_corners)
@skipIfUnsupportedMinOpsetVersion(20)
@skipScriptTest()
@common_utils.parametrize(
"align_corners",
(True, False),
)
@common_utils.parametrize(
"theta_params",
(
(
10,
np.array([0.3, -0.5]),
np.array([1.5, 0.5]),
),
(
60,
np.array([-0.5, -0.5]),
np.array([3.0, 5.5]),
),
),
)
@common_utils.parametrize(
"size",
([1, 1, 3, 2], [2, 10, 2, 3]),
)
def test_affine_grid_2d(self, align_corners, theta_params, size):
angle, translation, scale = theta_params
theta = np.array([], dtype=np.float32)
for _ in range(size[0]):
angle_radian = (angle / 180.0) * np.pi
theta = np.append(
theta,
[
np.cos(angle_radian) * scale[0],
-np.sin(angle_radian),
translation[0],
np.sin(angle_radian),
np.cos(angle_radian) * scale[1],
translation[1],
],
)
theta = theta.reshape(size[0], 2, 3)
theta = torch.Tensor(theta)
self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
@skipIfUnsupportedMinOpsetVersion(20)
@skipScriptTest()
@common_utils.parametrize(
"align_corners",
(True, False),
)
@common_utils.parametrize(
"theta_params",
(
(
[10, 20],
np.array([0.3, -0.5, 1.8]),
np.array([1.5, 2.0, 0.5]),
),
(
[60, -30],
np.array([-0.5, -0.5, 0.3]),
np.array([0.3, 3.0, 5.5]),
),
),
)
@common_utils.parametrize(
"size",
([1, 1, 3, 2, 2], [2, 10, 2, 2, 3]),
)
def test_affine_grid_3d(self, align_corners, theta_params, size):
angle, translation, scale = theta_params
theta = np.array([], dtype=np.float32)
for _ in range(size[0]):
angle_radian_x = (angle[0] / 180.0) * np.pi
angle_radian_y = (angle[1] / 180.0) * np.pi
rot_matrix_x = np.array(
[
[1, 0, 0],
[0, np.cos(angle_radian_x), -np.sin(angle_radian_x)],
[0, np.sin(angle_radian_x), np.cos(angle_radian_x)],
]
)
rot_matrix_y = np.array(
[
[np.cos(angle_radian_y), 0, np.sin(angle_radian_y)],
[0, 1, 0],
[-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)],
]
)
rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y)
rot_matrix = rot_matrix * scale.reshape(3, 1)
rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1)
theta = np.append(theta, rot_matrix.flatten())
theta = theta.reshape(size[0], 3, 4)
theta = torch.Tensor(theta)
self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
@skipIfUnsupportedMinOpsetVersion(16)
@common_utils.parametrize(
"mode",
@ -13408,7 +13517,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
name_fn=lambda align_corners: str(align_corners),
)
def test_grid_sample(self, mode, padding_mode, align_corners):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
atol_rtol = {}
if (mode, padding_mode) == ("bicubic", "border"):
if align_corners:
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
else:
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
class GridSampleModule(torch.nn.Module):
def __init__(self, mode, padding_mode, align_corners) -> None:
@ -13424,13 +13541,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
input, grid, self.mode, self.padding_mode, self.align_corners
)
atol_rtol = {}
if (mode, padding_mode) == ("bicubic", "border"):
if align_corners:
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
else:
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
self.run_test(
GridSampleModule(mode, padding_mode, align_corners),
(input, grid),
@ -13438,8 +13548,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
)
# ONNX Opset 16 GridSample with 5D volumetric input is not supported.
d_in = 2
d_out = 3
volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in)
volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3)
for mode, padding_mode, align_corners in itertools.product(
@ -13457,9 +13565,16 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
False,
),
):
with self.assertRaises(
torch.onnx.errors.OnnxExporterError,
):
if self.opset_version < 20:
with self.assertRaises(
torch.onnx.errors.OnnxExporterError,
):
self.run_test(
GridSampleModule(mode, padding_mode, align_corners),
(volumetric_input_tensor, volumetric_grid_tensor),
**atol_rtol,
)
else:
self.run_test(
GridSampleModule(mode, padding_mode, align_corners),
(volumetric_input_tensor, volumetric_grid_tensor),

View File

@ -256,6 +256,7 @@ class TestUtilityFuns(_BaseTestCase):
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
@skipIfUnsupportedMaxOpsetVersion(17)
def test_constant_fold_reduceL2(self):
class ReduceModule(torch.nn.Module):
def forward(self, x):
@ -273,6 +274,7 @@ class TestUtilityFuns(_BaseTestCase):
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::ReduceL2")
@skipIfUnsupportedMaxOpsetVersion(17)
def test_constant_fold_reduceL1(self):
class NormModule(torch.nn.Module):
def forward(self, x):

View File

@ -916,11 +916,15 @@ void ProcessReduceNode(Node* n) {
size_t rank_0 = input_shape_value_0.value().size();
std::vector<::c10::ShapeSymbol> final_shape;
std::vector<int64_t> axes_vector(rank_0);
if (!n->hasAttributeS("axes")) {
std::iota(axes_vector.begin(), axes_vector.end(), 0);
} else {
if (n->hasAttributeS("axes")) {
axes_vector = n->is(attr::axes);
} else if (n->inputs().size() > 1) {
axes_vector =
ConstantValueMap::GetValueInto1DInt64Vector(n->input(1)->debugName());
} else {
std::iota(axes_vector.begin(), axes_vector.end(), 0);
}
for (auto idx : c10::irange(axes_vector.size())) {
if (axes_vector[idx] < 0) {
axes_vector[idx] += rank_0;

View File

@ -91,7 +91,7 @@ namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
const static int kInvalidOpsetVersion = -1;
const static int kMainOpsetVersion = 19;
const static int kMainOpsetVersion = 20;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
@ -116,6 +116,7 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
8, // opset 17
8, // opset 18
9, // opset 19
9, // opset 20
};
std::string getNodeStackTraceString(const Node* n) {

View File

@ -24,6 +24,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
utils,
)
@ -82,6 +84,8 @@ __all__ = [
"symbolic_opset16",
"symbolic_opset17",
"symbolic_opset18",
"symbolic_opset19",
"symbolic_opset20",
# Enums
"ExportTypes",
"OperatorExportTypes",

View File

@ -4,8 +4,8 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 19
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 17
ONNX_MAX_OPSET = 20
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 17
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import functools
import inspect
import math
import sys
import typing
import warnings
@ -23,27 +24,11 @@ import torch._C._onnx as _C_onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _type_utils, errors
from torch.onnx import _constants, _type_utils, errors, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils
from torch.types import Number
__all__ = [
"args_have_same_dtype",
"cast_pytorch_to_onnx",
"check_training_mode",
"dequantize_helper",
"is_caffe2_aten_fallback",
"is_complex_value",
"parse_args",
"pytorch_name_to_type",
"quantize_helper",
"quantized_args",
"requantize_bias_helper",
"scalar_name_to_pytorch",
"scalar_type_to_onnx",
"scalar_type_to_pytorch_type",
]
# ---------------------------------------------------------------------------------
# Helper functions
@ -1714,6 +1699,526 @@ def args_have_same_dtype(args):
return has_same_dtype
@_beartype.beartype
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
`Clip<int>(INPUT)` (opset version < 12).
Args:
g (torch._C.Graph): graph to write the ONNX representation into.
op_name (str): operator name in ONNX.
*args (tuple): operands to the operator.
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
indicating the smallest opset version to trigger such casting behavior and "target_float_t"
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
Returns:
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
"""
opset_before = kwargs.pop("opset_before", None)
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
inputs = list(args)
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
require_cast = not _is_fp(inputs[0]) and (
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
)
if require_cast:
for input in inputs:
if input.isCompleteTensor():
input_scalar_type = _type_utils.JitScalarType.from_value(input)
if input_scalar_type != dtype_0:
raise errors.SymbolicValueError(
f"Inputs of {op_name} must have same dtype."
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
input,
)
for i, input in enumerate(inputs):
if input.isCompleteTensor() and not _is_fp(input):
inputs[i] = g.op(
"Cast",
input,
to_i=target_float_t.onnx_type(),
)
self = g.op(op_name, *inputs, **kwargs)
if require_cast:
self = g.op("Cast", self, to_i=dtype_0.onnx_type())
return self
@_beartype.beartype
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
)
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
# This check only covers traced modules where dtype is present
# pytorch reduce-ops cast all other integral types to int64
if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64:
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
return self
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_beartype.beartype
def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None or dim == tuple():
# Dim can be 0, which will cause (not dim) == True. So we don't want to do
# (not dim)
# all-reduce path
return _handle_reduce_dim_none(g, self, onnx_op_name)
else:
# dim-reduce path
keepdim = _get_const(keepdim, "i", "keepdim")
if g.opset < 18:
desc = "is" if allow_multi_dim_support else "i"
dim = _get_const(dim, desc, "dim")
dim_list = dim if allow_multi_dim_support else [dim]
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
else:
if _is_value(dim):
axes = dim
else:
if allow_multi_dim_support:
axes = g.op(
"Constant", value_t=torch.tensor(dim, dtype=torch.long)
)
else:
axes = g.op(
"Constant", value_t=torch.tensor([dim], dtype=torch.long)
)
return g.op(onnx_op_name, self, axes, keepdims_i=keepdim)
return symbolic
@_beartype.beartype
def _overload_by_arg_count(fn):
@functools.wraps(fn)
@_beartype.beartype
def wrapper(g, *args):
overloads = fn(g, *args)
for overload in overloads:
arg_descriptors = overload._arg_descriptors
if len(arg_descriptors) == len(args):
return overload(g, *args)
return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments")
return wrapper
@_beartype.beartype
def _reduce_with_dtype_helper(
onnx_op: str, name: str, allow_multi_dim_support: bool = True
):
symbolic = _reduce_op_symbolic_helper(
onnx_op, allow_multi_dim_support=allow_multi_dim_support
)
@_overload_by_arg_count
def reduce(g, *args, **kwargs):
@quantized_args(True)
@parse_args("v", "none")
def reduce_nodim(g, self, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
dtype = _get_const(dtype, "i", "dtype")
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
self = g.op("Cast", self, to_i=dtype_onnx)
elif dtype.node().kind() != "prim::Constant":
return _unimplemented(name, "dtype", dtype)
result = symbolic(g, self)
if dtype_onnx is not None:
result_dtype_onnx = _type_utils.JitScalarType.from_value(
result
).onnx_type()
if result_dtype_onnx != dtype_onnx:
result = g.op("Cast", result, to_i=dtype_onnx)
return result
dim_desc = "is" if allow_multi_dim_support else "i"
@quantized_args(True)
@parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type]
def reduce_dim(g, self, dim, keepdim, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
dtype = _get_const(dtype, "i", "dtype")
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
self = g.op("Cast", self, to_i=dtype_onnx)
elif dtype.node().kind() != "prim::Constant":
return _unimplemented(name, "dtype", dtype)
result = symbolic(g, self, dim, keepdim)
if dtype_onnx is not None:
result_dtype_onnx = _type_utils.JitScalarType.from_value(
result
).onnx_type()
if result_dtype_onnx != dtype_onnx:
result = g.op("Cast", result, to_i=dtype_onnx)
return result
return reduce_nodim, reduce_dim
return reduce
@_beartype.beartype
def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.max(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
# torch.max(input, other)
if keepdim is None:
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
# torch.max(input, dim, keepdim)
else:
keepdim = _get_const(keepdim, "i", "keepdim")
dim = _get_const(dim_or_y, "i", "dim")
if g.opset < 18:
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
else:
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
max = g.op("ReduceMax", self, axes, keepdims_i=keepdim)
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
return max, indices
@_beartype.beartype
def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.min(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
# torch.min(input, other)
if keepdim is None:
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
# torch.min(input, dim, keepdim)
else:
keepdim = _get_const(keepdim, "i", "keepdim")
dim = _get_const(dim_or_y, "i", "dim")
if g.opset < 18:
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
else:
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
min = g.op("ReduceMin", self, axes, keepdims_i=keepdim)
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
return min, indices
@_beartype.beartype
def _numel_helper(g: jit_utils.GraphContext, self):
shape = g.op("Shape", self)
return g.op("ReduceProd", shape, keepdims_i=0)
@parse_args("v", "is", "i", "i")
@_beartype.beartype
def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim):
if g.opset < 18:
if dim is None:
mean = g.op("ReduceMean", input, keepdims_i=0)
t_mean = mean
num_elements = _numel_helper(g, input)
else:
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
redudced_dims = g.op("Shape", input)
# dim could contain one or multiple dimensions
redudced_dims = g.op(
"Gather",
redudced_dims,
g.op("Constant", value_t=torch.tensor(dim)),
axis_i=0,
)
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
sub_v = g.op("Sub", input, t_mean)
sqr_sub = g.op("Mul", sub_v, sub_v)
keepdim_mean = 0 if dim is None else keepdim
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
if correction is None:
correction = 1
if correction != 0:
num_elements = g.op(
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
)
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
mul = g.op("Mul", var, num_elements)
var = g.op("Div", mul, g.op("Sub", num_elements, one))
return var, mean
else:
axes = None
if dim is None:
mean = g.op("ReduceMean", input, keepdims_i=0)
t_mean = mean
num_elements = _numel_helper(g, input)
else:
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim)
t_mean = g.op("ReduceMean", input, axes, keepdims_i=1)
redudced_dims = g.op("Shape", input)
# dim could contain one or multiple dimensions
redudced_dims = g.op(
"Gather",
redudced_dims,
g.op("Constant", value_t=torch.tensor(dim)),
axis_i=0,
)
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
sub_v = g.op("Sub", input, t_mean)
sqr_sub = g.op("Mul", sub_v, sub_v)
keepdim_mean = 0 if dim is None else keepdim
if axes is None:
var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean)
else:
var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean)
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
if correction is None:
correction = 1
if correction != 0:
num_elements = g.op(
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
)
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
mul = g.op("Mul", var, num_elements)
var = g.op("Div", mul, g.op("Sub", num_elements, one))
return var, mean
@_beartype.beartype
def _embedding_bag_helper(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return _onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
zero = g.op("Constant", value_t=torch.tensor([0]))
indices_len = _unsqueeze_helper(
g,
_size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))),
[0],
)
if not include_last_offset:
offsets = [offsets, indices_len]
offsets = g.op("Concat", *offsets, axis_i=0)
# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
offsets_starts = _slice_helper(
g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
)
offsets_ends = _slice_helper(
g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
)
loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0)))
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
indices_start = loop_context.op(
"Gather", offsets_starts, block_input_iter, axis_i=0
)
indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
indices_start = _unsqueeze_helper(loop_context, indices_start, [0])
indices_end = _unsqueeze_helper(loop_context, indices_end, [0])
indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
if not _is_none(per_sample_weights):
per_sample_weights_row = loop_context.op(
"Slice", per_sample_weights, indices_start, indices_end, zero
)
per_sample_weights_row = _unsqueeze_helper(
loop_context, per_sample_weights_row, [1]
)
embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = _reducesum_helper(
loop_context, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
if loop_context.opset < 18:
embeddings = loop_context.op(
"ReduceMean", embeddings, axes_i=[0], keepdims_i=0
)
else:
axes = loop_context.op(
"Constant", value_t=torch.tensor([0], dtype=torch.long)
)
embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0)
else:
if loop_context.opset < 18:
embeddings = loop_context.op(
"ReduceMax", embeddings, axes_i=[0], keepdims_i=0
)
else:
axes = loop_context.op(
"Constant", value_t=torch.tensor([0], dtype=torch.long)
)
embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, embeddings)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None
@_beartype.beartype
def _linalg_vector_norm_helper(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: float,
dim: Optional[Sequence[int]],
keepdim: bool,
dtype: torch._C.Value,
):
axes = None
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
if _is_none(dim):
self = _reshape_helper(g, self, [-1])
keepdim = False
elif g.opset >= 18:
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
if ord == math.inf:
if g.opset < 18:
result = g.op(
"ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
)
else:
if axes is None:
result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim)
else:
result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim)
elif ord == -math.inf:
if g.opset < 18:
result = g.op(
"ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim
)
else:
if axes is None:
result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim)
else:
result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim)
elif ord == 0:
if g.opset < 11:
return _onnx_opset_unsupported_detailed(
"linalg_vector_norm", 9, 11, "ord=0 not supported", self
)
else:
if dim is None:
self = _reshape_helper(
g,
self,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
)
keepdim = False
cond_op = g.op(
"Not",
g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))),
)
cond_op = g.op(
"Cast",
cond_op,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim)
elif ord == 1:
if g.opset < 18:
result = _reduce_op_symbolic_helper("ReduceL1")(
g, self, dim=dim, keepdim=keepdim
)
else:
if axes is None:
result = _reduce_op_symbolic_helper("ReduceL1")(
g, self, keepdim=keepdim
)
else:
result = _reduce_op_symbolic_helper("ReduceL1")(
g, self, axes, keepdim=keepdim
)
elif ord == 2:
if g.opset < 18:
result = _reduce_op_symbolic_helper("ReduceL2")(
g, self, dim=dim, keepdim=keepdim
)
else:
if axes is None:
result = _reduce_op_symbolic_helper("ReduceL2")(
g, self, keepdim=keepdim
)
else:
result = _reduce_op_symbolic_helper("ReduceL2")(
g, self, axes, keepdim=keepdim
)
else:
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
result = _reducesum_helper(
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
)
result = g.op(
"Pow",
result,
g.op(
"Div",
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
ord_op,
),
)
if not _is_none(dtype):
dtype = _get_const(dtype, "i", "dtype")
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type]
return result
# Deprecated. Internally use _type_utils.ScalarType
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator

View File

@ -70,15 +70,6 @@ __all__ = [
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::div")
@_beartype.beartype
def div(g: jit_utils.GraphContext, self, other, *args):
@ -276,20 +267,20 @@ def _aten_max_pool_with_indices_onnx(
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool1d_with_indices",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool1d_with_indices",
1,
return_indices=True,
@ -299,7 +290,7 @@ def _aten_max_pool_with_indices_onnx(
@_onnx_symbolic(
"aten::max_pool2d_with_indices",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool2d_with_indices",
2,
return_indices=True,
@ -309,7 +300,7 @@ def _aten_max_pool_with_indices_onnx(
@_onnx_symbolic(
"aten::max_pool3d_with_indices",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool3d_with_indices",
3,
return_indices=True,
@ -397,15 +388,15 @@ def _adjust_attributes_of_avg_pool(
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[_apply_params("avg_pool1d", 1)],
decorate=[symbolic_helper._apply_params("avg_pool1d", 1)],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[_apply_params("avg_pool2d", 2)],
decorate=[symbolic_helper._apply_params("avg_pool2d", 2)],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[_apply_params("avg_pool3d", 3)],
decorate=[symbolic_helper._apply_params("avg_pool3d", 3)],
)
@_beartype.beartype
def _avg_pool(name, expand_size):
@ -443,27 +434,27 @@ def _avg_pool(name, expand_size):
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):

View File

@ -17,7 +17,6 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
@ -86,15 +85,6 @@ __all__ = [
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::hardtanh")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
@ -111,7 +101,7 @@ def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val:
"Constant",
value_t=torch.tensor(max_val, dtype=scalar_type.dtype()),
)
return opset9._op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min_val, max_val, opset_before=12
)
@ -146,7 +136,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
symbolic_helper._get_tensor_rank(min) == 0
and symbolic_helper._get_tensor_rank(max) == 0
):
return opset9._op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
@ -160,11 +150,13 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(min) == 0:
max = opset9.unused(g)
return opset9._op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
return symbolic_helper._op_with_optional_float_cast(
g, "Max", self, min, opset_before=12
)
@_onnx_symbolic("aten::clamp_max")
@ -174,11 +166,13 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(max) == 0:
min = opset9.unused(g)
return opset9._op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
return symbolic_helper._op_with_optional_float_cast(
g, "Min", self, max, opset_before=12
)
@_onnx_symbolic("aten::relu6")
@ -348,31 +342,31 @@ def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bicubic2d",
decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")],
)
@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):
@ -1281,26 +1275,7 @@ def linalg_vector_norm(
keepdim: bool,
dtype,
):
if ord == 0:
if dim is None:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
)
keepdim = False
cond_op = g.op(
"Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
)
cond_op = g.op(
"Cast",
cond_op,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return symbolic_helper._reducesum_helper(
g, cond_op, axes_i=dim, keepdims_i=keepdim
)
else:
return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::embedding_bag")
@ -1318,86 +1293,18 @@ def embedding_bag(
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return symbolic_helper._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
zero = g.op("Constant", value_t=torch.tensor([0]))
indices_len = symbolic_helper._unsqueeze_helper(
return symbolic_helper._embedding_bag_helper(
g,
symbolic_helper._size_helper(
g, indices, g.op("Constant", value_t=torch.tensor(0))
),
[0],
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)
if not include_last_offset:
offsets = [offsets, indices_len]
offsets = g.op("Concat", *offsets, axis_i=0)
# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
offsets_starts = symbolic_helper._slice_helper(
g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
)
offsets_ends = symbolic_helper._slice_helper(
g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
)
loop_len = symbolic_helper._size_helper(
g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))
)
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
indices_start = loop_context.op(
"Gather", offsets_starts, block_input_iter, axis_i=0
)
indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
indices_start = symbolic_helper._unsqueeze_helper(loop_context, indices_start, [0])
indices_end = symbolic_helper._unsqueeze_helper(loop_context, indices_end, [0])
indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
if not symbolic_helper._is_none(per_sample_weights):
per_sample_weights_row = loop_context.op(
"Slice", per_sample_weights, indices_start, indices_end, zero
)
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
loop_context, per_sample_weights_row, [1]
)
embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = symbolic_helper._reducesum_helper(
loop_context, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
embeddings = loop_context.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = loop_context.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, embeddings)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None
@_onnx_symbolic("aten::embedding_renorm")

View File

@ -21,15 +21,6 @@ from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::softmax")
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
@ -412,7 +403,7 @@ def fake_quantize_per_tensor_affine(
def _reduce_op_symbolic(onnx_op_name):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = opset9._maybe_cast_reduce_op_input(g, self)
self = symbolic_helper._maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
@ -425,13 +416,13 @@ def _reduce_op_symbolic(onnx_op_name):
@_onnx_symbolic(
"aten::sum",
decorate=[_apply_params("ReduceSum", "sum")],
decorate=[symbolic_helper._apply_params("ReduceSum", "sum")],
)
@_beartype.beartype
def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)
@opset9.overload_by_arg_count
@symbolic_helper._overload_by_arg_count
@_beartype.beartype
def reduce(g, *args, **kwargs):
@symbolic_helper.parse_args("v", "none")

View File

@ -14,19 +14,23 @@ New operators:
Resize
ScatterElements
ScatterND
Split
"""
import functools
from typing import Sequence
from typing import List, Optional, Sequence, Tuple
import torch
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["col2im"]
__all__ = [
"col2im",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@ -68,3 +72,190 @@ def col2im(
pads_i=adjusted_padding,
strides_i=stride,
)
@_onnx_symbolic(
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
)
@_onnx_symbolic(
"aten::prod",
decorate=[
symbolic_helper._apply_params(
"ReduceProd", "prod", allow_multi_dim_support=False
)
],
)
@_beartype.beartype
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
)
@_onnx_symbolic("aten::native_layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
@_beartype.beartype
def _native_layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
) -> Tuple[_C.Value, _C.Value, _C.Value]:
return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
@_onnx_symbolic("aten::glu")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def _glu(g: jit_utils.GraphContext, input, dim):
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
return g.op("Mul", first, g.op("Sigmoid", second))
@_onnx_symbolic("aten::max")
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@symbolic_helper.quantized_args(True, True)
@_beartype.beartype
def maximum(g: jit_utils.GraphContext, input, other):
return max(g, input, dim_or_y=other)
@_onnx_symbolic("aten::min")
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@symbolic_helper.quantized_args(True, True)
@_beartype.beartype
def minimum(g: jit_utils.GraphContext, input, other):
return min(g, input, dim_or_y=other)
@_onnx_symbolic("aten::amax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::amin")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::aminmax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i")
@_beartype.beartype
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
if not symbolic_helper._is_none(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
"ReduceMax", self, axes, keepdims_i=keepdim
)
else:
return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
"ReduceMax", self, keepdims_i=keepdim
)
@_onnx_symbolic("aten::var_mean")
@_beartype.beartype
def _var_mean(g: jit_utils.GraphContext, input, *args):
if len(args) == 1:
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
else:
return symbolic_helper._var_mean_helper(g, input, *args)
@_onnx_symbolic("aten::logsumexp")
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
if dim is None:
return g.op("ReduceLogSumExp", input, keepdims_i=0)
else:
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
@_onnx_symbolic("aten::linalg_matrix_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
@_beartype.beartype
def _linalg_matrix_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: torch._C.Value,
dim: List[int],
keepdim: bool,
dtype: torch._C.Value,
):
return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
return symbolic_helper._embedding_bag_helper(
g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
@_beartype.beartype
def linalg_vector_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: float,
dim: Optional[Sequence[int]],
keepdim: bool,
dtype: torch._C.Value,
):
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)

View File

@ -0,0 +1,32 @@
"""This file exports ONNX ops for opset 19.
Note [ONNX Operators that are added/updated in opset 19]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set
New operators:
AveragePool
Cast
CastLike
Constant
DeformConv
DequantizeLinear
Equal
Identity
If
Loop
Pad
QuantizeLinear
Reshape
Resize
Scan
Shape
Size
"""
from typing import List
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__: List[str] = []

View File

@ -0,0 +1,85 @@
"""This file exports ONNX ops for opset 20.
Note [ONNX Operators that are added/updated in opset 20]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
New operators:
AffineGrid
ConstantOfShape
DFT
Gelu
GridSample
ImageDecoder
IsInf
IsNaN
ReduceMax
ReduceMin
RegexFullMatch
StringConcat
StringSplit
"""
import functools
import torch.nn.functional as F
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["_grid_sampler", "_affine_grid_generator"]
def convert_grid_sample_mode(mode_s):
return (
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
)
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def _grid_sampler(
g: jit_utils.GraphContext,
input: _C.Value,
grid: _C.Value,
mode_enum: int,
padding_mode_enum: int,
align_corners: bool,
):
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
mode_s = convert_grid_sample_mode(mode_s)
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::affine_grid_generator")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def _affine_grid_generator(
g: jit_utils.GraphContext,
theta: _C.Value,
size: _C.Value,
align_corners: bool,
):
return g.op(
"AffineGrid",
theta,
size,
align_corners_i=int(align_corners),
)

View File

@ -64,38 +64,29 @@ for block_listed_op in block_listed_operators:
)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):

View File

@ -185,7 +185,6 @@ __all__ = [
"ones_like",
"ones",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
@ -293,15 +292,6 @@ __all__ = [
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
def _export(name: str):
"""Exports the function in the current global namespace."""
@ -774,120 +764,27 @@ def _slice(g: jit_utils.GraphContext, input, axes, starts, ends):
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
@_beartype.beartype
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
)
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
# This check only covers traced modules where dtype is present
# pytorch reduce-ops cast all other integral types to int64
if (
not symbolic_helper._is_fp(self)
and scalar_type != _type_utils.JitScalarType.INT64
):
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
return self
@_beartype.beartype
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None or dim == tuple():
# Dim can be 0, which will cause (not dim) == True. So we don't want to do
# (not dim)
# all-reduce path
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
else:
# dim-reduce path
desc = "is" if allow_multi_dim_support else "i"
dim, keepdim = symbolic_helper._get_const(
dim, desc, "dim"
), symbolic_helper._get_const(keepdim, "i", "keepdim")
dim_list = dim if allow_multi_dim_support else [dim]
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
return symbolic
@_beartype.beartype
def overload_by_arg_count(fn):
@functools.wraps(fn)
@_beartype.beartype
def wrapper(g, *args):
overloads = fn(g, *args)
for overload in overloads:
arg_descriptors = overload._arg_descriptors
if len(arg_descriptors) == len(args):
return overload(g, *args)
return symbolic_helper._unimplemented(
f"aten::{fn.__name__}", f"with {len(args)} arguments"
)
return wrapper
@_onnx_symbolic("aten::sum", decorate=[_apply_params("ReduceSum", "sum")])
@_onnx_symbolic("aten::mean", decorate=[_apply_params("ReduceMean", "mean")])
@_onnx_symbolic(
"aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")]
)
@_onnx_symbolic(
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
)
# torch.prod does not support multidimensional "dim"
@_onnx_symbolic(
"aten::prod",
decorate=[_apply_params("ReduceProd", "prod", allow_multi_dim_support=False)],
decorate=[
symbolic_helper._apply_params(
"ReduceProd", "prod", allow_multi_dim_support=False
)
],
)
@_beartype.beartype
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
symbolic = _reduce_op_symbolic(
onnx_op, allow_multi_dim_support=allow_multi_dim_support
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
)
@overload_by_arg_count
def reduce(g, *args, **kwargs):
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "none")
def reduce_nodim(g, self, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
self = g.op("Cast", self, to_i=dtype_onnx)
elif dtype.node().kind() != "prim::Constant":
return symbolic_helper._unimplemented(name, "dtype", dtype)
result = symbolic(g, self)
if dtype_onnx is not None:
result_dtype_onnx = _type_utils.JitScalarType.from_value(
result
).onnx_type()
if result_dtype_onnx != dtype_onnx:
result = g.op("Cast", result, to_i=dtype_onnx)
return result
dim_desc = "is" if allow_multi_dim_support else "i"
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type]
def reduce_dim(g, self, dim, keepdim, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
self = g.op("Cast", self, to_i=dtype_onnx)
elif dtype.node().kind() != "prim::Constant":
return symbolic_helper._unimplemented(name, "dtype", dtype)
result = symbolic(g, self, dim, keepdim)
if dtype_onnx is not None:
result_dtype_onnx = _type_utils.JitScalarType.from_value(
result
).onnx_type()
if result_dtype_onnx != dtype_onnx:
result = g.op("Cast", result, to_i=dtype_onnx)
return result
return reduce_nodim, reduce_dim
return reduce
@_onnx_symbolic("aten::cumsum")
@symbolic_helper.parse_args("v", "i", "none")
@ -1356,65 +1253,13 @@ def mish(g: jit_utils.GraphContext, input):
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
@_beartype.beartype
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
`Clip<int>(INPUT)` (opset version < 12).
Args:
g (torch._C.Graph): graph to write the ONNX representation into.
op_name (str): operator name in ONNX.
*args (tuple): operands to the operator.
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
indicating the smallest opset version to trigger such casting behavior and "target_float_t"
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
Returns:
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
"""
opset_before = kwargs.pop("opset_before", None)
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
inputs = list(args)
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
require_cast = not symbolic_helper._is_fp(inputs[0]) and (
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
)
if require_cast:
for input in inputs:
if input.isCompleteTensor():
input_scalar_type = _type_utils.JitScalarType.from_value(input)
if input_scalar_type != dtype_0:
raise errors.SymbolicValueError(
f"Inputs of {op_name} must have same dtype."
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
input,
)
for i, input in enumerate(inputs):
if input.isCompleteTensor() and not symbolic_helper._is_fp(input):
inputs[i] = g.op(
"Cast",
input,
to_i=target_float_t.onnx_type(),
)
self = g.op(op_name, *inputs, **kwargs)
if require_cast:
self = g.op("Cast", self, to_i=dtype_0.onnx_type())
return self
@_onnx_symbolic("aten::relu")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def relu(g: jit_utils.GraphContext, input):
return _op_with_optional_float_cast(g, "Relu", input, opset_before=14)
return symbolic_helper._op_with_optional_float_cast(
g, "Relu", input, opset_before=14
)
@_onnx_symbolic("aten::relu6")
@ -1603,7 +1448,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
),
_export("max_pool1d"),
@ -1612,7 +1457,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
),
_export("max_pool2d"),
@ -1621,7 +1466,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
),
_export("max_pool3d"),
@ -1716,21 +1561,21 @@ max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")(
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[
_apply_params("avg_pool1d", torch.nn.modules.utils._single),
symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single),
_export("avg_pool1d"),
],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[
_apply_params("avg_pool2d", torch.nn.modules.utils._pair),
symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair),
_export("avg_pool2d"),
],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[
_apply_params("avg_pool3d", torch.nn.modules.utils._triple),
symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple),
_export("avg_pool3d"),
],
)
@ -1762,7 +1607,7 @@ def _avg_pool(name, tuple_fn):
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
if count_include_pad:
input = _op_with_optional_float_cast(
input = symbolic_helper._op_with_optional_float_cast(
g,
"Pad",
input,
@ -1794,7 +1639,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_avg_pool1d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
),
_export("adaptive_avg_pool1d"),
@ -1803,7 +1648,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_avg_pool2d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
),
_export("adaptive_avg_pool2d"),
@ -1812,7 +1657,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_avg_pool3d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
),
_export("adaptive_avg_pool3d"),
@ -1821,7 +1666,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_max_pool1d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_max_pool1d",
"MaxPool",
torch.nn.modules.utils._single,
@ -1833,7 +1678,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_max_pool2d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_max_pool2d",
"MaxPool",
torch.nn.modules.utils._pair,
@ -1845,7 +1690,7 @@ def _avg_pool(name, tuple_fn):
@_onnx_symbolic(
"aten::adaptive_max_pool3d",
decorate=[
_apply_params(
symbolic_helper._apply_params(
"adaptive_max_pool3d",
"MaxPool",
torch.nn.modules.utils._triple,
@ -1961,7 +1806,7 @@ def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value):
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
)
@ -2016,7 +1861,7 @@ def reflection_pad(g: jit_utils.GraphContext, input, padding):
mode = "reflect"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
)
@ -2029,7 +1874,7 @@ def replication_pad(g: jit_utils.GraphContext, input, padding):
mode = "edge"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
)
@ -2059,42 +1904,42 @@ def pad(
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[
_apply_params("upsample_nearest1d", 3, "nearest"),
symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"),
_export("upsample_nearest1d"),
],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[
_apply_params("upsample_nearest2d", 4, "nearest"),
symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"),
_export("upsample_nearest2d"),
],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[
_apply_params("upsample_nearest3d", 5, "nearest"),
symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"),
_export("upsample_nearest3d"),
],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[
_apply_params("upsample_linear1d", 3, "linear"),
symbolic_helper._apply_params("upsample_linear1d", 3, "linear"),
_export("upsample_linear1d"),
],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[
_apply_params("upsample_bilinear2d", 4, "linear"),
symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"),
_export("upsample_bilinear2d"),
],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[
_apply_params("upsample_trilinear3d", 5, "linear"),
symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"),
_export("upsample_trilinear3d"),
],
)
@ -2942,7 +2787,15 @@ def native_layer_norm(
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
mean = g.op("ReduceMean", input, axes_i=axes)
if g.opset < 18:
mean = g.op("ReduceMean", input, axes_i=axes)
else:
mean = g.op(
"ReduceMean",
input,
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
)
numerator = sub(g, input, mean)
# Cast it to eps dtype to avoid precision loss
@ -2957,7 +2810,15 @@ def native_layer_norm(
)
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
if g.opset < 18:
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
else:
variance = g.op(
"ReduceMean",
pow(g, numerator, two_cst),
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
)
denominator = sqrt(g, g.op("Add", variance, eps_cst))
normalized = g.op("Div", numerator, denominator)
@ -3405,7 +3266,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
return clamp_min(g, self, min)
else:
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g,
"Clip",
self,
@ -3422,13 +3283,15 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
@_beartype.beartype
def clamp_min(g: jit_utils.GraphContext, self, min):
if symbolic_helper._is_constant(min):
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
)
else:
dtype = _type_utils.JitScalarType.from_value(self)
min = g.op("Cast", min, to_i=dtype.onnx_type())
return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
return symbolic_helper._op_with_optional_float_cast(
g, "Max", self, min, opset_before=12
)
@_onnx_symbolic("aten::clamp_max")
@ -3436,13 +3299,15 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
@_beartype.beartype
def clamp_max(g: jit_utils.GraphContext, self, max):
if symbolic_helper._is_constant(max):
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
)
else:
dtype = _type_utils.JitScalarType.from_value(self)
max = g.op("Cast", max, to_i=dtype.onnx_type())
return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
return symbolic_helper._op_with_optional_float_cast(
g, "Min", self, max, opset_before=12
)
@_onnx_symbolic("aten::max")
@ -3451,19 +3316,7 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.max(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
# torch.max(input, other)
if keepdim is None:
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
# torch.max(input, dim, keepdim)
else:
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
return max, indices
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@ -3477,19 +3330,7 @@ def maximum(g: jit_utils.GraphContext, input, other):
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.min(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
# torch.min(input, other)
if keepdim is None:
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
# torch.min(input, dim, keepdim)
else:
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
return min, indices
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@ -3550,22 +3391,28 @@ def dropout(g: jit_utils.GraphContext, input, p, train):
@_onnx_symbolic(
"aten::alpha_dropout_", decorate=[_apply_params("aten::alpha_dropout_")]
"aten::alpha_dropout_",
decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")],
) # See Note [Export inplace]
@_onnx_symbolic(
"aten::feature_alpha_dropout_",
decorate=[_apply_params("aten::feature_alpha_dropout_")],
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")],
)
@_onnx_symbolic(
"aten::feature_dropout_", decorate=[_apply_params("aten::feature_dropout_")]
"aten::feature_dropout_",
decorate=[symbolic_helper._apply_params("aten::feature_dropout_")],
)
@_onnx_symbolic(
"aten::feature_alpha_dropout",
decorate=[_apply_params("aten::feature_alpha_dropout")],
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")],
)
@_onnx_symbolic("aten::alpha_dropout", decorate=[_apply_params("aten::alpha_dropout")])
@_onnx_symbolic(
"aten::feature_dropout", decorate=[_apply_params("aten::feature_dropout")]
"aten::alpha_dropout",
decorate=[symbolic_helper._apply_params("aten::alpha_dropout")],
)
@_onnx_symbolic(
"aten::feature_dropout",
decorate=[symbolic_helper._apply_params("aten::feature_dropout")],
)
@_beartype.beartype
def _unsupported_dropout(name: str):
@ -3585,9 +3432,9 @@ def _unsupported_dropout(name: str):
@_beartype.beartype
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
if p == 1:
f = _reduce_op_symbolic("ReduceL1")
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1")
elif p == 2:
f = _reduce_op_symbolic("ReduceL2")
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2")
else:
raise errors.SymbolicValueError(
"ONNX export only p-norms with p of 1 or 2", self
@ -4135,7 +3982,7 @@ def slice(g: jit_utils.GraphContext, self, *args):
@symbolic_helper.parse_args("v", "f", "f")
@_beartype.beartype
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
return _op_with_optional_float_cast(
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
)
@ -4283,8 +4130,7 @@ def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
@_onnx_symbolic("aten::numel")
@_beartype.beartype
def numel(g: jit_utils.GraphContext, self):
shape = g.op("Shape", self)
return g.op("ReduceProd", shape, keepdims_i=0)
return symbolic_helper._numel_helper(g, self)
@_onnx_symbolic("aten::topk")
@ -4974,12 +4820,16 @@ def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh):
), symbolic_helper._squeeze_helper(g, c_outs, [0])
@_onnx_symbolic("aten::gru", decorate=[_apply_params("GRU"), _export("gru")])
@_onnx_symbolic(
"aten::rnn_tanh", decorate=[_apply_params("RNN_TANH"), _export("rnn_tanh")]
"aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")]
)
@_onnx_symbolic(
"aten::rnn_relu", decorate=[_apply_params("RNN_RELU"), _export("rnn_relu")]
"aten::rnn_tanh",
decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")],
)
@_onnx_symbolic(
"aten::rnn_relu",
decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")],
)
def _one_hidden_rnn(kind: str):
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
@ -5583,37 +5433,7 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
@symbolic_helper.parse_args("v", "is", "i", "i")
@_beartype.beartype
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim):
if dim is None:
mean = g.op("ReduceMean", input, keepdims_i=0)
t_mean = mean
num_elements = numel(g, input)
else:
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
redudced_dims = g.op("Shape", input)
# dim could contain one or multiple dimensions
redudced_dims = g.op(
"Gather",
redudced_dims,
g.op("Constant", value_t=torch.tensor(dim)),
axis_i=0,
)
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
sub_v = g.op("Sub", input, t_mean)
sqr_sub = g.op("Mul", sub_v, sub_v)
keepdim_mean = 0 if dim is None else keepdim
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
if correction is None:
correction = 1
if correction != 0:
num_elements = g.op(
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
)
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
mul = g.op("Mul", var, num_elements)
var = g.op("Div", mul, g.op("Sub", num_elements, one))
return var, mean
return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim)
@_onnx_symbolic("aten::std")
@ -5633,11 +5453,6 @@ def var(g: jit_utils.GraphContext, input, *args):
@_onnx_symbolic("aten::var_mean")
@_beartype.beartype
def var_mean(g: jit_utils.GraphContext, input, *args):
# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure
# out the correct arguments:
# aten::var_mean(Tensor self, bool unbiased)
# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False)
# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False)
if len(args) == 1:
return _var_mean(g, input, None, args[0], None)
else:
@ -5994,42 +5809,7 @@ def linalg_vector_norm(
keepdim: bool,
dtype: torch._C.Value,
):
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
if symbolic_helper._is_none(dim):
self = symbolic_helper._reshape_helper(g, self, [-1])
keepdim = False
if ord == math.inf:
result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
elif ord == -math.inf:
result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
elif ord == 0:
return symbolic_helper._onnx_opset_unsupported_detailed(
"linalg_vector_norm", 9, 11, "ord=0 not supported", self
)
elif ord == 1:
result = _reduce_op_symbolic("ReduceL1")(g, self, dim=dim, keepdim=keepdim)
elif ord == 2:
result = _reduce_op_symbolic("ReduceL2")(g, self, dim=dim, keepdim=keepdim)
else:
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
result = symbolic_helper._reducesum_helper(
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
)
result = g.op(
"Pow",
result,
g.op(
"Div",
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
ord_op,
),
)
if not symbolic_helper._is_none(dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type]
return result
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::linalg_matrix_norm")
@ -6842,7 +6622,9 @@ def prim_shape(g: jit_utils.GraphContext, self):
@_onnx_symbolic("prim::max")
@_beartype.beartype
def prim_max(g: jit_utils.GraphContext, self, other):
return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12)
return symbolic_helper._op_with_optional_float_cast(
g, "Max", self, other, opset_before=12
)
@_onnx_symbolic("prim::min")