Files
pytorch/torch/onnx/symbolic_opset9.py
Aaron Orenstein 5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00

6634 lines
219 KiB
Python

# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 9.
Opset 9 is supported by ONNX release 1.4.1
release on 01/23/19
"""
from __future__ import annotations
import builtins
import functools
import math
import sys
import warnings
from typing import Callable, Sequence, TYPE_CHECKING
import torch
import torch._C._onnx as _C_onnx
import torch.nn.modules.utils
import torch.onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
if TYPE_CHECKING:
from torch.types import Number
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__ = [
"abs",
"acos",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp_max",
"clamp_min",
"clamp",
"clone",
"constant_pad_nd",
"contiguous",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"conv1d",
"conv2d",
"conv3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding_bag",
"embedding",
"empty_like",
"empty",
"eq",
"erf",
"exp",
"expand_as",
"expand",
"eye",
"fill",
"flatten",
"floor_divide",
"floor",
"floordiv",
"frobenius_norm",
"full_like",
"full",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"index",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log_sigmoid",
"log_softmax",
"log",
"log10",
"log1p",
"log2",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm_cell",
"lstm",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max_pool1d_with_indices",
"max_pool2d_with_indices",
"max_pool3d_with_indices",
"max",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero_numpy",
"nonzero",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones_like",
"ones",
"onnx_placeholder",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant_chunk",
"prim_constant_split",
"prim_constant",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand_like",
"rand",
"randint_like",
"randint",
"randn_like",
"randn",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat_interleave",
"repeat",
"replication_pad",
"reshape_as",
"reshape",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter_add",
"scatter",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split_with_sizes",
"split",
"sqrt",
"square",
"squeeze",
"stack",
"std_mean",
"std",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split_with_sizes",
"unsafe_split",
"unsqueeze",
"unsupported_complex_operators",
"noop_complex_operators",
"unused",
"var_mean",
"var",
"view_as",
"view",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zeros_like",
"zeros",
"zero",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
def _export(name: str):
"""Exports the function in the current global namespace."""
def wrapper(func):
globals()[name] = func
__all__.append(name)
return func
return wrapper
def unused(g):
"""Represents "missing" optional inputs."""
n = g.op("prim::Constant")
n.setType(_C.OptionalType.ofTensor())
return n
@_onnx_symbolic("aten::_shape_as_tensor")
def _shape_as_tensor(g: jit_utils.GraphContext, input):
return g.op("Shape", input)
@_onnx_symbolic("aten::_reshape_from_tensor")
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape):
if isinstance(shape, list):
shape = g.op("Concat", *shape, axis_i=0)
return reshape(g, input, shape)
@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
def reshape(g: jit_utils.GraphContext, self, shape):
return symbolic_helper._reshape_helper(g, self, shape)
@_onnx_symbolic("aten::reshape_as")
@symbolic_helper.quantized_args(True)
def reshape_as(g: jit_utils.GraphContext, self, other):
shape = g.op("Shape", other)
return reshape(g, self, shape)
@_onnx_symbolic("aten::add")
def add(g: jit_utils.GraphContext, self, other, alpha=None):
"""
This function takes the add function and returns the corresponding ONNX operator.
This function is not meant to be called directly by the user.
Args:
g (GraphContext): The graph context.
self (Tensor): The first operand.
other (Tensor): The second operand.
alpha (float, optional): The scaling factor for the second operand. Defaults to None.
Returns:
ONNX operator.
"""
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
return symbolic_helper._onnx_opset_unsupported_detailed(
"Add", 9, 11, "Add between list of tensors not supported", self
)
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
other = g.op("Mul", other, alpha)
return g.op("Add", self, other)
@_onnx_symbolic("aten::sub")
def sub(g: jit_utils.GraphContext, self, other, alpha=None):
"""
Consumes sub function and returns the corresponding ONNX operator.
This function is not meant to be called directly by the user.
Args:
g (GraphContext): The graph context.
self (Tensor): The first operand.
other (Tensor): The second operand.
alpha (Optional[Tensor]): A scaling factor to apply to the second operand.
If `alpha` is not provided, it defaults to 1.
Returns:
ONNX operator
"""
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
other = g.op("Mul", other, alpha)
return g.op("Sub", self, other)
@_onnx_symbolic("aten::rsub")
def rsub(g: jit_utils.GraphContext, self, other, alpha=None):
return sub(g, other, self, alpha=alpha)
@_onnx_symbolic("aten::mul")
def mul(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other):
# ONNX Mul doesn't support Boolean, so use And as an equivalent operator.
return g.op("And", self, other)
else:
return g.op("Mul", self, other)
@_onnx_symbolic("aten::div")
def div(g: jit_utils.GraphContext, self, other, *args):
if len(args) == 0:
return true_divide(g, self, other)
else:
return _div_rounding_mode(g, self, other, *args)
@_onnx_symbolic("aten::addcmul")
@symbolic_helper.parse_args("v", "v", "v", "f")
def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0):
value_tens = g.op("Constant", value_t=torch.tensor([value]))
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
@symbolic_helper.parse_args("v", "v", "s")
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
if rounding_mode is None:
return true_divide(g, self, other)
elif rounding_mode == "floor":
return _floor_divide(g, self, other)
elif rounding_mode == "trunc":
return _trunc_divide(g, self, other)
else:
raise errors.SymbolicValueError(
f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"',
self,
)
def _trunc_divide(g: jit_utils.GraphContext, self, other):
out = g.op("Div", self, other)
# the correct operation is truncate, which is not supported in ONNX,
# we cannot call floor since it will behave differently for negative numbers
# (eg. -0.1 should become -0 )
# - if scalar_type information are not available, assume that
# we need to call floor (treat as float)
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64)
# Matching PyTorch's behavior:
# - if self is fp the output's type is self's type
# - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT
# - self is not fp and other is not fp, the output's type is self's output type
# - the output type defaults to Float
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
)
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
else:
out = g.op(
"Cast",
out,
to_i=scalar_type.onnx_type(),
)
else:
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return out
def _floor_divide(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = true_divide(g, self, other)
return g.op("Floor", out)
else:
# Integer division does trunction rounding
div = g.op("Div", self, other)
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op(
"Xor",
symbolic_helper._lt_helper(g, self, zero),
symbolic_helper._lt_helper(g, other, zero),
)
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Sub", self, g.op("Mul", div, other))
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
fixup = g.op("Mul", fixup_mask, one)
return g.op("Sub", div, fixup)
@_onnx_symbolic("aten::floor_divide")
def floor_divide(g: jit_utils.GraphContext, self, other):
# Deprecated behavior, floor_divide actually truncates
return _trunc_divide(g, self, other)
@_onnx_symbolic("aten::floordiv")
def floordiv(g: jit_utils.GraphContext, self, other):
return floor_divide(g, self, other)
@_onnx_symbolic("aten::true_divide")
def true_divide(g: jit_utils.GraphContext, self, other):
"""Division where both inputs are cast to floating types
If both inputs are floating, performs div as usual
If only one input is a floating type, the other input is cast to its type
If neither input is a floating type, both inputs are cast to the default scalar type
"""
# Case 1: either values are floating
# Performs div as usual.
# Implicit casting will be handled in scalar type analysis pass.
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
return g.op("Div", self, other)
# Case 2: neither is floating
# Casts both inputs to the default scalar type
scalar_type = torch.get_default_dtype()
onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT
assert scalar_type is torch.float or scalar_type is torch.double
if torch.get_default_dtype() is torch.double:
onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE
self = g.op("Cast", self, to_i=onnx_scalar_type)
other = g.op("Cast", other, to_i=onnx_scalar_type)
return g.op("Div", self, other)
@_onnx_symbolic("aten::reciprocal")
def reciprocal(g: jit_utils.GraphContext, self):
# torch.reciprocal implicitly casts to float, so we do the same.
if not symbolic_helper._is_fp(self):
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return g.op("Reciprocal", self)
@_onnx_symbolic("aten::cat")
@symbolic_helper.parse_args("v", "i")
def cat(g: jit_utils.GraphContext, tensor_list, dim):
"""Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension.
Parameters:
g (jit_utils.GraphContext): Graph context.
tensor_list (List[torch.Tensor]): List of tensors to concatenate.
dim (int): Dimension along which to concatenate the tensors.
Returns:
ONNX graph node representing the concatenated tensor.
"""
tensors = symbolic_helper._unpack_list(tensor_list)
# torch.cat ignores empty tensors such as `torch.Tensor([])`
# These needs to be removed as input from ONNX's concat too, otherwise shape inference
# will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
nonempty_tensors = []
for t in tensors:
if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
t, 0
):
continue
nonempty_tensors.append(t)
assert len(nonempty_tensors) > 0
assert all(
symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
or symbolic_helper._get_tensor_rank(t) is None
or symbolic_helper._get_tensor_rank(t)
== symbolic_helper._get_tensor_rank(nonempty_tensors[0])
for t in nonempty_tensors
)
tensor_list.node().removeAllInputs()
for t in nonempty_tensors:
tensor_list.node().addInput(t)
tensors = symbolic_helper._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@_onnx_symbolic("aten::stack")
@symbolic_helper.parse_args("v", "i")
def stack(g: jit_utils.GraphContext, tensor_list, dim):
unsqueezed = [
symbolic_helper._unsqueeze_helper(g, t, [dim])
for t in symbolic_helper._unpack_list(tensor_list)
]
return g.op("Concat", *unsqueezed, axis_i=dim)
@_onnx_symbolic("aten::list")
def _list(g: jit_utils.GraphContext, self):
return self
@_onnx_symbolic("aten::mm")
def mm(g: jit_utils.GraphContext, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
C = g.op("Constant", value_t=torch.tensor([1]))
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::bmm")
def bmm(g: jit_utils.GraphContext, self, other):
return g.op("MatMul", self, other)
@_onnx_symbolic("aten::matmul")
def matmul(g: jit_utils.GraphContext, self, other):
return g.op("MatMul", self, other)
@_onnx_symbolic("aten::addmm")
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
scalar_type = None
self_scalar_type = symbolic_helper._try_get_scalar_type(self)
mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1)
mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2)
if self_scalar_type is not None:
scalar_type = self_scalar_type
elif mat1_scalar_type is not None:
scalar_type = mat1_scalar_type
elif mat2_scalar_type is not None:
scalar_type = mat2_scalar_type
mat1_rank = symbolic_helper._get_tensor_rank(mat1)
mat2_rank = symbolic_helper._get_tensor_rank(mat2)
def is_not_none_nor(v, u):
return v is not None and v != u
if scalar_type is not None and (
is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2)
):
res1 = g.op("MatMul", mat1, mat2)
res2 = self
alpha = symbolic_helper._scalar(alpha)
beta = symbolic_helper._scalar(beta)
if alpha != 1:
alpha = g.op(
"Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype())
)
res1 = g.op("Mul", res1, alpha)
if beta != 1:
beta = g.op(
"Constant",
value_t=torch.tensor(
symbolic_helper._scalar(beta), dtype=scalar_type.dtype()
),
)
res2 = g.op("Mul", res2, beta)
return g.op("Add", res1, res2)
return g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
)
@_onnx_symbolic("aten::neg")
def neg(g: jit_utils.GraphContext, self):
return g.op("Neg", self)
@_onnx_symbolic("aten::sqrt")
def sqrt(g: jit_utils.GraphContext, self):
if _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
) in {
_type_utils.JitScalarType.UINT8,
_type_utils.JitScalarType.INT8,
_type_utils.JitScalarType.INT16,
_type_utils.JitScalarType.INT,
_type_utils.JitScalarType.INT64,
}:
# torch converts all int inputs to sqrt to float
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return g.op("Sqrt", self)
@_onnx_symbolic("aten::rsqrt")
def rsqrt(g: jit_utils.GraphContext, self):
return g.op(
"Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self)
)
@_onnx_symbolic("aten::tanh")
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp
@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128)
def tanh(g: jit_utils.GraphContext, self):
return g.op("Tanh", self)
@_onnx_symbolic("aten::sin")
def sin(g: jit_utils.GraphContext, self):
return g.op("Sin", self)
@_onnx_symbolic("aten::cos")
def cos(g: jit_utils.GraphContext, self):
return g.op("Cos", self)
@_onnx_symbolic("aten::tan")
def tan(g: jit_utils.GraphContext, self):
return g.op("Tan", self)
@_onnx_symbolic("aten::asin")
def asin(g: jit_utils.GraphContext, self):
return g.op("Asin", self)
@_onnx_symbolic("aten::acos")
def acos(g: jit_utils.GraphContext, self):
return g.op("Acos", self)
@_onnx_symbolic("aten::atan")
def atan(g: jit_utils.GraphContext, self):
return g.op("Atan", self)
@_onnx_symbolic("aten::atan2")
def atan2(g: jit_utils.GraphContext, self, other):
# self is y, and other is x on coordinate
slope = g.op("Div", self, other)
atan = g.op("Atan", slope)
const_zero = g.op("Constant", value_t=torch.tensor(0))
const_pi = g.op("Constant", value_t=torch.tensor(math.pi))
condition_second_or_third_quadrant = g.op("Greater", self, const_zero)
second_third_quadrant = g.op(
"Where",
condition_second_or_third_quadrant,
g.op("Add", atan, const_pi),
g.op("Sub", atan, const_pi),
)
condition_14_or_23_quadrant = g.op("Less", other, const_zero)
result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan)
return result
@_onnx_symbolic("aten::sigmoid")
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
def sigmoid(g: jit_utils.GraphContext, self):
"""Converts the corresponding PyTorch function into ONNX operators.
It is not meant to be called directly by a user.
Args:
g (jit_utils.GraphContext): Graph context.
self (Tensor): the input tensor.
Returns:
ONNX operator
"""
return g.op("Sigmoid", self)
@_onnx_symbolic("aten::sign")
def sign(g: jit_utils.GraphContext, self):
return g.op("Sign", self)
@symbolic_helper.quantized_args(True)
def _slice(g: jit_utils.GraphContext, input, axes, starts, ends):
assert len(starts) == len(ends)
if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX:
return input
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
@_onnx_symbolic(
"aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")]
)
@_onnx_symbolic(
"aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
)
# torch.prod does not support multidimensional "dim"
@_onnx_symbolic(
"aten::prod",
decorate=[
symbolic_helper._apply_params(
"ReduceProd", "prod", allow_multi_dim_support=False
)
],
)
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
)
@_onnx_symbolic("aten::cumsum")
@symbolic_helper.parse_args("v", "i", "none")
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
@_onnx_symbolic("aten::_sample_dirichlet")
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator):
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self)
@_onnx_symbolic("aten::_standard_gamma")
def _standard_gamma(g: jit_utils.GraphContext, self, generator):
return symbolic_helper._onnx_unsupported("_standard_gamma", self)
@_onnx_symbolic("aten::t")
def t(g: jit_utils.GraphContext, self):
rank = symbolic_helper._get_tensor_rank(self)
if rank is None or rank < 2:
# The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior
# clearly and onnxruntime fails on these cases. So we add an Identity node to
# mirror the behavior of eager mode.
return g.op("Identity", self)
return g.op("Transpose", self, perm_i=(1, 0))
@_onnx_symbolic("aten::numpy_T")
@symbolic_helper.quantized_args(True)
def numpy_T(g: jit_utils.GraphContext, input):
ndim = symbolic_helper._get_tensor_rank(input)
assert ndim is not None
perm = list(reversed(range(0, ndim)))
return g.op("Transpose", input, perm_i=perm)
@_onnx_symbolic("aten::expand")
@symbolic_helper.quantized_args(True)
def expand(g: jit_utils.GraphContext, self, size, implicit):
"""Implement the expand function for a pytorch tensor in ONNX according to specified `size`"""
size = symbolic_helper._maybe_get_const(size, "is")
if not symbolic_helper._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
elif symbolic_helper._is_packed_list(size):
# Expand with -1 dim value means dim is unchanged.
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = symbolic_helper._reshape_helper(
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
)
dtype = _type_utils.JitScalarType.INT64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
return g.op("Expand", self, size)
@_onnx_symbolic("aten::broadcast_to")
@symbolic_helper.quantized_args(True)
def broadcast_to(g: jit_utils.GraphContext, self, size):
size = symbolic_helper._maybe_get_const(size, "is")
if not symbolic_helper._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
elif symbolic_helper._is_packed_list(size):
# Expand with -1 dim value means dim is unchanged.
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = symbolic_helper._reshape_helper(
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
)
dtype = _type_utils.JitScalarType.INT64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
return g.op("Expand", self, size)
@_onnx_symbolic("aten::expand_as")
@symbolic_helper.quantized_args(True, True)
def expand_as(g: jit_utils.GraphContext, self, other):
self_t = symbolic_helper._maybe_get_const(self, "t")
if isinstance(self_t, torch.Tensor):
orig_type = self_t.dtype
self_t = self_t.to(torch.double)
dims = []
for d in range(self_t.dim()):
if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t):
dims.append(d)
self = g.op(
"Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type)
)
shape = g.op("Shape", other)
return g.op("Expand", self, shape)
@_onnx_symbolic("aten::embedding")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i", "b", "v")
def embedding(
g: jit_utils.GraphContext,
weight,
indices,
padding_idx,
scale_grad_by_freq,
sparse,
):
if scale_grad_by_freq and GLOBALS.export_training:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
"for training mode. ONNX does not support scaling the gradients.",
weight,
)
if padding_idx >= 0 and GLOBALS.export_training:
warnings.warn(
"Warning: ONNX export of embedding with padding_idx >= 0 "
"for training mode. "
"ONNX does not support not updating the embedding vector at padding_idx during training."
)
return g.op("Gather", weight, indices)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if not symbolic_helper._is_none(per_sample_weights):
return symbolic_helper._onnx_unsupported(
"embedding_bag with per_sample_weights"
)
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix)
@_onnx_symbolic("aten::size")
@symbolic_helper.quantized_args(True, quantize_output=False)
def size(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Shape", self)
if symbolic_helper._maybe_get_const(dim, "i") < 0:
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None:
dim = symbolic_helper._maybe_get_const(dim, "i") + rank
dim = g.op("Constant", value_t=torch.tensor(dim))
return symbolic_helper._size_helper(g, self, dim)
@_onnx_symbolic("aten::transpose")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "i")
def transpose(g: jit_utils.GraphContext, self, dim0, dim1):
if dim0 == dim1: # micro-optimization
return self
# NB: Transpose in ONNX is actually a Permute
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None:
axes = list(range(rank))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
else:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of transpose for tensor of unknown rank.",
self,
)
@_onnx_symbolic("aten::permute")
@symbolic_helper.parse_args("v", "is")
def permute(g: jit_utils.GraphContext, self, dims):
if dims == list(range(0, len(dims))):
return self
return g.op("Transpose", self, perm_i=dims)
@_onnx_symbolic("aten::view")
@symbolic_helper.quantized_args(True)
def view(g: jit_utils.GraphContext, self, size):
return reshape(g, self, size)
@_onnx_symbolic("aten::view_as")
def view_as(g: jit_utils.GraphContext, self, other):
shape = g.op("Shape", other)
return reshape(g, self, shape)
@_onnx_symbolic("aten::unsafe_chunk")
@symbolic_helper.parse_args("v", "i", "i", "i")
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
if _outputs is None:
return symbolic_helper._onnx_opset_unsupported_detailed(
"unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self
)
size = symbolic_helper._get_tensor_dim_size(self, dim)
if size is None:
return symbolic_helper._unimplemented(
"unsafe_chunk", "unknown dimension size", self
)
split_size = (size + chunks - 1) // chunks
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
@_onnx_symbolic("aten::split")
@symbolic_helper.parse_args("v", "v", "i", "i")
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
return symbolic_helper._onnx_opset_unsupported_detailed(
"split", 9, 11, "Dynamic number of outputs not supported", self
)
split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
if split_val.dim() > 0:
return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs)
split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
size = symbolic_helper._get_tensor_dim_size(self, dim)
if size is None:
if _outputs is not None:
size = split_size * _outputs
else:
return symbolic_helper._onnx_opset_unsupported_detailed(
"split", 9, 11, "Unknown dimension size not supported", self
)
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
@_onnx_symbolic("aten::unsafe_split")
def unsafe_split(
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
):
return split(g, self, split_size_or_sizes, dim, _outputs)
@_onnx_symbolic("aten::split_with_sizes")
@symbolic_helper.parse_args("v", "is", "i", "i")
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_sizes, _outputs):
return symbolic_helper._onnx_opset_unsupported_detailed(
"split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self
)
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs)
@_onnx_symbolic("aten::unsafe_split_with_sizes")
def unsafe_split_with_sizes(
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
):
return split_with_sizes(g, self, split_sizes, dim, _outputs)
@_onnx_symbolic("aten::unbind")
@symbolic_helper.parse_args("v", "i", "i")
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
if _outputs is None:
return symbolic_helper._onnx_opset_unsupported_detailed(
"unbind", 9, 11, "Dynamic number of outputs not supported", self
)
outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs)
outputs = [outputs] if _outputs == 1 else outputs
squeezed_outputs = [
symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs
]
return squeezed_outputs
@_onnx_symbolic("aten::select")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "v")
def select(g: jit_utils.GraphContext, self, dim, index):
"""Implement the select functionality for a pytorch tensor in ONNX.
Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor.
"""
index = symbolic_helper._maybe_get_scalar(index)
if (not symbolic_helper._is_value(index)) and (index < 0):
if index == -1:
end_index = _constants.INT64_MAX
else:
end_index = index + 1
slice_node = symbolic_helper._slice_helper(
g, self, axes=[dim], starts=[index], ends=[end_index]
)
return symbolic_helper._squeeze_helper(g, slice_node, [dim])
else:
# FIXME(justinchuby): can index be an int and not a value?
return g.op("Gather", self, index, axis_i=dim)
@_onnx_symbolic("aten::square")
def square(g: jit_utils.GraphContext, self):
return g.op("Mul", self, self)
@_onnx_symbolic("aten::squeeze")
def squeeze(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
squeeze_dim = symbolic_helper._get_const(dim, "i", "dim")
# Handle negative dims
if squeeze_dim < 0:
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None:
warnings.warn(
"ONNX export squeeze with negative axis "
+ str(squeeze_dim)
+ " might cause the onnx model to be incorrect. "
+ "Negative axis is not supported in ONNX. "
+ "Axis is converted to "
+ str(squeeze_dim + rank)
+ " based on input shape at export time. "
+ "Passing an tensor of different rank in execution will be incorrect."
)
squeeze_dim += rank
else:
return symbolic_helper._unimplemented(
"squeeze", "negative axis with unknown input rank", self
)
dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim)
if dim_size is None:
warnings.warn(
"This model contains a squeeze operation on dimension "
+ str(squeeze_dim)
+ " on an input "
+ "with unknown shape. Note that if the size of dimension "
+ str(squeeze_dim)
+ " of the input "
+ "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on "
+ "non-singleton dimensions, it is recommended to export this model using opset "
+ "version 11 or higher."
)
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
if dim_size > 1:
warnings.warn(
"This model contains a squeeze operation on dimension "
+ str(squeeze_dim)
+ ". The size of "
+ "this dimension in the given input is "
+ str(dim_size)
+ ". The model will "
+ "be exported without the squeeze node. If the model is intended to be used with dynamic "
+ "input shapes, please use opset version 11 to "
+ "export the model."
)
return self
warnings.warn(
"This model contains a squeeze operation on dimension "
+ str(squeeze_dim)
+ ". If the model is "
+ "intended to be used with dynamic input shapes, please use opset version 11 to export the model."
)
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
@_onnx_symbolic("aten::prelu")
def prelu(g: jit_utils.GraphContext, self, weight):
self_rank = symbolic_helper._get_tensor_rank(self)
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
weight_rank = len(weight_sizes)
if self_rank is not None:
if self_rank > 2:
# make weight unidirectional broadcastable
weight = symbolic_helper._unsqueeze_helper(
g, weight, list(range(1, self_rank - 1))
)
elif self_rank == 0 and weight_sizes == [1]:
# self and weight are both scalar but weight has rank == 1, squeeze weight.
weight = symbolic_helper._squeeze_helper(g, weight, [0])
weight_rank = 0
if self_rank is not None and weight_rank is not None:
assert (
self_rank >= weight_rank
), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}"
return g.op("PRelu", self, weight)
@_onnx_symbolic("aten::silu")
def silu(g: jit_utils.GraphContext, input):
return g.op("Mul", input, g.op("Sigmoid", input))
@_onnx_symbolic("aten::mish")
def mish(g: jit_utils.GraphContext, input):
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
@_onnx_symbolic("aten::relu")
@symbolic_helper.quantized_args(True)
def relu(g: jit_utils.GraphContext, input):
return symbolic_helper._op_with_optional_float_cast(
g, "Relu", input, opset_before=14
)
@_onnx_symbolic("aten::relu6")
@symbolic_helper.quantized_args(True)
def relu6(g: jit_utils.GraphContext, input):
return clamp(g, input, 0, 6)
@_onnx_symbolic("aten::ceil")
def ceil(g: jit_utils.GraphContext, input):
return g.op("Ceil", input)
@_onnx_symbolic("aten::floor")
def floor(g: jit_utils.GraphContext, input):
return g.op("Floor", input)
@_onnx_symbolic("aten::len")
def _len(g: jit_utils.GraphContext, self):
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
return symbolic_helper._squeeze_helper(g, sz_0, [0])
@_onnx_symbolic("aten::threshold")
@symbolic_helper.parse_args("v", "t", "t")
def threshold(g: jit_utils.GraphContext, self, threshold, value):
# See Note [Export inplace]
if symbolic_helper._scalar(threshold) != 0:
return symbolic_helper._unimplemented("threshold", "non-zero threshold", self)
if symbolic_helper._scalar(value) != 0:
return symbolic_helper._unimplemented("threshold", "non-zero value", self)
return g.op("Relu", self)
@_onnx_symbolic("aten::leaky_relu")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "b")
def leaky_relu(
g: jit_utils.GraphContext,
input: _C.Value,
negative_slope: float,
inplace: bool = False,
):
# See Note [Export inplace]
return g.op("LeakyRelu", input, alpha_f=negative_slope)
@_onnx_symbolic("aten::glu")
@symbolic_helper.parse_args("v", "i")
def glu(g: jit_utils.GraphContext, input, dim):
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", input, axis_i=dim, outputs=2)
return g.op("Mul", first, g.op("Sigmoid", second))
@_onnx_symbolic("aten::softmax")
@symbolic_helper.parse_args("v", "i", "none")
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
# If input is a 2 x 3 tensor:
# input = [[1.0, 1.0, 1.0],
# [1.0, 1,0, 1,0]]
# with dim = 0, the result is:
# result = [[0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5]]
# with axis = 0, the result is:
# result = [[0.167, 0.167, 0.167],
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
# So use softmax when dim and axis both equal to ndim - 1,
# otherwise transpose the input to put the vectors to be normalized to the last dimension.
# When input rank is not known at export time we compute softmax using a subgraph
# with other operators
input_dim = symbolic_helper._get_tensor_rank(input)
if input_dim is not None:
# TODO: remove this as onnx opset 11 spec allows negative axes
if dim < 0:
dim = input_dim + dim
is_transpose_required = input_dim != dim + 1
if is_transpose_required:
axes = list(range(input_dim))
axes[dim], axes[-1] = axes[-1], axes[dim]
input = g.op("Transpose", input, perm_i=axes)
dim = input_dim - 1
softmax = g.op("Softmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
softmax = g.op(
"Cast",
softmax,
to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(),
)
if is_transpose_required:
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined]
return softmax
# Apply max normalization.
input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1))
exp = g.op("Exp", input)
sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim])
softmax = g.op("Div", exp, sum)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
softmax = g.op(
"Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
)
return softmax
@_onnx_symbolic("aten::softplus")
def softplus(g: jit_utils.GraphContext, self, beta, threshold):
beta_const = symbolic_helper._maybe_get_const(beta, "f")
if beta_const != 1:
return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta)
return g.op("Softplus", self)
@_onnx_symbolic("aten::get_pool_ceil_padding")
def get_pool_ceil_padding(input, kernel_size, stride, padding):
# TODO(justinchuby): Looks like this op is deprecated in torch
sizes = symbolic_helper._get_tensor_sizes(input)
dim = sizes[-len(padding) :] if sizes is not None else None
if dim is None or any(i is None for i in dim):
return symbolic_helper._unimplemented(
"get_pool_ceil_padding", "input size not accessible", input
)
ceiled_output_dim = [
int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i])))
+ 1
for i in range(0, len(padding))
]
# ensure last pooling starts inside
ceiled_output_dim = [
(
ceiled_output_dim[i] - 1
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
else ceiled_output_dim[i]
)
for i in range(0, len(ceiled_output_dim))
]
padding_ceil = [
(
0
if (stride[i] == 1)
else (
kernel_size[i]
- (
dim[i]
+ 2 * padding[i]
- ((ceiled_output_dim[i] - 1) * stride[i] + 1)
)
)
)
for i in range(0, len(padding))
]
# ensure padding is not > kernel_size
padding_ceil = [
(
(
int(padding_ceil[i])
if padding_ceil[i] < kernel_size[i] - 1
else int(kernel_size[i] - 1)
)
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
else int(padding_ceil[i])
)
for i in range(0, len(padding_ceil))
]
return padding_ceil
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[
symbolic_helper._apply_params(
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
),
_export("max_pool1d"),
],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[
symbolic_helper._apply_params(
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
),
_export("max_pool2d"),
],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[
symbolic_helper._apply_params(
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
),
_export("max_pool3d"),
],
)
def _max_pool(name, tuple_fn, ndims, return_indices):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if set(tuple_fn(dilation)) != {1}:
return symbolic_helper._unimplemented(name, "dilation", input)
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
else:
padding = padding * 2
kwargs = {
"kernel_shape_i": tuple_fn(kernel_size),
"pads_i": padding,
"strides_i": tuple_fn(stride),
}
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op(
"MaxPool",
input,
outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)],
)
# convert indices to have non-flattened indices values
s = symbolic_helper._slice_helper(
g,
flattened_indices,
axes=[2 + i for i in range(ndims)],
starts=list(tuple_fn(0)),
ends=list(tuple_fn(1)),
)
indices = sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")(
_max_pool(
"max_pool1d_with_indices",
torch.nn.modules.utils._single,
1,
return_indices=True,
)
)
max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")(
_max_pool(
"max_pool2d_with_indices",
torch.nn.modules.utils._pair,
2,
return_indices=True,
)
)
max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")(
_max_pool(
"max_pool3d_with_indices",
torch.nn.modules.utils._triple,
3,
return_indices=True,
)
)
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[
symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single),
_export("avg_pool1d"),
],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[
symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair),
_export("avg_pool2d"),
],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[
symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple),
_export("avg_pool3d"),
],
)
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
def symbolic_fn(
g,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: int | Sequence[int],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
if not stride:
stride = kernel_size
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
adjusted_padding = padding
# Although onnx::AvgPool provides count_include_pad,
# The corner case of Average Pooling with ceil_mode on
# PyTorch allows sliding window go off bound, which leads to
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
if count_include_pad:
input = symbolic_helper._op_with_optional_float_cast(
g,
"Pad",
input,
pads_i=((0,) * 2 + padding) * 2,
mode_s="constant",
value_f=0.0,
opset_before=11,
)
adjusted_padding = (0,) * len(padding)
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
adjusted_padding = adjusted_padding + tuple(
a + b for (a, b) in zip(padding_ceil, adjusted_padding)
)
else:
adjusted_padding = adjusted_padding * 2
output = g.op(
"AveragePool",
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=adjusted_padding,
)
return output
return symbolic_fn
@_onnx_symbolic(
"aten::adaptive_avg_pool1d",
decorate=[
symbolic_helper._apply_params(
"adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
),
_export("adaptive_avg_pool1d"),
],
)
@_onnx_symbolic(
"aten::adaptive_avg_pool2d",
decorate=[
symbolic_helper._apply_params(
"adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
),
_export("adaptive_avg_pool2d"),
],
)
@_onnx_symbolic(
"aten::adaptive_avg_pool3d",
decorate=[
symbolic_helper._apply_params(
"adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
),
_export("adaptive_avg_pool3d"),
],
)
@_onnx_symbolic(
"aten::adaptive_max_pool1d",
decorate=[
symbolic_helper._apply_params(
"adaptive_max_pool1d",
"MaxPool",
torch.nn.modules.utils._single,
max_pool1d_with_indices,
),
_export("adaptive_max_pool1d"),
],
)
@_onnx_symbolic(
"aten::adaptive_max_pool2d",
decorate=[
symbolic_helper._apply_params(
"adaptive_max_pool2d",
"MaxPool",
torch.nn.modules.utils._pair,
max_pool2d_with_indices,
),
_export("adaptive_max_pool2d"),
],
)
@_onnx_symbolic(
"aten::adaptive_max_pool3d",
decorate=[
symbolic_helper._apply_params(
"adaptive_max_pool3d",
"MaxPool",
torch.nn.modules.utils._triple,
max_pool3d_with_indices,
),
_export("adaptive_max_pool3d"),
],
)
def _adaptive_pool(name, type, tuple_fn, fn=None):
@symbolic_helper.quantized_args(True, False)
def symbolic_fn(g, input, output_size):
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
# by executing a GlobalPool.
# It is also supported for cases where the output size is a factor of the input size.
# For these cases the stride and kernel size are uniform along all the indices of
# the same dimension, which makes it possible to export it to ONNX.
# for MaxPool, GlobalMaxPool does not return indices,
# so we try using max_poolxd_with_indices, and if it is not possible
# (input is not a complete tensor or output size not factor of input size)
# then we call GlobalAveragePool and return None for the indices
output_size_value = output_size
try:
output_size = symbolic_helper._parse_arg(output_size, "is")
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
return symbolic_helper._onnx_unsupported(
"adaptive pooling, since output_size is not constant.", input
)
if output_size == [1] * len(output_size) and type == "AveragePool":
return g.op("GlobalAveragePool", input)
sizes = symbolic_helper._get_tensor_sizes(input)
try:
dim = sizes[2:]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
dim = None
if dim is None or any(i is None for i in dim):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return symbolic_helper._unimplemented(
name, "input size not accessible", input
)
# verify if output size % input size = 0 for all dim
mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
if mod != [0] * len(mod):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return symbolic_helper._unimplemented(
name, "output size that are not factor of input size", output_size_value
)
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == "MaxPool":
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k))
return output
return symbolic_fn
def _prepare_onnx_paddings(dim: int, pad):
"""Generate paddings in ONNX order based on pad in pytorch.
Args:
dim: the dimension of the tensor.
pad: the paddings in pytorch.
The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
"""
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# assume zero-dimensions in the beginning
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
# reverse order and collate first beginnings and then ends
paddings = paddings[-2::-2] + paddings[-1::-2]
return paddings
def _convert_padding_node(input):
padding = symbolic_helper._maybe_get_const(input, "is")
if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding):
input_list = symbolic_helper._unpack_list(padding)
try:
padding = [
symbolic_helper._get_const(v, "i", "padding") for v in input_list
]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
return symbolic_helper._onnx_opset_unsupported_detailed(
"Pad", 9, 11, "The sizes of the padding must be constant", input
)
return padding
@_onnx_symbolic("aten::constant_pad_nd")
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value):
mode = "constant"
try:
value = symbolic_helper._get_const(value, "f", "value")
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
return symbolic_helper._onnx_opset_unsupported_detailed(
"Pad", 9, 11, "The value for the padding must be constant", value
)
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
)
def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value):
padding = _convert_padding_node(pad)
assert len(padding) % 2 == 0
ndim = len(padding) // 2
cur = input
for idx in range(ndim):
pad_r = padding[-(2 * idx + 1)]
pad_l = padding[-(2 * idx + 2)]
tensors = []
if pad_l > 0:
left = symbolic_helper._slice_helper(
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX]
)
tensors.append(left)
if pad_l < 0 or pad_r < 0:
start = builtins.max(0, -pad_l)
end = -(builtins.max(0, -pad_r))
middle = symbolic_helper._slice_helper(
g,
cur,
axes=[2 + idx],
starts=[start],
ends=[end],
)
tensors.append(middle)
else:
tensors.append(cur)
if pad_r > 0:
right = symbolic_helper._slice_helper(
g, cur, axes=[2 + idx], starts=[0], ends=[pad_r]
)
tensors.append(right)
cur = g.op("Concat", *tensors, axis_i=(2 + idx))
return cur
@_onnx_symbolic("aten::reflection_pad1d")
@_onnx_symbolic("aten::reflection_pad2d")
@_onnx_symbolic("aten::reflection_pad3d")
def reflection_pad(g: jit_utils.GraphContext, input, padding):
mode = "reflect"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
)
@_onnx_symbolic("aten::replication_pad1d")
@_onnx_symbolic("aten::replication_pad2d")
@_onnx_symbolic("aten::replication_pad3d")
def replication_pad(g: jit_utils.GraphContext, input, padding):
mode = "edge"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
return symbolic_helper._op_with_optional_float_cast(
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
)
@_onnx_symbolic("aten::pad")
def pad(
g: jit_utils.GraphContext,
input: _C.Value,
pad: _C.Value,
mode: _C.Value,
value: _C.Value,
):
mode = symbolic_helper._parse_arg(mode, "s")
if mode == "replicate":
return replication_pad(g, input, pad)
elif mode == "reflect":
return reflection_pad(g, input, pad)
elif mode == "constant":
return constant_pad_nd(g, input, pad, value)
elif mode == "circular":
return _pad_circular(g, input, pad)
else:
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[
symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"),
_export("upsample_nearest1d"),
],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[
symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"),
_export("upsample_nearest2d"),
],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[
symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"),
_export("upsample_nearest3d"),
],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[
symbolic_helper._apply_params("upsample_linear1d", 3, "linear"),
_export("upsample_linear1d"),
],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[
symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"),
_export("upsample_bilinear2d"),
],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[
symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"),
_export("upsample_trilinear3d"),
],
)
def _interpolate(name: str, dim: int, interpolate_mode: str):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
if scales is None:
scales = symbolic_helper._interpolate_size_to_scales(
g, input, output_size, dim
)
return g.op("Upsample", input, scales, mode_s=interpolate_mode)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Upsample", input, scales, mode_s=mode)
@_onnx_symbolic("aten::bitwise_not")
def bitwise_not(g: jit_utils.GraphContext, input):
if not symbolic_helper._is_bool(input):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise Not "
"for non-boolean input values",
input,
)
return g.op("Not", input)
@_onnx_symbolic("aten::bitwise_or")
def bitwise_or(g, self, other):
if not symbolic_helper._is_bool(self):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise OR "
"for non-boolean input values. self: ",
self,
)
if not symbolic_helper._is_bool(other):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise OR "
"for non-boolean input values. other: ",
other,
)
return g.op("Or", self, other)
def wrap_logical_op_with_cast_to(to_type):
def decorator(fn):
@functools.wraps(fn)
def wrap_with_cast(g, input, other):
to_cast_func = globals()[f"_cast_{to_type}"]
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
return wrap_with_cast
return decorator
def wrap_logical_op_with_negation(func: Callable) -> Callable:
@functools.wraps(func)
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
@_onnx_symbolic("aten::__not_")
def __not_(g: jit_utils.GraphContext, self):
if not symbolic_helper._is_bool(self):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise Not "
"for non-boolean input values",
self,
)
return g.op("Not", self)
@_onnx_symbolic("aten::eq")
@symbolic_helper.quantized_args(True, True)
def eq(g: jit_utils.GraphContext, self, other):
if isinstance(self.type(), _C.DeviceObjType) and isinstance(
other.type(), _C.DeviceObjType
):
# ONNX doesn't have devices, so consider them all to be equal.
# The no-op check for equality will get constant-folded.
return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool))
self_node = self.node()
other_node = other.node()
if self_node.kind() == other_node.kind() == "onnx::Constant":
if self_node.kindOf("value") == other_node.kindOf("value") == "s":
# Exporting strings to ONNX is not supported.
# If both strings are constant, we can compare them directly.
# The no-op check for equality will get constant-folded.
return g.op(
"Constant",
value_t=torch.tensor(
self_node.s("value") == other_node.s("value"),
dtype=torch.bool,
),
)
return g.op("Equal", self, other)
@_onnx_symbolic("aten::ne")
@symbolic_helper.quantized_args(True, True)
@wrap_logical_op_with_negation
def ne(g: jit_utils.GraphContext, self, other):
return eq(g, self, other)
@_onnx_symbolic("aten::gt")
@symbolic_helper.quantized_args(True, True)
def gt(g: jit_utils.GraphContext, input, other):
return _gt_impl(g, input, other)
def _gt_impl(g: jit_utils.GraphContext, input, other):
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
return g.op("Greater", input, other)
@_onnx_symbolic("aten::lt")
@symbolic_helper.quantized_args(True, True)
def lt(g: jit_utils.GraphContext, input, other):
return _lt_impl(g, input, other)
def _lt_impl(g: jit_utils.GraphContext, input, other):
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
return g.op("Less", input, other)
@_onnx_symbolic("aten::ge")
@symbolic_helper.quantized_args(True, True)
@wrap_logical_op_with_negation
def ge(g: jit_utils.GraphContext, input, other):
return _lt_impl(g, input, other)
@_onnx_symbolic("aten::le")
@symbolic_helper.quantized_args(True, True)
@wrap_logical_op_with_negation
def le(g: jit_utils.GraphContext, input, other):
return _gt_impl(g, input, other)
@_onnx_symbolic("aten::__and_")
def __and_(g: jit_utils.GraphContext, input, other):
if not symbolic_helper._is_bool(input):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise AND "
"for non-boolean input values",
input,
)
if not symbolic_helper._is_bool(other):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise AND "
"for non-boolean input values",
other,
)
return g.op("And", input, other)
@_onnx_symbolic("aten::__or_")
def __or_(g: jit_utils.GraphContext, input, other):
if not symbolic_helper._is_bool(input):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise OR "
"for non-boolean input values",
input,
)
if not symbolic_helper._is_bool(other):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise OR "
"for non-boolean input values",
other,
)
return g.op("Or", input, other)
@_onnx_symbolic("aten::__xor_")
def __xor_(g: jit_utils.GraphContext, input, other):
if not symbolic_helper._is_bool(input):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise XOR "
"for non-boolean input values",
input,
)
if not symbolic_helper._is_bool(other):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise XOR "
"for non-boolean input values",
other,
)
return g.op("Xor", input, other)
@_onnx_symbolic("aten::logical_and")
@wrap_logical_op_with_cast_to("Bool")
def logical_and(g: jit_utils.GraphContext, input, other):
return g.op("And", input, other)
@_onnx_symbolic("aten::logical_or")
@wrap_logical_op_with_cast_to("Bool")
def logical_or(g: jit_utils.GraphContext, input, other):
return g.op("Or", input, other)
@_onnx_symbolic("aten::logical_xor")
@wrap_logical_op_with_cast_to("Bool")
def logical_xor(g: jit_utils.GraphContext, input, other):
return g.op("Xor", input, other)
@_onnx_symbolic("aten::logical_not")
def logical_not(g: jit_utils.GraphContext, input):
return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL))
@_onnx_symbolic("aten::__rshift_")
def __rshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
self_scalar_type = _type_utils.JitScalarType.from_value(self)
if (
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
!= self_scalar_type
):
other = g.op(
"Cast",
other,
to_i=self_scalar_type.onnx_type(),
)
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not symbolic_helper._is_fp(self):
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
two_pow = g.op("Pow", two, other)
two_pow = g.op(
"Cast",
two_pow,
to_i=self_scalar_type.onnx_type(),
)
rshift = g.op("Div", self, two_pow)
return rshift
@_onnx_symbolic("aten::__lshift_")
def __lshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
self_scalar_type = _type_utils.JitScalarType.from_value(self)
if (
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
!= self_scalar_type
):
other = g.op(
"Cast",
other,
to_i=self_scalar_type.onnx_type(),
)
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not symbolic_helper._is_fp(self):
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
two_pow = g.op("Pow", two, other)
two_pow = g.op(
"Cast",
two_pow,
to_i=self_scalar_type.onnx_type(),
)
lshift = g.op("Mul", self, two_pow)
return lshift
@_onnx_symbolic("aten::where")
@symbolic_helper.parse_args("v", "v", "v", "i")
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if not symbolic_helper._is_bool(condition):
condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
if self is None:
condition = nonzero(g, condition)
return symbolic_helper._unbind_helper(
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
)
return g.op("Where", condition, self, other)
@_onnx_symbolic("aten::log_softmax")
@symbolic_helper.parse_args("v", "i", "none")
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
# TODO: remove this as onnx opset 11 spec allows negative axes
input_dim = symbolic_helper._get_tensor_rank(input)
if input_dim is None:
return symbolic_helper._unimplemented(
"dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.",
)
if dim < 0:
dim = input_dim + dim
is_transpose_required = input_dim != dim + 1
# ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
if is_transpose_required:
axes = list(range(input_dim))
axes[dim], axes[-1] = axes[-1], axes[dim]
input = g.op("Transpose", input, perm_i=axes)
dim = input_dim - 1
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
return_op = g.op(
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
)
if is_transpose_required:
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined]
return return_op
@_onnx_symbolic("aten::_log_softmax")
@symbolic_helper.parse_args("v", "i", "i")
def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float):
if (
half_to_float
and _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.UNDEFINED
)
== _type_utils.JitScalarType.HALF
):
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return log_softmax(g, input, dim)
@_onnx_symbolic("aten::_convolution")
@symbolic_helper.parse_args(
"v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i"
)
def _convolution(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
benchmark,
deterministic,
cudnn_enabled,
allow_tf32=None,
):
weight_size = symbolic_helper._get_tensor_sizes(weight)
try:
kernel_shape = weight_size[2:]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
kernel_shape = None
if kernel_shape is None or any(i is None for i in kernel_shape):
raise errors.SymbolicValueError(
"Unsupported: ONNX export of convolution for kernel of unknown shape.",
input,
)
args = [input, weight]
# ONNX only supports 1D bias
if (
not symbolic_helper._is_none(bias)
and symbolic_helper._get_tensor_rank(bias) == 1
):
args.append(bias)
kwargs = {
"kernel_shape_i": weight_size[2:],
"strides_i": stride,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": padding + padding,
"dilations_i": dilation,
"group_i": groups,
}
if any(o != 0 for o in output_padding):
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
# output_padding is more straightforward, so we use it here.
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
assert transposed
assert len(stride) == len(output_padding)
kwargs["output_padding_i"] = output_padding
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
if (
not symbolic_helper._is_none(bias)
and symbolic_helper._get_tensor_rank(bias) != 1
):
return g.op("Add", n, bias)
else:
return n
@_onnx_symbolic("aten::_convolution_mode")
@symbolic_helper.parse_args(
"v",
"v",
"v",
"is",
"s",
"is",
"i",
)
def _convolution_mode(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
dilation,
groups,
):
weight_size = symbolic_helper._get_tensor_sizes(weight)
try:
kernel_shape = weight_size[2:]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
kernel_shape = None
if kernel_shape is None or any(i is None for i in kernel_shape):
raise errors.SymbolicValueError(
"Unsupported: ONNX export of convolution for kernel of unknown shape.",
input,
)
args = [input, weight]
# ONNX only supports 1D bias
if (
not symbolic_helper._is_none(bias)
and symbolic_helper._get_tensor_rank(bias) == 1
):
args.append(bias)
if padding == "valid":
padding = "VALID"
elif padding == "same":
padding = "SAME_UPPER"
kwargs = {
"kernel_shape_i": weight_size[2:],
"strides_i": stride,
"auto_pad_s": padding,
"dilations_i": dilation,
"group_i": groups,
}
n = g.op("Conv", *args, **kwargs)
if (
not symbolic_helper._is_none(bias)
and symbolic_helper._get_tensor_rank(bias) != 1
):
return g.op("Add", n, bias)
else:
return n
@_onnx_symbolic("aten::convolution")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i")
def convolution(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
):
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv1d")
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
def conv1d(
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
):
str_padding = symbolic_helper._parse_arg(padding, "s")
if str_padding in ["valid", "same"]:
return _convolution_mode(
g,
input,
weight,
bias,
stride,
str_padding,
dilation,
groups,
)
else:
padding = symbolic_helper._parse_arg(padding, "is")
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
False,
(),
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv2d")
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
def conv2d(
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
):
str_padding = symbolic_helper._parse_arg(padding, "s")
if str_padding in ["valid", "same"]:
return _convolution_mode(
g,
input,
weight,
bias,
stride,
str_padding,
dilation,
groups,
)
else:
padding = symbolic_helper._parse_arg(padding, "is")
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
False,
(),
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv3d")
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
def conv3d(
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
):
str_padding = symbolic_helper._parse_arg(padding, "s")
if str_padding in ["valid", "same"]:
return _convolution_mode(
g,
input,
weight,
bias,
stride,
str_padding,
dilation,
groups,
)
else:
padding = symbolic_helper._parse_arg(padding, "is")
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
False,
(),
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv_transpose1d")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose1d(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
output_padding,
groups,
dilation,
):
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
True,
output_padding,
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv_transpose2d")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose2d(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
output_padding,
groups,
dilation,
):
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
True,
output_padding,
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::conv_transpose3d")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose3d(
g: jit_utils.GraphContext,
input,
weight,
bias,
stride,
padding,
output_padding,
groups,
dilation,
):
return _convolution(
g,
input,
weight,
bias,
stride,
padding,
dilation,
True,
output_padding,
groups,
None,
None,
None,
None,
)
@_onnx_symbolic("aten::batch_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
def batch_norm(
g: jit_utils.GraphContext,
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
cudnn_enabled,
):
symbolic_helper.check_training_mode(training, "batch_norm")
if (
torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype(
[input, weight, bias, running_mean, running_var]
)
and GLOBALS.export_onnx_opset_version < 15
):
return symbolic_helper._onnx_opset_unsupported_detailed(
"BatchNormalization",
9,
15,
"All input tensors must have the same `dtype`."
" Turn off Autocast or export using opset version 15.",
input,
)
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
g, input, weight, bias, running_mean, running_var
)
out = g.op(
"BatchNormalization",
input,
weight,
bias,
running_mean,
running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not training else 5,
)
if not training:
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
return res
@_onnx_symbolic("aten::native_layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
def native_layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
) -> tuple[_C.Value, _C.Value, _C.Value]:
axes = [-i for i in range(len(normalized_shape), 0, -1)]
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
if g.opset < 18:
mean = g.op("ReduceMean", input, axes_i=axes)
else:
mean = g.op(
"ReduceMean",
input,
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
)
numerator = sub(g, input, mean)
# Cast it to eps dtype to avoid precision loss
is_type_half = (
_type_utils.JitScalarType.from_value(numerator)
== _type_utils.JitScalarType.HALF
)
if is_type_half:
eps_dtype = _type_utils.JitScalarType.from_value(eps_cst)
numerator = g.op(
"Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type()
)
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
if g.opset < 18:
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
else:
variance = g.op(
"ReduceMean",
pow(g, numerator, two_cst),
g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)),
)
denominator = sqrt(g, g.op("Add", variance, eps_cst))
normalized = g.op("Div", numerator, denominator)
# Cast back to input type as eps related ops are all done
if is_type_half:
input_dtype = _type_utils.JitScalarType.from_value(input)
normalized = g.op(
"Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
)
if not (weight is None or symbolic_helper._is_none(weight)):
normalized = mul(g, normalized, weight)
if not (bias is None or symbolic_helper._is_none(bias)):
normalized = add(g, normalized, bias)
# rdenominator := 1 / sqrt(variance + eps)
# According to aten::native_layer_norm, rdenominator should have the same dtype as input,
# mean and normalized, so we need to Cast it back
if is_type_half:
denominator = g.op(
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
)
rdenominator = g.op("Reciprocal", denominator)
else:
rdenominator = reciprocal(g, denominator)
return normalized, mean, rdenominator
@_onnx_symbolic("aten::layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b")
def layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
) -> _C.Value:
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps)
return normalized
@_onnx_symbolic("aten::instance_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b")
def instance_norm(
g: jit_utils.GraphContext,
input,
weight,
bias,
running_mean,
running_var,
use_input_stats: bool,
momentum: Number,
eps: Number,
cudnn_enabled: bool,
):
symbolic_helper.check_training_mode(use_input_stats, "instance_norm")
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
if weight is None or symbolic_helper._is_none(weight):
if channel_size is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of instance_norm for unknown channel size.",
input,
)
weight_value = torch.tensor(
[1.0] * channel_size,
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
)
weight = g.op("Constant", value_t=weight_value)
if bias is None or symbolic_helper._is_none(bias):
if channel_size is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of instance_norm for unknown channel size.",
input,
)
bias_value = torch.tensor(
[0.0] * channel_size,
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
)
bias = g.op("Constant", value_t=bias_value)
if (
running_mean is None
or symbolic_helper._is_none(running_mean)
or running_var is None
or symbolic_helper._is_none(running_var)
):
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
else:
input_size = symbolic_helper._get_tensor_sizes(input)
# If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm.
# For more information instance_norm():
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542
input_size_reshape = input_size.copy()
n = input_size[0]
if n is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of instance_norm training for unknown "
"batch size.",
input,
)
c = input_size[1]
input_size_reshape[0] = 1
input_size_reshape[1] = n * c
weight_ = repeat(
g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
)
bias_ = repeat(
g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
)
running_mean_ = repeat(
g,
running_mean,
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
)
running_var_ = repeat(
g,
running_var,
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
)
input_reshaped = g.op(
"Reshape",
input,
g.op("Constant", value_t=torch.LongTensor(input_size_reshape)),
)
out = batch_norm(
g,
input_reshaped,
weight_,
bias_,
running_mean_,
running_var_,
use_input_stats,
momentum,
eps,
cudnn_enabled,
)
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "i", "i")
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
sizes = symbolic_helper._get_tensor_sizes(input)
# FIXME(justinchuby): Get rid of the try catch here to improve readability
try:
sizedim = sizes[dimension]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
sizedim = None
if sizedim is not None:
low_indices = range(0, sizedim, step)
hi_indices = range(size, sizedim + 1, step)
stack = [
symbolic_helper._slice_helper(
g, input, axes=[dimension], starts=[low], ends=[hi]
)
for low, hi in zip(low_indices, hi_indices)
]
ndim = len(sizes)
perm = list(range(0, ndim))
perm.append(perm.pop(dimension))
unsqueeze = [
symbolic_helper._unsqueeze_helper(
g, g.op("Transpose", t, perm_i=perm), [dimension]
)
for t in stack
]
return g.op("Concat", *unsqueeze, axis_i=dimension)
else:
return symbolic_helper._unimplemented(
"Unfold", "input size not accessible", input
)
@_onnx_symbolic("aten::elu")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "t", "t", "t")
def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale):
if scale and scale != 1.0:
return symbolic_helper._unimplemented(
"scale", "does not support scale in Elu", scale
)
if input_scale and input_scale != 1.0:
return symbolic_helper._unimplemented(
"input_scale", "does not support input_scale in Elu", input_scale
)
# See Note [Export inplace]
return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha))
@_onnx_symbolic("aten::selu")
@symbolic_helper.quantized_args(True)
def selu(g: jit_utils.GraphContext, input):
return g.op("Selu", input)
@_onnx_symbolic("aten::index_select")
@symbolic_helper.parse_args("v", "i", "v")
def index_select(g: jit_utils.GraphContext, self, dim, index):
# In case of a scalar index, index_select returns a tensor with the same rank as the input.
# To match this behavior in ONNX, we make index a 1D tensor so that the following gather
# also produces a tensor with the same rank as the input.
return symbolic_helper._select_helper(g, self, dim, index)
@_onnx_symbolic("aten::index_put")
def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate):
if symbolic_helper._is_packed_list(indices_list_value):
indices_list = symbolic_helper._unpack_list(indices_list_value)
else:
indices_list = [indices_list_value]
accumulate = symbolic_helper._parse_arg(accumulate, "b")
if len(indices_list) == 0:
if accumulate:
return add(g, self, values)
return values
symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self)
@_onnx_symbolic("aten::index_fill")
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
value = symbolic_helper._maybe_get_scalar(value)
value = symbolic_helper._if_scalar_type_as(value, self)
expanded_value = expand(g, value, expanded_index_shape, None)
return scatter(g, self, dim, expanded_index, expanded_value)
@_onnx_symbolic("aten::index_copy")
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
return scatter(g, self, dim, expanded_index, source)
@_onnx_symbolic("aten::bucketize")
@symbolic_helper.parse_args("v", "v", "b", "b")
def bucketize(
g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False
):
out_type = _C_onnx.TensorProtoDataType.INT64
if out_int32:
out_type = _C_onnx.TensorProtoDataType.INT32
# A tensor expanded_boundaries is created such that it
# contains a copy of boundaries for each element of self.
new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0)
# Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
tensor_rank = symbolic_helper._get_tensor_rank(self)
assert tensor_rank is not None
unsqueeze_axes = list(range(1, tensor_rank + 1))
expanded_boundaries = expand(
g,
symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes),
new_shape,
None,
)
# Compare each element of self to boundaries to get a tensor
# with leading 1s and trailing 0s.
# e.g., 4 > [1, 3, 4] = [1, 1, 0]
# The index of the last 1 is the bucket where the element should go.
if right:
cond = ge(g, self, expanded_boundaries)
else:
cond = gt(g, self, expanded_boundaries)
cond_out = g.op("Cast", cond, to_i=out_type)
# Sum to get the number of 1s corresponding to each element,
# which is the same as the bucket index.
# e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2
return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0)
@_onnx_symbolic("aten::type_as")
def type_as(g: jit_utils.GraphContext, self, other):
self_dtype = symbolic_helper._try_get_scalar_type(self)
other_dtype = symbolic_helper._try_get_scalar_type(other)
if self_dtype == other_dtype and self_dtype is not None:
return self
if other_dtype is not None:
return g.op(
"Cast",
self,
to_i=other_dtype.onnx_type(),
)
raise errors.SymbolicValueError(
"Unsupported: ONNX export of type_as for tensor "
"of unknown dtype. Please check if the dtype of the "
"parameter passed to the type_as function is correct.",
other,
)
@_onnx_symbolic("aten::cosine_similarity")
@symbolic_helper.parse_args("v", "v", "i", "f")
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps):
cross = symbolic_helper._reducesum_helper(
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
)
x1_l2 = symbolic_helper._reducesum_helper(
g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0
)
x2_l2 = symbolic_helper._reducesum_helper(
g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0
)
div_tens = max(
g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps]))
)
return div(g, cross, div_tens)
@_onnx_symbolic("aten::pairwise_distance")
def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim):
if not symbolic_helper._is_value(eps):
eps = g.op("Constant", value_t=torch.tensor([eps]))
inv_p = div(
g,
g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)),
add(g, p, eps),
)
summation = symbolic_helper._reducesum_helper(
g,
pow(g, sub(g, input1, input2), p),
axes_i=[-1],
keepdims_i=symbolic_helper._parse_arg(keepdim, "i"),
)
return pow(g, summation, inv_p)
@_onnx_symbolic("aten::clone")
# ignore clone operators that are inserted by PyTorch autograd
def clone(g: jit_utils.GraphContext, input, unused_memory_format):
return input
@_onnx_symbolic("aten::abs")
def abs(g: jit_utils.GraphContext, self):
return g.op("Abs", self)
@_onnx_symbolic("aten::log")
def log(g: jit_utils.GraphContext, self):
return g.op("Log", self)
@_onnx_symbolic("aten::log1p")
def log1p(g: jit_utils.GraphContext, self):
return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self))
@_onnx_symbolic("aten::log10")
def log10(g: jit_utils.GraphContext, self):
_ln10 = 2.30258509299404568401
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10])))
@_onnx_symbolic("aten::pow")
def pow(g: jit_utils.GraphContext, self, exponent):
f_dtype = _type_utils.JitScalarType.from_value(self)
if not symbolic_helper._is_fp(self):
f_dtype = _type_utils.JitScalarType.FLOAT
self = g.op("Cast", self, to_i=f_dtype.onnx_type())
if not symbolic_helper._is_fp(exponent):
exponent = g.op(
"Cast",
exponent,
to_i=f_dtype.onnx_type(),
)
pow = g.op("Pow", self, exponent)
return pow
@_onnx_symbolic("aten::clamp")
def clamp(g: jit_utils.GraphContext, self, min, max):
# min or max may be None that we need to dispatch to
# Clip separately, as ONNX does not have None syntax
if symbolic_helper._is_none(min):
return clamp_max(g, self, max)
elif symbolic_helper._is_none(max):
return clamp_min(g, self, min)
else:
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
return symbolic_helper._op_with_optional_float_cast(
g,
"Clip",
self,
min_f=symbolic_helper._parse_arg(min, "f"),
max_f=symbolic_helper._parse_arg(max, "f"),
opset_before=12,
)
else:
return clamp_max(g, clamp_min(g, self, min), max)
@_onnx_symbolic("aten::clamp_min")
@symbolic_helper.parse_args("v", "v")
def clamp_min(g: jit_utils.GraphContext, self, min):
if symbolic_helper._is_constant(min):
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
)
else:
dtype = _type_utils.JitScalarType.from_value(self)
min = g.op("Cast", min, to_i=dtype.onnx_type())
return symbolic_helper._op_with_optional_float_cast(
g, "Max", self, min, opset_before=12
)
@_onnx_symbolic("aten::clamp_max")
@symbolic_helper.parse_args("v", "v")
def clamp_max(g: jit_utils.GraphContext, self, max):
if symbolic_helper._is_constant(max):
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
)
else:
dtype = _type_utils.JitScalarType.from_value(self)
max = g.op("Cast", max, to_i=dtype.onnx_type())
return symbolic_helper._op_with_optional_float_cast(
g, "Min", self, max, opset_before=12
)
@_onnx_symbolic("aten::max")
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
# TODO(justinchuby): Support multiple quantized args in output
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@symbolic_helper.quantized_args(True, True)
def maximum(g: jit_utils.GraphContext, input, other):
return max(g, input, dim_or_y=other)
@_onnx_symbolic("aten::min")
# TODO(justinchuby): Support multiple quantized args in output
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@symbolic_helper.quantized_args(True, True)
def minimum(g: jit_utils.GraphContext, input, other):
return min(g, input, dim_or_y=other)
@_onnx_symbolic("aten::amax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim)
@_onnx_symbolic("aten::amin")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim)
@_onnx_symbolic("aten::aminmax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i")
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
reduce_kwargs = {"keepdims_i": keepdim}
if not symbolic_helper._is_none(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
reduce_kwargs["axes_i"] = [dim]
return g.op("ReduceMin", self, **reduce_kwargs), g.op(
"ReduceMax", self, **reduce_kwargs
)
@_onnx_symbolic("aten::exp")
def exp(g: jit_utils.GraphContext, self):
return g.op("Exp", self)
@_onnx_symbolic("aten::dropout_")
@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "i")
def dropout(g: jit_utils.GraphContext, input, p, train):
symbolic_helper.check_training_mode(train, "dropout")
# if train is False, dropout is no-op
if not train:
return input
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
return r
@_onnx_symbolic(
"aten::alpha_dropout_",
decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")],
) # See Note [Export inplace]
@_onnx_symbolic(
"aten::feature_alpha_dropout_",
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")],
)
@_onnx_symbolic(
"aten::feature_dropout_",
decorate=[symbolic_helper._apply_params("aten::feature_dropout_")],
)
@_onnx_symbolic(
"aten::feature_alpha_dropout",
decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")],
)
@_onnx_symbolic(
"aten::alpha_dropout",
decorate=[symbolic_helper._apply_params("aten::alpha_dropout")],
)
@_onnx_symbolic(
"aten::feature_dropout",
decorate=[symbolic_helper._apply_params("aten::feature_dropout")],
)
def _unsupported_dropout(name: str):
@symbolic_helper.parse_args("v", "none", "b")
def feature_dropout(g, input, p, train):
# NB: In inference mode, FeatureDropout is exported as an identity op.
if train:
return symbolic_helper._unimplemented(name, "training mode", input)
return input
return feature_dropout
@_onnx_symbolic("aten::norm")
@symbolic_helper.parse_args("v", "t", "is", "i", "v")
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
if p == 1:
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1")
elif p == 2:
f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2")
else:
raise errors.SymbolicValueError(
"ONNX export only p-norms with p of 1 or 2", self
)
result = f(g, self, dim=dim, keepdim=keepdim)
if dtype is not None:
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type())
return result
@_onnx_symbolic("aten::conv_tbc")
@symbolic_helper.parse_args("v", "v", "v", "i")
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])
@_onnx_symbolic("aten::_unique")
@symbolic_helper.parse_args("v", "i", "i")
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
return symbolic_helper._onnx_unsupported("_unique", input)
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
@_onnx_symbolic("aten::_cast_Byte")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
@_onnx_symbolic("aten::_cast_Char")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
@_onnx_symbolic("aten::_cast_Short")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
@_onnx_symbolic("aten::_cast_Int")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
@_onnx_symbolic("aten::_cast_Long")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
@_onnx_symbolic("aten::_cast_Half")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
@_onnx_symbolic("aten::_cast_Float")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
@_onnx_symbolic("aten::_cast_Double")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
@_onnx_symbolic("aten::_cast_Bool")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
@_onnx_symbolic("aten::empty")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty(
g: jit_utils.GraphContext,
sizes,
dtype,
layout,
device,
pin_memory=False,
memory_format=None,
):
return zeros(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::empty_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty_like(
g: jit_utils.GraphContext,
input,
dtype=None,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
return zeros_like(g, input, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::new_empty")
def new_empty(
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
):
self_dtype = symbolic_helper._try_get_scalar_type(self)
if symbolic_helper._is_none(dtype) and self_dtype is not None:
dtype = self_dtype
return empty(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::scalar_tensor")
def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = _type_utils.JitScalarType.FLOAT
scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type())
return scalar
@_onnx_symbolic("aten::tensor")
def tensor(
g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False
):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if symbolic_helper._is_packed_list(data):
if dtype is None:
dtype = _type_utils.JitScalarType.from_value(
symbolic_helper._unpack_list(data)[0]
)
input_list = []
for t in symbolic_helper._unpack_list(data):
shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
t = symbolic_helper._reshape_helper(g, t, shape_reference)
t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type())
input_list.append(t)
return g.op("Concat", *input_list, axis_i=0)
else:
if dtype is None:
dtype = _type_utils.JitScalarType.from_value(data)
if symbolic_helper._is_list(data) and (
symbolic_helper._is_tensor_list(data)
or symbolic_helper._is_scalar_list(data)
):
data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1)
return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type())
@_onnx_symbolic("aten::as_tensor")
def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None):
return tensor(g, data, dtype, device)
@_onnx_symbolic("aten::zeros")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op(
"ConstantOfShape",
sizes,
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
)
@_onnx_symbolic("aten::zeros_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def zeros_like(
g: jit_utils.GraphContext,
input,
dtype=None,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
if symbolic_helper._is_none(dtype):
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
else:
scalar_type = _type_utils.JitScalarType(dtype)
return g.op(
"ConstantOfShape",
shape,
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
)
@_onnx_symbolic("aten::new_zeros")
def new_zeros(
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
):
self_dtype = symbolic_helper._try_get_scalar_type(self)
if symbolic_helper._is_none(dtype) and self_dtype is not None:
dtype = self_dtype
return zeros(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::zero")
def zero(g: jit_utils.GraphContext, self):
self_dtype = symbolic_helper._try_get_scalar_type(self)
return zeros_like(g, self, self_dtype)
@_onnx_symbolic("aten::ones")
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op(
"ConstantOfShape",
sizes,
value_t=torch.tensor([1], dtype=scalar_type.dtype()),
)
@_onnx_symbolic("aten::ones_like")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def ones_like(
g: jit_utils.GraphContext,
input,
dtype=None,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
shape = g.op("Shape", input)
if symbolic_helper._is_none(dtype):
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
else:
scalar_type = _type_utils.JitScalarType(dtype)
return g.op(
"ConstantOfShape",
shape,
value_t=torch.tensor([1], dtype=scalar_type.dtype()),
)
@_onnx_symbolic("aten::new_ones")
def new_ones(
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
):
self_dtype = symbolic_helper._try_get_scalar_type(self)
if symbolic_helper._is_none(dtype) and self_dtype is not None:
dtype = self_dtype
return ones(g, sizes, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::full")
def full(
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
):
const_value = symbolic_helper._maybe_get_const(value, "t")
if symbolic_helper._is_value(const_value):
dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype
tmp = zeros(g, sizes, dtype, layout, device)
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op(
"ConstantOfShape",
sizes,
value_t=const_value.view(1).to(scalar_type.dtype()),
)
@_onnx_symbolic("aten::full_like")
def full_like(
g: jit_utils.GraphContext,
input,
fill_value,
dtype=None,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
fill_value = symbolic_helper._maybe_get_const(fill_value, "f")
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
else:
scalar_type = _type_utils.JitScalarType(dtype)
if symbolic_helper._is_value(fill_value):
tmp = zeros_like(g, input, dtype, layout, device)
fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type())
return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1)))
else:
shape = g.op("Shape", input)
return g.op(
"ConstantOfShape",
shape,
value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()),
)
@_onnx_symbolic("aten::new_full")
def new_full(
g: jit_utils.GraphContext,
self,
size,
fill_value,
dtype,
layout,
device,
pin_memory=False,
):
self_dtype = symbolic_helper._try_get_scalar_type(self)
if symbolic_helper._is_none(dtype) and self_dtype is not None:
dtype = self_dtype
return full(g, size, fill_value, dtype, layout, device, pin_memory)
@_onnx_symbolic("aten::eye")
def eye(g: jit_utils.GraphContext, *args):
if len(args) == 5:
# aten::eye(n, dtype, layout, device, pin_memory)
n, dtype, layout, device, pin_memory = args
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
shape = g.op("Concat", dim_size, dim_size, axis_i=0)
tensor = zeros(g, shape, dtype, layout, device)
return g.op("EyeLike", tensor)
if len(args) == 6:
# aten::eye(n, m, dtype, layout, device, pin_memory)
n, m, dtype, layout, device, pin_memory = args
shape = g.op(
"Concat",
symbolic_helper._unsqueeze_helper(g, n, [0]),
symbolic_helper._unsqueeze_helper(g, m, [0]),
axis_i=0,
)
tensor = zeros(g, shape, dtype, layout, device)
return g.op("EyeLike", tensor)
return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments")
@_onnx_symbolic("aten::slice")
def slice(g: jit_utils.GraphContext, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
dim, start, end, step = args
step = symbolic_helper._parse_arg(step, "i")
if step != 1:
raise errors.SymbolicValueError("step!=1 is currently not supported", self)
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
start.type(), _C.NoneType
)
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
end.type(), _C.NoneType
)
is_start_onnx_const = start.node().kind() == "onnx::Constant"
is_end_onnx_const = end.node().kind() == "onnx::Constant"
if (
((not is_start_none) and (not is_start_onnx_const))
or ((not is_end_none) and (not is_end_onnx_const))
or dim.node().kind() != "onnx::Constant"
):
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
"is a deprecated experimental op. Please use statically allocated "
"variables or export to a higher opset version.",
self,
)
else:
start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0])
end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0])
dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0])
return g.op(
"DynamicSlice",
self,
start_unsqueezed,
end_unsqueezed,
dim_unsqueezed,
)
else:
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
end = (
_constants.INT64_MAX
if is_end_none
else symbolic_helper._parse_arg(end, "i")
)
dim = symbolic_helper._parse_arg(dim, "i")
return symbolic_helper._slice_helper(
g, self, axes=[dim], starts=[start], ends=[end]
)
elif len(args) == 3:
# aten::slice(t[] l, int start, int end, int step) -> t[]
start, end, step = args
dim = 0
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
start.type(), _C.NoneType
)
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
end.type(), _C.NoneType
)
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
end = (
_constants.INT64_MAX
if is_end_none
else symbolic_helper._parse_arg(end, "i")
)
return symbolic_helper._slice_helper(
g, self, axes=[dim], starts=[start], ends=[end]
)
return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments")
@_onnx_symbolic("aten::hardtanh")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
return symbolic_helper._op_with_optional_float_cast(
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
)
@_onnx_symbolic("aten::hardswish")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v")
def hardswish(g: jit_utils.GraphContext, self):
hs = hardsigmoid(g, self)
return g.op("Mul", self, hs)
@_onnx_symbolic("aten::hardsigmoid")
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
@symbolic_helper.parse_args("v")
def hardsigmoid(g: jit_utils.GraphContext, self):
# Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
# See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
return g.op("HardSigmoid", self, alpha_f=1 / 6)
@_onnx_symbolic("aten::tanhshrink")
@symbolic_helper.parse_args("v")
def tanhshrink(g: jit_utils.GraphContext, self):
return g.op("Sub", self, tanh(g, self))
@_onnx_symbolic("aten::hardshrink")
@symbolic_helper.parse_args("v", "f")
def hardshrink(g: jit_utils.GraphContext, self, lambd):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
lambd_op = g.op(
"Constant",
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
)
cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op)))
return g.op(
"Where",
cond,
self,
g.op(
"Constant",
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
),
)
@_onnx_symbolic("aten::softshrink")
@symbolic_helper.parse_args("v", "f")
def softshrink(g: jit_utils.GraphContext, self, lambd):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
lambd_op = g.op(
"Constant",
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
)
gt_cond = gt(g, self, lambd_op)
gt_out = g.op(
"Where",
gt_cond,
sub(g, self, lambd_op),
g.op(
"Constant",
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
),
)
lt_cond = lt(g, self, neg(g, lambd_op))
lt_out = g.op(
"Where",
lt_cond,
add(g, self, lambd_op),
g.op(
"Constant",
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
),
)
return add(g, gt_out, lt_out)
@_onnx_symbolic("aten::alias")
def alias(g: jit_utils.GraphContext, self):
return self
@_onnx_symbolic("aten::unsqueeze")
@symbolic_helper.parse_args("v", "i")
def unsqueeze(g: jit_utils.GraphContext, self, dim):
"""Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`"""
# Handle negative dim
if dim < 0:
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None:
warnings.warn(
"ONNX export unsqueeze with negative axis "
+ str(dim)
+ " might cause the onnx model to be incorrect. "
+ "Negative axis is not supported in ONNX. "
+ "Axis is converted to "
+ str(dim + rank + 1)
+ " based on input shape at export time. "
+ "Passing an tensor of different rank in execution will be incorrect."
)
dim = dim + rank + 1
else:
return symbolic_helper._unimplemented(
"unsqueeze", "negative axis with unknown input rank", self
)
return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim])
@_onnx_symbolic("aten::sort")
# TODO(justinchuby): Support multiple quantized args in output
@symbolic_helper.parse_args("v", "i", "i", "none")
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
if out is not None:
symbolic_helper._unimplemented(
"Sort", "Out parameter is not supported for sort", self
)
self_sizes = symbolic_helper._get_tensor_sizes(self)
try:
dim_size = self_sizes[dim]
except Exception:
# FIXME(justinchuby): Avoid catching Exception.
# Catch a more specific exception instead.
dim_size = None
if dim_size is None:
return symbolic_helper._unimplemented("Sort", "input size not accessible", self)
return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2)
@_onnx_symbolic("aten::numel")
def numel(g: jit_utils.GraphContext, self):
return symbolic_helper._numel_helper(g, self)
@_onnx_symbolic("aten::topk")
# TODO(justinchuby): Support multiple quantized args in output
@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none")
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
if out is not None:
symbolic_helper._unimplemented(
"TopK", "Out parameter is not supported for topk", self
)
if not largest:
symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self)
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
@_onnx_symbolic("prim::convert_element_type")
def convert_element_type(g: jit_utils.GraphContext, self, *args):
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
@_onnx_symbolic("aten::to")
def to(g: jit_utils.GraphContext, self, *args):
def is_aten_to_device_only(args):
if len(args) == 4:
# aten::to(Tensor, Device, bool, bool, memory_format)
return (
args[0].node().kind() == "prim::device"
or args[0].type().isSubtypeOf(_C.ListType.ofInts())
or isinstance(args[0].type(), _C.DeviceObjType)
)
elif len(args) == 5:
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
# When dtype is None, this is a aten::to(device) call
dtype = symbolic_helper._get_const(args[1], "i", "dtype")
return dtype is None
elif len(args) in (6, 7):
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
# When dtype is None, this is a aten::to(device) call
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
return dtype is None
return False
# ONNX doesn't have a concept of a device, so we ignore device-only casts
if is_aten_to_device_only(args):
return self
if len(args) == 4:
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]()
# In this case, the constant value is a tensor not int,
# so symbolic_helper._maybe_get_const(args[0], 'i') would not work.
dtype = args[0]
if (
symbolic_helper._is_value(args[0])
and args[0].node().kind() == "onnx::Constant"
):
tval = symbolic_helper._node_get(args[0].node(), "value")
if isinstance(tval, torch.Tensor):
if len(tval.shape) == 0:
tval = tval.item()
dtype = int(tval)
else:
dtype = tval
if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor):
# aten::to(Tensor, Tensor, bool, bool, memory_format)
dtype = _type_utils.JitScalarType.from_value(args[0])
return g.op(
"Cast",
self,
to_i=dtype.onnx_type(),
)
else:
# aten::to(Tensor, ScalarType, bool, bool, memory_format)
# memory_format is ignored
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
elif len(args) == 5:
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
dtype = symbolic_helper._get_const(args[1], "i", "dtype")
# memory_format is ignored
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
elif len(args) == 6:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
# Layout, device and memory_format are ignored
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
elif len(args) == 7:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
# Layout, device and memory_format are ignored
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self)
@_onnx_symbolic("aten::repeat")
def repeat(g: jit_utils.GraphContext, self, repeats):
dtype = _type_utils.JitScalarType.INT64
shape_ = ones_like(g, repeats, dtype)
self = g.op("Expand", self, shape_)
return g.op("Tile", self, repeats)
@_onnx_symbolic("aten::repeat_interleave")
def repeat_interleave(
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
):
repeats_dim = symbolic_helper._get_tensor_rank(repeats)
repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
input_sizes = symbolic_helper._get_tensor_sizes(self)
if repeats_dim is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
self,
)
if repeats_sizes is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
self,
)
if input_sizes is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of repeat_interleave for unknown input size.",
self,
)
# if dim is None flatten
# By default, use the flattened input array, and return a flat output array
if symbolic_helper._is_none(dim):
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1]))
)
dim = torch.tensor(0, dtype=torch.int64)
else:
dim = symbolic_helper._maybe_get_scalar(dim)
# Handle cases where dim is negative
if dim < 0:
dim += len(input_sizes)
input_sizes_temp = input_sizes.copy()
for idx, input_size in enumerate(input_sizes):
if input_size is None:
input_sizes[idx], input_sizes_temp[idx] = 0, -1
# Cases where repeats is an int or single value tensor
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
if input_sizes[dim] == 0:
return symbolic_helper._onnx_opset_unsupported_detailed(
"repeat_interleave",
9,
13,
"Unsupported along dimension with unknown input size",
self,
)
return symbolic_helper._repeat_interleave_single_value_repeat_helper(
g, self, repeats, dim
)
# Cases where repeats is a 1 dim Tensor
elif repeats_dim == 1:
if input_sizes[dim] == 0:
return symbolic_helper._onnx_opset_unsupported_detailed(
"repeat_interleave",
9,
13,
"Unsupported along dimension with unknown input size",
self,
)
if repeats_sizes[0] is None:
return symbolic_helper._onnx_opset_unsupported_detailed(
"repeat_interleave",
9,
13,
"Unsupported for cases with dynamic repeats",
self,
)
assert (
repeats_sizes[0] == input_sizes[dim]
), "repeats must have the same size as input along dim"
reps = repeats_sizes[0]
else:
raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self)
final_splits = []
r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0)
i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim)
input_sizes[dim], input_sizes_temp[dim] = -1, 1
for idx, r_split in enumerate(r_splits):
i_split = unsqueeze(g, i_splits[idx], dim + 1)
r_concat = [
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
r_split,
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
]
r_concat = g.op("Concat", *r_concat, axis_i=0)
i_split = expand(g, i_split, r_concat, None)
i_split = symbolic_helper._reshape_helper(
g,
i_split,
g.op("Constant", value_t=torch.LongTensor(input_sizes)),
allowzero=0,
)
final_splits.append(i_split)
return g.op("Concat", *final_splits, axis_i=dim)
@_onnx_symbolic("aten::pixel_shuffle")
@symbolic_helper.parse_args("v", "i")
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
dims = symbolic_helper._get_tensor_sizes(self)
if len(dims) != 4:
return symbolic_helper._unimplemented(
"pixel_shuffle", "only support 4d input", self
)
if any(i is None for i in dims[1:]):
after_view = symbolic_helper._reshape_helper(
g,
symbolic_helper._unsqueeze_helper(g, self, [2, 3]),
g.op(
"Constant",
value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]),
),
allowzero=0,
)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
# For dynamic input shapes, two reshapes are performed
reshape_h = symbolic_helper._reshape_helper(
g,
after_transpose,
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
allowzero=0,
)
reshape_w = symbolic_helper._reshape_helper(
g,
reshape_h,
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
allowzero=0,
)
return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5])
else:
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = symbolic_helper._reshape_helper(
g,
self,
g.op(
"Constant",
value_t=torch.tensor(
[
-1,
output_channel,
upscale_factor,
upscale_factor,
dims[2],
dims[3],
]
),
),
allowzero=0,
)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return symbolic_helper._reshape_helper(
g,
after_transpose,
g.op(
"Constant",
value_t=torch.tensor(
[
-1,
output_channel,
dims[2] * upscale_factor,
dims[3] * upscale_factor,
]
),
),
allowzero=0,
)
@_onnx_symbolic("aten::pixel_unshuffle")
@symbolic_helper.parse_args("v", "i")
def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor):
dims = symbolic_helper._get_tensor_sizes(self)
if len(dims) != 4:
return symbolic_helper._unimplemented(
"pixel_shuffle", "only support 4d input", self
)
if any(i is None for i in dims[1:]):
# For dynamic input shapes, two reshapes are performed
reshape_h = symbolic_helper._reshape_helper(
g,
symbolic_helper._unsqueeze_helper(g, self, [3]),
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
allowzero=0,
)
reshape_w = symbolic_helper._reshape_helper(
g,
reshape_h,
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
allowzero=0,
)
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
final_reshape = symbolic_helper._reshape_helper(
g,
after_transpose,
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
allowzero=0,
)
return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3])
else:
output_channel = dims[1] * downscale_factor * downscale_factor
after_view = symbolic_helper._reshape_helper(
g,
self,
g.op(
"Constant",
value_t=torch.tensor(
[
-1,
dims[1],
dims[2] // downscale_factor,
downscale_factor,
dims[3] // downscale_factor,
downscale_factor,
]
),
),
allowzero=0,
)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
return symbolic_helper._reshape_helper(
g,
after_transpose,
g.op(
"Constant",
value_t=torch.tensor(
[
-1,
output_channel,
dims[2] // downscale_factor,
dims[3] // downscale_factor,
]
),
),
allowzero=0,
)
def _generic_rnn(
g: jit_utils.GraphContext,
variant,
input,
initial_states,
all_weights,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first=None,
batch_sizes=None,
):
warnings.warn(
"Exporting a model to ONNX with a batch_size other than 1, "
+ "with a variable length with "
+ variant
+ " can cause an error "
+ "when running the ONNX model with a different batch size. "
+ "Make sure to save the model with a batch size of 1, "
+ "or define the initial states (h0/c0) as inputs of the model. "
)
onnxActivations = [
"Relu",
"Tanh",
"Sigmoid",
"Affine",
"LeakyRelu",
"ThresholdedRelu",
"ScaledTanh",
"HardSigmoid",
"Elu",
"Softsign",
"Softplus",
]
variantToOnnxActivationMap = dict(
zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations)
)
weights_per_layer = 4 if has_biases else 2
# this means that projections are used inside LSTM, so need to tell user that it's not supported
if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * (
1 + bidirectional
):
return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input)
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
layer_weights = [
all_weights[i : i + weights_per_layer]
for i in range(0, len(all_weights), weights_per_layer)
]
if batch_first:
# batch, seq, feat -> seq, batch, feat
input = g.op("Transpose", input, perm_i=[1, 0, 2])
if dropout and train:
return symbolic_helper._unimplemented(
"RNN/GRU/LSTM", "dropout in training mode", input
)
if variant.startswith("RNN"):
nonlinearity = variantToOnnxActivationMap[variant[4:].lower()]
variant = "RNN"
w_hh = all_weights[1]
hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1)
if hidden_size is None:
return symbolic_helper._unimplemented(
"RNN/GRU/LSTM", "unknown hidden size", input
)
unidirectional = not bidirectional
prev_output = input
h_outs = []
if variant == "RNN" or variant == "GRU":
h0 = initial_states
elif variant == "LSTM":
h0, c0 = initial_states
c_outs = []
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
if variant == "GRU":
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
reform_permutation = [(1, 2), (0, 1), (2, 3)]
elif variant == "LSTM":
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
reform_permutation = [(0, 1), (3, 4), (1, 3)]
def reform_weights(g, w, n, intervals):
slices = [
symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n])
for x, y in intervals
]
return g.op("Concat", *slices, axis_i=0)
def transform_weights_no_bias(layer_index):
weights = layer_weights[layer_index]
if variant == "RNN":
weight_ih, weight_hh = weights
elif variant == "GRU" or variant == "LSTM":
weight_ih, weight_hh = (
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
)
return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
)
def transform_weights(layer_index):
weights = layer_weights[layer_index]
if variant == "RNN":
weight_ih, weight_hh, bias_ih, bias_hh = weights
elif variant == "GRU" or variant == "LSTM":
weight_ih, weight_hh, bias_ih, bias_hh = (
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
)
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined]
return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0])
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined]
)
def retrieve_state(x, start, end):
return (
x
if num_layers == 1
else symbolic_helper._slice_helper(
g, x, axes=[0], starts=[start], ends=[end]
)
)
for i in range(num_layers):
if unidirectional:
if weights_per_layer == 4:
weight_ih, weight_hh, bias_concat = transform_weights(i)
else:
weight_ih, weight_hh = transform_weights_no_bias(i)
bias_concat = unused(g)
state_indices = i, i + 1
else:
if weights_per_layer == 4:
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0)
else:
weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i)
weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1)
bias_concat = unused(g)
weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0)
state_indices = 2 * i, 2 * i + 2
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined]
if variant == "LSTM":
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined]
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
if variant == "RNN":
if bidirectional:
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined]
else:
activation = [nonlinearity] # type: ignore[possibly-undefined]
prev_output, h_out = g.op(
"RNN",
*inputs,
outputs=2,
hidden_size_i=hidden_size,
activations_s=activation,
**extra_kwargs,
)
elif variant == "GRU":
prev_output, h_out = g.op(
"GRU",
*inputs,
outputs=2,
hidden_size_i=hidden_size,
linear_before_reset_i=1,
**extra_kwargs,
)
elif variant == "LSTM":
prev_output, h_out, c_out = g.op(
"LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs
)
if bidirectional:
# The ONNX RNN/GRU/LSTM produce an output of dimensions
# seq_len, num_directions, batch, hidden_size
# We have to convert to match pytorch's expected
# seq_len, batch, num_directions * hidden_size
# by first moving num_directions before hidden_size with
# Transpose, and then combining it with hidden_size
# with Reshape.
prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3])
prev_output = symbolic_helper._reshape_helper(
g,
prev_output,
g.op("Constant", value_t=torch.LongTensor([0, 0, -1])),
allowzero=0,
)
else:
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
h_outs.append(h_out) # type: ignore[possibly-undefined]
if variant == "LSTM":
c_outs.append(c_out) # type: ignore[possibly-undefined]
if batch_first:
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined]
if variant == "RNN" or variant == "GRU":
return prev_output, h_outs
elif variant == "LSTM":
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined]
return prev_output, h_outs, c_outs
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
def _lstm_full(
g: jit_utils.GraphContext,
input,
hidden_v,
weight_v,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
):
hidden, weight = symbolic_helper._unpack_list(
hidden_v
), symbolic_helper._unpack_list(weight_v)
return _generic_rnn(
g,
"LSTM",
input,
hidden,
weight,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
)
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
def _lstm_packed(
g: jit_utils.GraphContext,
input,
batch_sizes,
hidden_v,
weight_v,
has_biases,
num_layers,
dropout,
train,
bidirectional,
):
hidden, weight = symbolic_helper._unpack_list(
hidden_v
), symbolic_helper._unpack_list(weight_v)
return _generic_rnn(
g,
"LSTM",
input,
hidden,
weight,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_sizes=batch_sizes,
)
@_onnx_symbolic("aten::lstm")
def lstm(g: jit_utils.GraphContext, *args):
if symbolic_helper._is_tensor_list(args[3]):
return _lstm_packed(g, *args)
else:
return _lstm_full(g, *args)
@_onnx_symbolic("aten::lstm_cell")
def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh):
input = symbolic_helper._unsqueeze_helper(g, self, [0])
hidden = symbolic_helper._unpack_list(hidden)
hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden]
weight = (
(w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh)
)
has_biases = True if symbolic_helper._is_tensor(b_ih) else False
_, h_outs, c_outs = _generic_rnn(
g,
"LSTM",
input,
hidden,
weight,
has_biases,
num_layers=1,
dropout=0,
train=0,
bidirectional=False,
batch_first=False,
)
return symbolic_helper._squeeze_helper(
g, h_outs, [0]
), symbolic_helper._squeeze_helper(g, c_outs, [0])
@_onnx_symbolic(
"aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")]
)
@_onnx_symbolic(
"aten::rnn_tanh",
decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")],
)
@_onnx_symbolic(
"aten::rnn_relu",
decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")],
)
def _one_hidden_rnn(kind: str):
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
def _rnn_full(
g,
input,
hidden,
weight_v,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
):
weight = symbolic_helper._unpack_list(weight_v)
return _generic_rnn(
g,
kind,
input,
hidden,
weight,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
)
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
def _rnn_packed(
g,
input,
batch_sizes,
hidden,
weight_v,
has_biases,
num_layers,
dropout,
train,
bidirectional,
):
weight = symbolic_helper._unpack_list(weight_v)
return _generic_rnn(
g,
kind,
input,
hidden,
weight,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_sizes=batch_sizes,
)
def symbolic(g, *args):
if symbolic_helper._is_tensor_list(args[3]):
return _rnn_packed(g, *args)
else:
return _rnn_full(g, *args)
return symbolic
@_onnx_symbolic("aten::_dim_arange")
@symbolic_helper.parse_args("v", "i")
def _dim_arange(g: jit_utils.GraphContext, like, dim):
like_shape = g.op("Shape", like)
stop = g.op(
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
)
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)
@_onnx_symbolic("aten::detach")
def detach(g: jit_utils.GraphContext, input):
# Erase aten::detach nodes because ONNX is inference only
return input
@_onnx_symbolic("aten::contiguous")
@symbolic_helper.parse_args("v", "i")
def contiguous(g: jit_utils.GraphContext, input, memory_format):
if memory_format > 2: # allower values are any, preserve and contiguous_format
raise errors.SymbolicValueError(
"onnx memory_format support is not implemented", input
)
return input
@_onnx_symbolic("aten::_pack_padded_sequence")
@symbolic_helper.parse_args("v", "v", "i")
def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first):
# Currently there is no PackPadded operator in ONNX. We rely on an
# optimization pass to remove this later. It is an error if all
# PackPadded operators cannot be optimized out.
if batch_first:
input = g.op("Transpose", input, perm_i=[1, 0, 2])
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
raise errors.SymbolicValueError(
"'lengths' must be a Tensor for ONNX export", input
)
# We know it's a TensorType so this check is now safe.
# It's really only necessary because those operators expand to something that
# only works with int32 types in Caffe2...
if (
_type_utils.JitScalarType.from_value(
lengths, _type_utils.JitScalarType.UNDEFINED
)
!= _type_utils.JitScalarType.INT
):
lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32)
return g.op("prim::PackPadded", input, lengths, outputs=2)
@_onnx_symbolic("aten::_pad_packed_sequence")
@symbolic_helper.parse_args("v", "v", "i", "t", "v")
def _pad_packed_sequence(
g: jit_utils.GraphContext,
data,
batch_sizes,
batch_first,
padding_value,
total_length,
):
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
# It is only useful/used when training using data_parallel model, so
# It shouldn't be relevant for ONNX anyway
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
if batch_first:
data = g.op("Transpose", data, perm_i=[1, 0, 2])
return data, lengths
@_onnx_symbolic("aten::randint")
def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
low_i = symbolic_helper._get_const(low, "i", "low")
high_i = symbolic_helper._get_const(high, "i", "high")
if dtype is None:
scalar_type = _type_utils.JitScalarType.INT64
else:
scalar_type = _type_utils.JitScalarType(dtype)
if low_i is None:
raise symbolic_helper._onnx_unsupported("randint", low)
if high_i is None:
raise symbolic_helper._onnx_unsupported("randint", high)
shape = symbolic_helper._maybe_get_const(shapes, "is")
if symbolic_helper._is_value(shape):
shape_const = g.op(
"ConstantOfShape",
shapes,
value_t=torch.tensor([0], dtype=torch.float),
)
randn = g.op(
"RandomUniformLike",
shape_const,
low_f=low_i,
high_f=high_i,
)
else:
randn = g.op(
"RandomUniform",
shape_i=shape,
low_f=low_i,
high_f=high_i,
)
# cast to integer type
int_dtype = _type_utils.JitScalarType.INT64
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
if int_dtype != scalar_type:
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
return randint
@_onnx_symbolic("aten::randint_like")
def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
low_i = symbolic_helper._get_const(low, "i", "low")
high_i = symbolic_helper._get_const(high, "i", "high")
if dtype is None:
scalar_type = _type_utils.JitScalarType.INT64
else:
scalar_type = _type_utils.JitScalarType(dtype)
if low_i is None:
raise symbolic_helper._onnx_unsupported("randint", low)
if high_i is None:
raise symbolic_helper._onnx_unsupported("randint", high)
randn = g.op(
"RandomUniformLike",
self,
low_f=low_i,
high_f=high_i,
)
# cast to integer type
int_dtype = _type_utils.JitScalarType.INT64
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
if int_dtype != scalar_type:
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
return randint
@_onnx_symbolic("aten::randn")
def randn(g: jit_utils.GraphContext, shapes, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
shape = symbolic_helper._maybe_get_const(shapes, "is")
if symbolic_helper._is_value(shape):
shape_const = g.op(
"ConstantOfShape",
shapes,
value_t=torch.tensor([0], dtype=torch.float),
)
return g.op(
"RandomNormalLike",
shape_const,
dtype_i=scalar_type.onnx_type(),
)
return g.op(
"RandomNormal",
shape_i=shape,
dtype_i=scalar_type.onnx_type(),
)
@_onnx_symbolic("aten::rand")
def rand(g: jit_utils.GraphContext, shapes, dtype, *options):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
shape = symbolic_helper._maybe_get_const(shapes, "is")
if symbolic_helper._is_value(shape):
shape_const = g.op(
"ConstantOfShape",
shapes,
value_t=torch.tensor([0], dtype=torch.float),
)
return g.op(
"RandomUniformLike",
shape_const,
dtype_i=scalar_type.onnx_type(),
)
return g.op(
"RandomUniform",
shape_i=shape,
dtype_i=scalar_type.onnx_type(),
)
@_onnx_symbolic("aten::randn_like")
def randn_like(
g: jit_utils.GraphContext,
self,
dtype,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
else:
scalar_type = _type_utils.JitScalarType(dtype)
return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type())
@_onnx_symbolic("aten::rand_like")
def rand_like(
g: jit_utils.GraphContext,
self,
dtype,
layout=None,
device=None,
pin_memory=False,
memory_format=None,
):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
return g.op(
"RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type()
)
@_onnx_symbolic("aten::rrelu")
@symbolic_helper.parse_args("v", "f", "f", "i", "none")
def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator):
if not training:
slope = (upper + lower) / 2.0
return g.op("LeakyRelu", input, alpha_f=slope)
p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower)
return g.op("PRelu", input, p)
@_onnx_symbolic("aten::bernoulli")
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
if out is not None and not symbolic_helper._is_none(out):
symbolic_helper._unimplemented(
"Bernoulli", "out parameter is not supported for bernoulli", input
)
if generator is not None and not symbolic_helper._is_none(generator):
symbolic_helper._unimplemented(
"Bernoulli", "generator is not supported for bernoulli", input
)
dtype = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.UNDEFINED
)
if dtype == _type_utils.JitScalarType.UNDEFINED:
return symbolic_helper._unimplemented(
"Bernoulli", "input dtype not accessible", input
)
rands = g.op(
"RandomUniformLike",
input,
high_f=1.0,
low_f=0.0,
dtype_i=dtype.onnx_type(),
)
prob = p if p is not None and not symbolic_helper._is_none(p) else input
output = g.op("Less", rands, prob)
return g.op("Cast", output, to_i=dtype.onnx_type())
@_onnx_symbolic("aten::log_sigmoid")
@symbolic_helper.parse_args("v")
def log_sigmoid(g: jit_utils.GraphContext, input):
p = g.op("Sigmoid", input)
return g.op("Log", p)
@_onnx_symbolic("aten::erf")
@symbolic_helper.parse_args("v")
def erf(g: jit_utils.GraphContext, input):
return g.op("Erf", input)
@_onnx_symbolic("aten::flatten")
@symbolic_helper.quantized_args(True, False, False)
@symbolic_helper.parse_args("v", "i", "i")
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
dim = symbolic_helper._get_tensor_rank(input)
if dim is None:
return symbolic_helper._unimplemented(
"dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.",
input,
)
if dim == 0:
return symbolic_helper._reshape_helper(g, input, [1])
if dim == 1:
return g.op("Identity", input)
# TODO: remove this as onnx opset 11 spec allows negative axes
if end_dim < 0:
end_dim = dim + end_dim
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1 and end_dim == dim - 1:
return g.op("Flatten", input, axis_i=start_dim)
if start_dim == 0 and end_dim == dim - 2:
return g.op("Flatten", input, axis_i=end_dim + 1)
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
@_onnx_symbolic("aten::nonzero")
@symbolic_helper.parse_args("v")
def nonzero(g: jit_utils.GraphContext, input):
"""Emitted from `torch.nonzero(x, as_tuple=False)`"""
return t(g, g.op("NonZero", input))
@_onnx_symbolic("aten::nonzero_numpy")
# Emitted from `torch.nonzero(x, as_tuple=True)`
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
@_onnx_symbolic("aten::isnan")
@symbolic_helper.parse_args("v")
def isnan(g: jit_utils.GraphContext, input):
output = g.op("IsNaN", input)
return output
@_onnx_symbolic("aten::any")
def _any(g: jit_utils.GraphContext, *args):
# aten::any(Tensor self)
if len(args) == 1:
input = args[0]
dim, keepdim = None, 0
# aten::any(Tensor self, int[]? dim, bool keepdim)
else:
input, dim, keepdim = args
# Can be int list or single int
dim = symbolic_helper._parse_arg(dim, "t")
dim = [int(d) for d in dim.view(-1)]
keepdim = symbolic_helper._parse_arg(keepdim, "i")
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
input_sum = symbolic_helper._reducesum_helper(
g, input, axes_i=dim, keepdims_i=keepdim
)
return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)))
@_onnx_symbolic("aten::all")
def _all(g: jit_utils.GraphContext, *args):
input = g.op("Not", args[0])
# aten::all(Tensor self)
if len(args) == 1:
return g.op("Not", _any(g, input))
# aten::all(Tensor self, int[]? dim, bool keepdim)
else:
return g.op("Not", _any(g, input, args[1], args[2]))
@_onnx_symbolic("aten::narrow")
@symbolic_helper.parse_args("v", "i", "i", "i")
def narrow(g: jit_utils.GraphContext, input, dim, start, length):
return symbolic_helper._slice_helper(
g, input, axes=[dim], starts=[start], ends=[start + length]
)
@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
def argmax(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
def argmin(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
@_onnx_symbolic("aten::scatter")
@symbolic_helper.parse_args("v", "i", "v", "v")
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("Scatter", self, index, src, axis_i=dim)
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
self_scalar_type = _type_utils.JitScalarType.from_value(self)
if self_scalar_type != src_type:
src = g.op("Cast", src, to_i=self_scalar_type.onnx_type())
return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim)
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
scalar_type = symbolic_helper._try_get_scalar_type(self)
if scalar_type is None:
return symbolic_helper._unimplemented(
"scatter_add", "input dtype not accessible", self
)
sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False)
if sizes:
to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype()))
else:
to_add = zeros_like(g, self, scalar_type)
to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src)
return add(g, self, to_add)
@_onnx_symbolic("aten::log2")
def log2(g: jit_utils.GraphContext, self):
_ln2 = 0.693147180559945309
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2)))
@_onnx_symbolic("aten::is_floating_point")
def is_floating_point(g: jit_utils.GraphContext, self):
if symbolic_helper._is_fp(self):
return g.op("Constant", value_t=torch.BoolTensor([1]))
return g.op("Constant", value_t=torch.BoolTensor([0]))
@_onnx_symbolic("aten::__is_")
def __is_(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_none(other):
if symbolic_helper._is_none(self):
return g.op("Constant", value_t=torch.BoolTensor([1]))
return g.op("Constant", value_t=torch.BoolTensor([0]))
return eq(g, self, other)
@_onnx_symbolic("aten::__isnot_")
@wrap_logical_op_with_negation
def __isnot_(g: jit_utils.GraphContext, self, other):
return __is_(g, self, other)
@_onnx_symbolic("aten::one_hot")
def one_hot(g: jit_utils.GraphContext, self, num_classes):
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
# onnxruntime supports limited type combinations for OneHot.
if _type_utils.JitScalarType.from_value(
num_classes, _type_utils.JitScalarType.UNDEFINED
) in {
_type_utils.JitScalarType.UINT8,
_type_utils.JitScalarType.INT8,
_type_utils.JitScalarType.INT,
_type_utils.JitScalarType.INT16,
}:
num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64)
return g.op("OneHot", self, num_classes, values, axis_i=-1)
@_onnx_symbolic("aten::gather")
@symbolic_helper.parse_args("v", "i", "v", "v")
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
return symbolic_helper._unimplemented("gather", "sparse_grad == True", self)
# NOTE: This workaround is needed since GatherElement is only supported
# since opset 11, and Gather in ONNX is not the same as torch.gather.
scalar_type = _type_utils.JitScalarType.from_value(self)
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
index = g.op(
"Cast",
g.op("OneHot", index, depth, values, axis_i=dim),
to_i=scalar_type.onnx_type(),
)
mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index)
return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)
@symbolic_helper.parse_args("v", "is", "i", "i")
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim):
return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim)
@_onnx_symbolic("aten::std")
def std(g: jit_utils.GraphContext, input, *args):
var, _ = var_mean(g, input, *args)
return g.op("Sqrt", var)
@_onnx_symbolic("aten::var")
def var(g: jit_utils.GraphContext, input, *args):
var, _ = var_mean(g, input, *args)
return var
@_onnx_symbolic("aten::var_mean")
def var_mean(g: jit_utils.GraphContext, input, *args):
if len(args) == 1:
return _var_mean(g, input, None, args[0], None)
else:
return _var_mean(g, input, *args)
@_onnx_symbolic("aten::std_mean")
def std_mean(g: jit_utils.GraphContext, input, *args):
var, mean = var_mean(g, input, *args)
return g.op("Sqrt", var), mean
@_onnx_symbolic("aten::logsumexp")
@symbolic_helper.parse_args("v", "is", "i")
def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim)
@_onnx_symbolic("aten::arange")
def arange(g: jit_utils.GraphContext, *args):
def _get_arange_dtype(dtype):
dtype = symbolic_helper._maybe_get_const(dtype, "i")
return dtype
def _float_step_convert(range_tensor):
if symbolic_helper._is_fp(range_tensor):
range_tensor = g.op(
"Cast",
g.op("Ceil", range_tensor),
to_i=_type_utils.JitScalarType.INT64.onnx_type(),
)
return range_tensor
if len(args) == 2 or len(args) == 5:
if len(args) == 2:
# aten::arange(Scalar end, Tensor out)
dtype = None
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
dtype, end, start, step = symbolic_helper._arange_cast_helper(
g, end=args[0], dtype=dtype
)
end = symbolic_helper._unsqueeze_helper(g, end, [0])
range_tensor = _float_step_convert(end)
arange_tensor = symbolic_helper._squeeze_helper(
g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1]
)
return g.op(
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
elif len(args) == 4 or len(args) == 7:
if len(args) == 4:
# aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
dtype = None
else:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
dtype, end, start, step = symbolic_helper._arange_cast_helper(
g, start=args[0], end=args[1], step=args[2], dtype=dtype
)
step = symbolic_helper._unsqueeze_helper(g, step, [0])
end = symbolic_helper._unsqueeze_helper(g, end, [0])
start = symbolic_helper._unsqueeze_helper(g, start, [0])
range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step))
arange_tensor = symbolic_helper._squeeze_helper(
g, nonzero(g, ones(g, range_tensor, None, None, None)), [1]
)
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
return g.op(
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
dtype, end, start, step = symbolic_helper._arange_cast_helper(
g, start=args[0], end=args[1], dtype=dtype
)
end = symbolic_helper._unsqueeze_helper(g, end, [0])
start = symbolic_helper._unsqueeze_helper(g, start, [0])
range_tensor = _float_step_convert(g.op("Sub", end, start))
arange_tensor = g.op(
"Add",
symbolic_helper._squeeze_helper(
g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]
),
start,
)
return g.op(
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments")
@_onnx_symbolic("aten::linspace")
def linspace(
g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory
):
range_tensor = symbolic_helper._arange_helper(g, steps, None)
step = div(
g,
sub(g, end, start),
sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))),
)
return add(g, mul(g, range_tensor, step), start)
@_onnx_symbolic("aten::lift")
def lift(g: jit_utils.GraphContext, self):
# at::lift() is a no-op from the perspective of tracing for onnx
return self
@_onnx_symbolic("aten::masked_fill")
def masked_fill(g: jit_utils.GraphContext, self, mask, value):
"""Implement the masked_fill functionality available for a pytorch tensor in ONNX.
Fills elements of the input tensor with `value` where `mask` is True.
"""
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
value = symbolic_helper._maybe_get_scalar(value)
return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self)
@_onnx_symbolic("aten::masked_fill_")
def masked_fill_(g: jit_utils.GraphContext, self, mask, value):
return masked_fill(g, self, mask, value)
@_onnx_symbolic("aten::index")
def index(g: jit_utils.GraphContext, self, index):
if symbolic_helper._is_packed_list(index):
indices = symbolic_helper._unpack_list(index)
else:
indices = [index]
def try_mask_to_index(index):
if not symbolic_helper._is_none(index) and (
_type_utils.JitScalarType.from_value(
index, _type_utils.JitScalarType.UNDEFINED
)
== _type_utils.JitScalarType.UINT8
or symbolic_helper._is_bool(index)
):
if g.opset < 9:
raise errors.SymbolicValueError(
"Exporting masked indices are only supported after ONNX opset 9.",
self,
)
warnings.warn(
"Exporting aten::index operator with indices of type Byte. "
"Only 1-D indices are supported. In any other case, "
"this will produce an incorrect ONNX graph."
)
index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1])
return index
indices = [try_mask_to_index(idx) for idx in indices]
if len(indices) == 1:
return symbolic_helper._select_helper(
g, self, 0, indices[0], apply_reshape=False
)
else:
# Multiple tensors as indices. Each tensor could either be
# 1. prim::Constant()
# representing ":" in python indexing. E.g. tensor[:, :]
# 2. prim::Constant[value=...] or tensor output
# representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
# For more info on advanced indexing,
# check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# Consider a general case of
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":".
# Same results can be achieved through transposing t into
# t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
# and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t
# and process the tensor indices.
# t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
# After gather, reshape and transpose back.
adv_idx_indices = [
i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx)
]
if len(adv_idx_indices) == 0:
return self
elif len(adv_idx_indices) == 1:
return index_select(
g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]
)
else:
rank = symbolic_helper._get_tensor_rank(self)
if rank is None:
return symbolic_helper._unimplemented(
"aten::index",
"operator of advanced indexing on tensor of unknown rank. "
"Try turning on shape inference during export: "
"torch.onnx._export(..., onnx_shape_inference=True).",
self,
)
# TODO: If indexing is supported natively in ONNX in future opsets,
# update the warning to recommend exporting with higher opset version.
warnings.warn(
"Exporting aten::index operator of advanced indexing in opset "
f"{GLOBALS.export_onnx_opset_version}"
" is achieved by combination of multiple ONNX operators, "
"including Reshape, Transpose, Concat, and Gather. "
"If indices include negative values, the exported graph will produce incorrect results."
)
adv_idx_count = len(adv_idx_indices)
shape_tensor = _shape_as_tensor(g, self)
dim_tensor_list = [
g.op(
"Gather",
shape_tensor,
g.op("Constant", value_t=torch.LongTensor([dim])),
axis_i=0,
)
for dim in range(rank)
]
self = g.op(
"Transpose",
self,
perm_i=adv_idx_indices
+ [i for i in range(rank) if i not in adv_idx_indices],
)
self = g.op("Flatten", self, axis_i=adv_idx_count)
# Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well.
cum_adv_index = indices[adv_idx_indices[-1]]
multiplier = dim_tensor_list[adv_idx_indices[-1]]
for i in range(adv_idx_count - 2, -1, -1):
adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier)
cum_adv_index = g.op("Add", cum_adv_index, adv_index)
multiplier = g.op(
"Mul", multiplier, dim_tensor_list[adv_idx_indices[i]]
)
# perform gather
self = index_select(g, self, 0, cum_adv_index)
cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index)
# check if all advanced indices are consecutive.
# Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
# to understand how the subarray position is decided.
if adv_idx_indices == list(
range(adv_idx_indices[0], adv_idx_indices[-1] + 1)
):
# unfold regular index axes
folded_adv_idx_shape_list = [
g.op("Constant", value_t=torch.LongTensor([-1]))
] + [
dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices
]
folded_adv_idx_shape = g.op(
"Concat", *folded_adv_idx_shape_list, axis_i=0
)
self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape)
# Transpose folded advanced indexed axis to its original location.
adv_idx_permute = (
list(range(1, adv_idx_indices[0] + 1))
+ [0]
+ list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1))
)
self = g.op("Transpose", self, perm_i=adv_idx_permute)
# unfold advanced index axes
final_shape_list = (
[dim_tensor_list[i] for i in range(adv_idx_indices[0])]
+ [cum_adv_index_shape_tensor]
+ [
dim_tensor_list[i]
for i in range(adv_idx_indices[0], rank)
if i not in adv_idx_indices
]
)
final_shape = g.op("Concat", *final_shape_list, axis_i=0)
else:
final_shape = g.op(
"Concat",
cum_adv_index_shape_tensor,
*[
dim_tensor_list[i]
for i in range(rank)
if i not in adv_idx_indices
],
axis_i=0,
)
return symbolic_helper._reshape_helper(g, self, final_shape)
@_onnx_symbolic("aten::linalg_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
def linalg_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: torch._C.Value,
dim: Sequence[int] | None,
keepdim: bool,
dtype: torch._C.Value,
):
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
ord_value = None
if dim is None:
if symbolic_helper._is_none(ord):
self = symbolic_helper._reshape_helper(g, self, [-1])
ord = g.op("Constant", value_t=torch.LongTensor([2]))
self_dim = symbolic_helper._get_tensor_rank(self)
if self_dim is None:
return symbolic_helper._unimplemented(
"dim", "Input rank must be known at export time.", self
)
if self_dim == 1:
ord_value = symbolic_helper._parse_arg(ord, "f")
else:
dim = [0, 1]
else:
if len(dim) == 1:
if symbolic_helper._is_none(ord):
ord = g.op("Constant", value_t=torch.LongTensor([2]))
ord_value = symbolic_helper._parse_arg(ord, "f")
if ord_value:
return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype)
return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
def linalg_vector_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: float,
dim: Sequence[int] | None,
keepdim: bool,
dtype: torch._C.Value,
):
return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::linalg_matrix_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
def linalg_matrix_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
ord: torch._C.Value,
dim: list[int],
keepdim: bool,
dtype: torch._C.Value,
):
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html
ord_value = symbolic_helper._parse_arg(ord, "s")
if ord_value == "fro":
return frobenius_norm(g, self, dim, keepdim)
elif ord_value == "nuc":
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self)
else:
ord_value = symbolic_helper._parse_arg(ord, "f")
if ord_value is None:
return frobenius_norm(g, self, dim, keepdim)
if ord_value == 2 or ord_value == -2:
# ord = 2/-2 unimplemented due to lack of operators
# used to calculate singular values
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self)
# Wrap the dim vector to handle negative dim values
self_dim = symbolic_helper._get_tensor_rank(self)
if self_dim is None:
return symbolic_helper._unimplemented(
"linalg.matrix_norm", "Input rank must be known at export time.", self
)
# Common implementation for cases with
# ord = 1/-1 and ord = inf/-inf
if dim[0] < 0:
dim[0] += self_dim
if dim[1] < 0:
dim[1] += self_dim
if ord_value == math.inf or ord_value == -math.inf:
dim[0], dim[1] = dim[1], dim[0]
if dim[1] > dim[0] and not keepdim:
dim[1] -= 1
sum = symbolic_helper._reducesum_helper(
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
)
if ord_value > 0:
result, indices = max(
g,
sum,
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
keepdim=keepdim,
)
else:
result, indices = min(
g,
sum,
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
keepdim=keepdim,
)
return result
@_onnx_symbolic("aten::linalg_cross")
@symbolic_helper.parse_args("v", "v", "i")
def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1):
return cross(g, input, other, dim)
@_onnx_symbolic("aten::frobenius_norm")
@symbolic_helper.parse_args("v", "is", "b")
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
sqr = g.op("Mul", self, self)
sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
return g.op("Sqrt", sumsqr)
@_onnx_symbolic("aten::multinomial")
@symbolic_helper.parse_args("v", "i", "b", "v")
def multinomial(
g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None
):
if generator is not None and not symbolic_helper._is_none(generator):
symbolic_helper._unimplemented(
"Multinomial", "generator is not supported for multinomial", input
)
if not replacement and num_samples > 1:
symbolic_helper._unimplemented(
"Multinomial",
"replacement=False when num_samples > 1 is not supported for multinomial",
input,
)
log_input = log(g, input)
return g.op(
"Multinomial",
log_input,
dtype_i=_C_onnx.TensorProtoDataType.INT64,
sample_size_i=num_samples,
)
@_onnx_symbolic("aten::baddbmm")
def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha):
scalar_type = _type_utils.JitScalarType.from_value(self)
batch_mul = matmul(g, batch1, batch2)
mul_a = mul(
g,
batch_mul,
g.op("Cast", alpha, to_i=scalar_type.onnx_type()),
)
mul_b = mul(
g,
self,
g.op("Cast", beta, to_i=scalar_type.onnx_type()),
)
return add(g, mul_a, mul_b)
@_onnx_symbolic("aten::meshgrid")
@symbolic_helper.parse_args("v", "s")
def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None):
if indexing is None:
indexing = "ij"
elif indexing not in {"ij", "xy"}:
raise errors.SymbolicValueError(
f"Unsupported indexing: {indexing}", tensor_list
)
unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list)
if indexing == "xy":
unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1]
tensors = [
symbolic_helper._reshape_helper(
g, t, g.op("Constant", value_t=torch.LongTensor([-1]))
)
for t in unpacked_tensor_list
]
tensors_shape = [g.op("Shape", t) for t in tensors]
out_shape = g.op("Concat", *tensors_shape, axis_i=0)
out = []
for i, t in enumerate(tensors):
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(
tensors
)
shape_i[i] = tensors_shape[i]
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
out.append(g.op("Expand", t_reshaped, out_shape))
if indexing == "xy":
out[0], out[1] = out[1], out[0]
return g.op("prim::ListConstruct", *out)
@_onnx_symbolic("aten::remainder")
def remainder(g: jit_utils.GraphContext, input, other):
div = _floor_divide(g, input, other)
quo = g.op("Mul", div, other)
return g.op("Sub", input, quo)
@_onnx_symbolic("aten::gelu")
@symbolic_helper.parse_args("v", "s")
def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"):
if approximate == "tanh":
kBeta = math.sqrt(2 / math.pi)
kKappa = 0.044715
beta = torch.tensor(kBeta, dtype=torch.double)
kappa = torch.tensor(kKappa, dtype=torch.double)
one = torch.tensor(1.0, dtype=torch.double)
half = torch.tensor(0.5, dtype=torch.double)
self_cube = mul(g, self, mul(g, self, self))
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
else:
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(
g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))
)
return mul(
g,
mul(g, self, erf_plusone),
g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)),
)
@_onnx_symbolic("aten::group_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i")
def group_norm(
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled
):
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
if channel_size is not None:
assert channel_size % num_groups == 0
input_rank = symbolic_helper._get_tensor_rank(input)
if input_rank is None:
return symbolic_helper._unimplemented("group_norm", "unknown input rank", input)
# 0 in the shape list keeps dimension value unchanged.
shape = [0, num_groups, -1]
input_reshaped = symbolic_helper._reshape_helper(
g, input, g.op("Constant", value_t=torch.LongTensor(shape))
)
# C is always divisible by num_groups
# Due to shape difference. we need to apply weight and bias after
# instance norm computation and reshape
weight_ = g.op(
"Constant",
value_t=torch.tensor(
[1.0] * num_groups,
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
),
)
bias_ = g.op(
"Constant",
value_t=torch.tensor(
[0.0] * num_groups,
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
),
)
norm_reshaped = g.op(
"InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps
)
norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input))
if weight is None or weight.node().mustBeNone():
weight_value = torch.tensor(
[1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
)
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().mustBeNone():
bias_value = torch.tensor(
[0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
)
bias = g.op("Constant", value_t=bias_value)
# Norm has shape [N, C, *] so we reshape weight and bias to [C, *]
axes = list(range(1, input_rank - 1))
return add(
g,
mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)),
symbolic_helper._unsqueeze_helper(g, bias, axes),
)
@_onnx_symbolic("aten::_weight_norm")
@symbolic_helper.parse_args("v", "v", "i")
def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim):
rank = symbolic_helper._get_tensor_rank(weight_v)
if rank is not None:
# W = g * ((v) / ||v||)
# Compute norm_except_dim for l2 norm. dim = None means over all dims
# torch's weight_norm module sets dim = -1 if it's None.
# This conflicts the logic for negative axes to access dims backwards
# TODO: Might need a fix in torch group_norm module
axes = list(range(rank))
if dim is not None:
if dim < -1:
dim += rank
if dim != -1:
axes.remove(dim)
norm_v = norm(g, weight_v, 2, axes, 1)
div = g.op("Div", weight_v, norm_v)
return g.op("Mul", div, weight_g)
raise errors.SymbolicValueError(
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.",
weight_v,
)
@_onnx_symbolic("aten::dim")
def dim(g: jit_utils.GraphContext, self):
"""Implement the dim functionality available for a pytorch tensor in ONNX"""
# ONNX does not support dim directly in this opset so we can use 2 ops to get the info
shape = g.op("Shape", self)
return g.op("Size", shape)
@_onnx_symbolic("aten::__contains_")
def __contains_(g: jit_utils.GraphContext, self, element):
unpacked_list = symbolic_helper._unpack_list(self)
if all(
symbolic_helper._is_constant(x) for x in unpacked_list
) and symbolic_helper._is_constant(element):
return g.op(
"Constant",
value_t=torch.tensor(
symbolic_helper._node_get(element.node(), "value")
in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list)
),
)
raise errors.SymbolicValueError(
"Unsupported: ONNX export of __contains__ for non-constant list or element.",
self,
)
@_onnx_symbolic("aten::__getitem_")
def __getitem_(g: jit_utils.GraphContext, self, i):
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
@_onnx_symbolic("aten::item")
def item(g: jit_utils.GraphContext, self):
return self
@_onnx_symbolic("aten::take")
def take(g: jit_utils.GraphContext, self, index):
self_flattened = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
)
out = index_select(g, self_flattened, 0, index)
out = reshape_as(g, out, index)
return out
def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target):
diff_ = sub(g, target, input)
exp_ = exp(g, target)
output = mul(g, exp_, diff_)
return output
def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target):
log_ = log(g, target)
diff_ = sub(g, log_, input)
output_pos = mul(g, target, diff_)
zeros_ = zeros_like(g, output_pos)
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
output = where(g, mask_, output_pos, zeros_)
return output
@_onnx_symbolic("aten::kl_div")
@symbolic_helper.parse_args("v", "v", "i", "b")
def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target):
if log_target:
output = _kl_div_log_target_impl(g, input, target)
else:
output = _kl_div_non_log_target_impl(g, input, target)
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
else:
return symbolic_helper._onnx_unsupported(
"kl_div with reduction other than none, mean, or sum.", input
)
@_onnx_symbolic("aten::mse_loss")
@symbolic_helper.parse_args("v", "v", "i")
def mse_loss(g: jit_utils.GraphContext, input, target, reduction):
output = mul(g, sub(g, input, target), sub(g, input, target))
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
else:
return symbolic_helper._onnx_unsupported(
"mse_loss with reduction other than none, mean, or sum.", input
)
@_onnx_symbolic("aten::as_strided")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "is", "i")
def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None):
sizes = symbolic_helper._maybe_get_const(sizes, "is")
rank = len(strides)
self_1d = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
)
ind: torch.Tensor | None
if not symbolic_helper._is_value(sizes):
ind = torch.tensor([0], dtype=torch.long)
for i, (size, stride) in enumerate(zip(sizes, strides)):
r_size = [1] * rank
r_size[i] = -1
ind = ind + torch.arange(size).view(r_size) * stride
if offset:
ind = ind + offset
return g.op("Gather", self_1d, g.op("Constant", value_t=ind))
else:
ind = None
for i, stride in enumerate(strides):
r_size = [1] * rank
r_size[i] = -1
size = select(
g,
sizes,
g.op("Constant", value_t=torch.tensor([0])),
g.op("Constant", value_t=torch.tensor(i)),
)
tmp_ind = symbolic_helper._reshape_helper(
g,
arange(g, size, 4, None, None, None),
g.op("Constant", value_t=torch.tensor(r_size)),
)
tmp_ind = g.op(
"Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride]))
)
if ind is None:
ind = tmp_ind
else:
ind = g.op("Add", ind, tmp_ind)
if offset:
ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
return g.op("Gather", self_1d, ind)
@_onnx_symbolic("aten::__derive_index")
def __derive_index(g: jit_utils.GraphContext, index, start, step):
return g.op("Add", start, g.op("Mul", index, step))
@_onnx_symbolic("aten::__range_length")
# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp
# if (step > 0 && lo < hi) {
# push(stack, 1 + (hi - 1 - lo) / step);
# } else if (step < 0 && lo > hi) {
# push(stack, 1 + (lo - 1 - hi) / (0 - step));
# } else {
# push(stack, 0);
# }
def __range_length(g: jit_utils.GraphContext, lo, hi, step):
sub = g.op("Sub", hi, lo)
div = g.op("Ceil", true_divide(g, sub, step))
return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64)
@_onnx_symbolic("aten::linear")
def linear(g: jit_utils.GraphContext, input, weight, bias):
rank = symbolic_helper._get_tensor_rank(input)
weight = t(g, weight)
if rank == 2 and not bias.node().mustBeNone():
alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
output = addmm(g, bias, input, weight, alpha, beta)
else:
output = matmul(g, input, weight)
if not bias.node().mustBeNone():
output = add(g, bias, output)
return output
@_onnx_symbolic("aten::hann_window")
@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v")
def hann_window(
g: jit_utils.GraphContext,
window_length,
periodic=True,
dtype: int | None = None,
layout=None,
device=None,
pin_memory=None,
requires_grad=False,
):
if dtype is None:
dtype_ = torch.get_default_dtype()
if not dtype_ or not dtype_.is_floating_point:
dtype_ = torch.float
scalar_type = _type_utils.JitScalarType.from_dtype(dtype_)
else:
scalar_type = _type_utils.JitScalarType(dtype)
n_array = arange(g, window_length, 4, None, None, None)
output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT)
output = mul(
g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output
)
if periodic is False:
window_length = sub(
g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int))
)
output = div(g, output, window_length)
output = g.op(
"Cast",
square(g, sin(g, output)),
to_i=scalar_type.onnx_type(),
)
return output
@_onnx_symbolic("aten::mv")
def mv(g: jit_utils.GraphContext, self, vec):
return matmul(g, self, vec)
@_onnx_symbolic("aten::dot")
def dot(g: jit_utils.GraphContext, self, other):
return matmul(g, self, other)
@_onnx_symbolic("aten::movedim")
@symbolic_helper.parse_args("v", "t", "t")
def movedim(g: jit_utils.GraphContext, self, source, destination):
# This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim
source = source.view(-1)
destination = destination.view(-1)
assert source.size() == destination.size()
if (source == destination).all():
return self
self_rank = symbolic_helper._get_tensor_rank(self)
assert self_rank is not None
perm = list(range(self_rank))
src_dims = perm.copy()
dst_dims = perm.copy()
for src, dst in zip(source.tolist(), destination.tolist()):
perm[dst] = src
src_dims[src] = -1
dst_dims[dst] = -1
src_dims = [dim for dim in src_dims if dim != -1]
dst_dims = [dim for dim in dst_dims if dim != -1]
for src, dst in zip(src_dims, dst_dims):
perm[dst] = src
return g.op("Transpose", self, perm_i=perm)
@_onnx_symbolic("aten::fill")
@symbolic_helper.parse_args("v", "v")
def fill(g: jit_utils.GraphContext, self, value):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
return full_like(g, self, value, scalar_type)
@_onnx_symbolic("aten::index_add")
def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
warnings.warn(
"Warning: ONNX export does not support duplicated values in 'index' field, "
+ "this will cause the ONNX model to be incorrect."
)
# ONNX does not support "alpha" argument, unlike aten index_add
# See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
return symbolic_helper._unimplemented("index_add", "alpha != 1", self)
dim = symbolic_helper._maybe_get_const(dim, "i")
if dim is None:
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting 'index_add_()' function with "
"unknown 'dim' value.",
self,
)
self_dim_rank = symbolic_helper._get_tensor_rank(self)
other_dim_rank = symbolic_helper._get_tensor_rank(other)
if self_dim_rank is None or other_dim_rank is None:
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting 'index_add_()' function while "
"the rank of self tensor or tensor to be added is unknown.",
self,
)
if other_dim_rank != self_dim_rank:
delta = self_dim_rank - other_dim_rank
for i in range(delta):
other = symbolic_helper._unsqueeze_helper(
g, other, [symbolic_helper._get_tensor_rank(other)]
)
other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim)
self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
if (other_dim_size is not None) and (self_dim_size is not None):
if other_dim_size > self_dim_size:
raise errors.SymbolicValueError(
"ONNX export does not support exporting 'index_add_()' function with "
"duplicated values in 'index' parameter yet.",
self,
)
# Construct a new shape. It's almost as same as self except the size of the 'dim'
# dimension is 1, so that we can expand other dimensions as expected.
new_shape_axes = list(range(self_dim_rank))
new_shape_starts = [0 for i in range(self_dim_rank)]
new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)]
new_shape = symbolic_helper._slice_helper(
g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends
)
other = expand_as(g, other, new_shape)
for i in range(dim):
index = symbolic_helper._unsqueeze_helper(g, index, [0])
for i in range(self_dim_rank - dim - 1):
index = symbolic_helper._unsqueeze_helper(
g, index, [symbolic_helper._get_tensor_rank(index)]
)
return scatter_add(g, self, dim, expand_as(g, index, other), other)
@_onnx_symbolic("aten::roll")
@symbolic_helper.parse_args("v", "is", "is")
def roll(g: jit_utils.GraphContext, self, shifts, dims):
assert len(shifts) == len(dims)
result = self
for i in range(len(shifts)):
shapes = []
shape = symbolic_helper._slice_helper(
g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize]
)
shapes.append(shape)
shape = symbolic_helper._slice_helper(
g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]]
)
shapes.append(shape)
result = g.op("Concat", *shapes, axis_i=dims[i])
return result
@_onnx_symbolic("aten::cross")
@symbolic_helper.parse_args("v", "v", "i")
def cross(g: jit_utils.GraphContext, input, other, dim=None):
dim = symbolic_helper._get_dim_for_cross(input, dim)
# If we have two tensors such that
# A = [a, b, c], B = [d, e, f], we permute the tensor such that we have
# After first roll,
# A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e)
roll_x_1 = roll(g, input, [2], [dim])
roll_y_1 = roll(g, other, [1], [dim])
# After second roll,
# A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d)
roll_x_2 = roll(g, input, [1], [dim])
roll_y_2 = roll(g, other, [2], [dim])
# cross product is calculated as
# result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)]
return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2))
@_onnx_symbolic("aten::cdist")
def cdist(
g: jit_utils.GraphContext,
x1,
x2,
p=2.0,
compute_mode="use_mm_for_euclid_dist_if_necessary",
):
# X1.shape = (B * P * D), X2.shape = (B * R * D)
# In order to respect numpy style broadcasting as demonstrated in
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
# we unsqueeze both input tensors
# Currently we ignore the 'compute_mode' variable as we use default to
# using matrix multiplication to calculate the euclidean distance
rank = symbolic_helper._get_tensor_rank(x1)
assert rank is not None
broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1])
broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2])
return pairwise_distance(
g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False
)
@_onnx_symbolic("aten::lerp")
def lerp(g: jit_utils.GraphContext, self, end, weight):
# Conditional for better numeric. This has been discussed in
# https://github.com/pytorch/pytorch/pull/18871
diff = g.op("Sub", end, self)
return where(
g,
g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))),
g.op("Add", self, g.op("Mul", weight, diff)),
g.op(
"Sub",
end,
g.op(
"Mul",
diff,
g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight),
),
),
)
@_onnx_symbolic("aten::broadcast_tensors")
def broadcast_tensors(g: jit_utils.GraphContext, self):
all_tensors = symbolic_helper._unpack_list(self)
t_with_final_shape = zeros_like(g, all_tensors[0])
# Add operator supports multidirectional broadcasting. So we leverage this function
# to infer the final shape generated by the broadcast.
for t in all_tensors:
t_with_final_shape = add(g, t_with_final_shape, t)
t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors]
return g.op("prim::ListConstruct", *t_list)
@_onnx_symbolic("aten::is_pinned")
def is_pinned(g: jit_utils.GraphContext, self, device=None):
# Unused by ONNX.
return None
@_onnx_symbolic("prim::ConstantSplit")
def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim):
size = symbolic_helper._get_tensor_dim_size(self, dim)
if size is None:
return symbolic_helper._unimplemented(
"prim::ConstantSplit", "unknown dimension size", self
)
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
# TODO: It would be better to export this as a chunk directly, as this is
# less sensitive to changes in input size.
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
# method, and use the desugared version
@_onnx_symbolic("prim::ConstantChunk")
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
if dim_size is None:
return symbolic_helper._unimplemented(
"prim::ConstantChunk", "unknown dimension size", self
)
split_size = (dim_size + chunks - 1) // chunks
return prim_constant_split(g, self, split_size, dim)
@_onnx_symbolic("prim::shape")
def prim_shape(g: jit_utils.GraphContext, self):
return g.op("Shape", self)
@_onnx_symbolic("prim::max")
def prim_max(g: jit_utils.GraphContext, self, other):
return symbolic_helper._op_with_optional_float_cast(
g, "Max", self, other, opset_before=12
)
@_onnx_symbolic("prim::min")
def prim_min(g: jit_utils.GraphContext, self, other=None):
if not other:
if symbolic_helper._is_packed_list(self):
self = stack(g, self, g.op("Constant", value_t=torch.tensor([0])))
return min(g, self)
return min(g, self, other)
@_onnx_symbolic("prim::data")
def prim_data(g: jit_utils.GraphContext, self):
return self
@_onnx_symbolic("prim::layout")
def prim_layout(g: jit_utils.GraphContext, self):
# Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'.
# Layout class defined in 'c10/core/Layout.h'.
return g.op("Constant", value_t=torch.tensor(0))
@_onnx_symbolic("prim::ListConstruct")
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
return None
@_onnx_symbolic("prim::ListUnpack")
def prim_list_unpack(
g: jit_utils.GraphContext, *inputs, **kwargs
) -> list[_C.Value] | None:
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
# Cancel the previous node if it is ListConstruct by returning its inputs
# TODO(justinchuby): Use a public method in the helper module
return symbolic_helper._unpack_list(inputs[0])
return None
@_onnx_symbolic("prim::TupleConstruct")
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
return None
@_onnx_symbolic("prim::Uninitialized")
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs):
return None
# exists to refine the type of the Value
# if x is an optional Tensor, unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor
# this doesn't do anything in runtime and is a noop in ONNX
@_onnx_symbolic("prim::unchecked_cast")
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
return self
@_onnx_symbolic("prim::dtype")
def prim_dtype(g: jit_utils.GraphContext, self):
scalar_type = symbolic_helper._try_get_scalar_type(self)
if scalar_type is None:
scalar_type = _type_utils.JitScalarType.FLOAT
# This node records a torch dtype as int
return g.op("Constant", value_t=torch.tensor(scalar_type))
@_onnx_symbolic("prim::tolist")
def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val):
"""tolist is currently supported only for 1D input tensors.
dim_val and elem_ty_val represent dimension and type annotations
that need to match dimension and type of the input tensor.
"""
dim = symbolic_helper._maybe_get_const(dim_val, "i")
if dim > 1:
return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input)
return input
# -----------------------------------------------------------------------------
# Symbolic functions that need extra context
# -----------------------------------------------------------------------------
@_onnx_symbolic("prim::device")
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
output_type = g.original_node.output().type()
if isinstance(output_type, _C.DeviceObjType):
return None
return symbolic_helper._unimplemented(
"prim::device",
f"output type should be 'DeviceObjType', not '{output_type.kind()}'",
g.original_node.output(),
)
@_onnx_symbolic("prim::Loop")
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
node = g.original_node
env = g.env
values_in_env = g.values_in_env
params_dict = g.params_dict
operator_export_type = GLOBALS.operator_export_type
opset_version = GLOBALS.export_onnx_opset_version
old_blocks = tuple(node.blocks())
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
)
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
# Copy input metadata to subblock
#
# prim::Loop(iter, cond, input_1, ..., input_n)
# block0(iter, input_1, ..., input_n)
#
# For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`.
for i, b_in in enumerate(old_block.inputs()):
if i == 0 and i < len(inputs):
b_in.setType(inputs[i].type())
# For optional block inputs, they may switch between None not-None inside
# the loop body, so if the loop input is not optional, the block input may
# still need to be optional.
if (
i > 0
and (i + 1) < len(inputs)
and not isinstance(b_in.type(), _C.OptionalType)
):
b_in.setType(inputs[i + 1].type())
torch._C._jit_pass_onnx_block(
old_block,
new_block_context.block,
operator_export_type,
env,
values_in_env,
False,
)
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
new_node, opset_version
)
# Run shape type inference for Loop after subblock is converted.
if GLOBALS.onnx_shape_inference:
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
return fixed_outputs
@_onnx_symbolic("prim::If")
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]:
n = g.original_node
block = g.block
env = g.env
values_in_env = g.values_in_env
params_dict = g.params_dict
operator_export_type = GLOBALS.operator_export_type
opset_version = GLOBALS.export_onnx_opset_version
static_if = inputs[0].node().kind() == "onnx::Constant"
if static_if:
# Fold static if
#
# The torch IR
# graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu),
# %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ...
# %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %21 : Long(device=cpu) = aten::eq(%20, %64)
# %22 : Long(device=cpu) = prim::If(%21)
# block0():
# %23 : Long(device=cpu) = aten::is_floating_point(%input.1)
# -> (%23)
# block1():
# -> (%65)
# %input.53 : Tensor, %weight : Tensor = prim::If(%22)
# block0():
# -> (%embedding_matrix.1, %input.1)
# block1():
# -> (%input.1, %embedding_matrix.1)
# %26 : int[] = aten::size(%input.53)
#
# The converted ONNX graph
# %10 : Bool(device=cpu) = onnx::Constant[value={0}]()
# %14 : Bool(device=cpu) = onnx::Equal(%13, %8)
# %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
# %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1)
input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist()
const_value = (
all(input_flag) if isinstance(input_flag, list) else bool(input_flag)
)
block_idx = 0 if const_value else 1
current_b = list(n.blocks())[block_idx]
env = torch._C._jit_pass_onnx_block(
current_b,
block,
operator_export_type,
env,
values_in_env,
True,
)
if_output_list = list(n.outputs())
current_b_list = list(current_b.outputs())
final_b_list = []
for idx in range(len(if_output_list)):
if current_b_list[idx] not in env:
raise errors.SymbolicValueError(
f"The sub block ATen output {current_b_list[idx]} is not in env.",
current_b_list[idx],
) # type:ignore[operator]
onnx_b = env[current_b_list[idx]]
final_b_list.append(onnx_b)
return final_b_list
else:
old_blocks = tuple(n.blocks())
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
)
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
torch._C._jit_pass_onnx_block(
old_block,
new_block_context.block,
operator_export_type,
env,
values_in_env,
False,
)
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
new_node, opset_version
)
# Run shape type inference for If after subblock is converted.
if GLOBALS.onnx_shape_inference:
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
return fixed_outputs
@_onnx_symbolic("prim::Constant")
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs):
node = g.original_node
if node.mustBeNone():
return None
# This must go before checking for string values, because some device constants
# have string values, but we want to keep them as unconverted Device types so
# that eq() can work on them.
if isinstance(node.output().type(), _C.DeviceObjType):
return None
if node.kindOf("value") == "t":
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value"))
if node.kindOf("value") == "s":
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value"))
if node.output().type().isSubtypeOf(
_C.ListType.ofInts()
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()):
return g.op(
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value"))
)
if node.output().type().isSubtypeOf(_C.ListType.ofStrings()):
str_constants = [
g.op("Constant", value_s=s)
for s in symbolic_helper._node_get(node, "value")
]
return g.op("prim::ListConstruct", *str_constants)
raise errors.SymbolicValueError(
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. "
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
node.output(),
)
@_onnx_symbolic("prim::type")
def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs):
if device_value.node().kind() == "prim::device":
device = jit_utils.get_device_from_value(device_value.node().input())
if device is not None:
return g.op("Constant", value_s=str(device))
return symbolic_helper._unimplemented(
"prim::type",
"Device type cannot be statically determined.",
device_value,
)
@_onnx_symbolic("onnx::Placeholder")
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs):
node = g.original_node
block = g.block
env = g.env
values_in_env = g.values_in_env
return torch._C._jit_onnx_convert_pattern_from_subblock(
block, node, env, values_in_env
)
@_onnx_symbolic("aten::resolve_conj")
@_onnx_symbolic("aten::resolve_neg")
def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
# ONNX does not have operators to *directly* manipulate real/imaginary components
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
# which results in failures due to missing operators for complex numbers
# `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op
return input
@_onnx_symbolic("aten::_conj")
@_onnx_symbolic("aten::conj_physical")
def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
# ONNX does not have operators to *directly* manipulate real/imaginary components
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
# which results in failures due to missing operators for complex numbers
# While `aten::_conj` and `aten::conj_physical` raise exception when input is complex
if symbolic_helper.is_complex_value(input):
# FIXME(justinchuby): report correct name for symbolic being executed
return symbolic_helper._onnx_unsupported(
"aten::_conj, aten::conj_physical",
input,
)
# they can safely be implemented as no-op for real numbers only
return noop_complex_operators(g, input)
@_onnx_symbolic("aten::logit")
def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value):
one = g.op("Constant", value_t=torch.tensor(1.0))
if not symbolic_helper._is_none(eps):
eps = g.op(
"Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()
)
one_sub_eps = g.op("Sub", one, eps)
self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self)
temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps)
temporary_self_less_eps = g.op("Less", temporary_self, eps)
z = g.op("Where", temporary_self_less_eps, eps, temporary_self)
else:
z = self
sub = g.op("Sub", one, z)
div = g.op("Div", z, sub)
return g.op("Log", div)