Files
pytorch/torch/onnx/symbolic_opset10.py

889 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import functools
import sys
import warnings
from typing import Callable
import torch
import torch._C._onnx as _C_onnx
import torch.onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import (
_constants,
_type_utils,
errors,
symbolic_helper,
symbolic_opset9 as opset9,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 10
# Opset 10 is supported by ONNX release 1.5.0
# release on 04/24/19
__all__ = [
"dequantize",
"div",
"embedding_bag",
"fake_quantize_per_tensor_affine",
"flip",
"fmod",
"isfinite",
"isinf",
"nan_to_num",
"quantize_per_tensor",
"quantized_add_relu",
"quantized_add",
"quantized_cat",
"quantized_conv1d_relu",
"quantized_conv2d_relu",
"quantized_conv2d",
"quantized_group_norm",
"quantized_hardswish",
"quantized_instance_norm",
"quantized_layer_norm",
"quantized_leaky_relu",
"quantized_linear",
"quantized_mul",
"quantized_sigmoid",
"slice",
"sort",
"topk",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::div")
@_beartype.beartype
def div(g: jit_utils.GraphContext, self, other, *args):
if len(args) == 0:
return opset9.true_divide(g, self, other)
else:
return _div_rounding_mode(g, self, other, *args)
@symbolic_helper.parse_args("v", "v", "s")
@_beartype.beartype
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
if rounding_mode == "floor":
return _floor_divide(g, self, other)
else:
return opset9._div_rounding_mode(g, self, other, rounding_mode)
@_onnx_symbolic("aten::_floor_divide")
@_beartype.beartype
def _floor_divide(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = opset9.true_divide(g, self, other)
return g.op("Floor", out)
else:
# Integer division does trunction rounding
div = g.op("Div", self, other)
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Mod", self, other, fmod_i=0)
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
fixup = g.op("Sub", div, one)
return g.op("Where", fixup_mask, fixup, div)
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
)
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[
_apply_params(
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[
_apply_params(
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[
_apply_params(
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool1d_with_indices",
decorate=[
_apply_params(
"max_pool1d_with_indices",
torch.nn.modules.utils._single,
1,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool2d_with_indices",
decorate=[
_apply_params(
"max_pool2d_with_indices",
torch.nn.modules.utils._pair,
2,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool3d_with_indices",
decorate=[
_apply_params(
"max_pool3d_with_indices",
torch.nn.modules.utils._triple,
3,
return_indices=True,
)
],
)
@_beartype.beartype
def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if not stride:
stride = kernel_size
kwargs = {
"kernel_shape_i": tuple_fn(kernel_size),
"pads_i": tuple_fn(padding) * 2,
"strides_i": tuple_fn(stride),
"ceil_mode_i": ceil_mode,
}
if set(tuple_fn(dilation)) != {1}:
kwargs["dilations_i"] = tuple_fn(dilation)
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op(
"MaxPool",
input,
outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)],
)
# convert indices to have non-flattened indices values
s = symbolic_helper._slice_helper(
g,
flattened_indices,
axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0),
ends=tuple_fn(1),
)
indices = opset9.sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)],
)
@_beartype.beartype
def _avg_pool(name, tuple_fn):
# Although onnx::AvgPool provides count_include_pad and ceil_mode,
# The corner case of Average Pooling with ceil_mode on
# PyTorch allows sliding window go off bound, which leads to
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
return opset9._avg_pool(name, tuple_fn)
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
@_beartype.beartype
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
if scales is None:
scales = symbolic_helper._interpolate_size_to_scales(
g, input, output_size, dim
)
return g.op("Resize", input, scales, mode_s=interpolate_mode)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Resize", input, scales, mode_s=mode)
@_beartype.beartype
def _slice(
g: jit_utils.GraphContext,
input,
axes,
starts,
ends,
steps=None,
dynamic_slice=False,
):
if dynamic_slice:
starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
ends = symbolic_helper._unsqueeze_helper(g, ends, [0])
if isinstance(axes, int):
axes = g.op("Constant", value_t=torch.tensor(axes))
axes = symbolic_helper._unsqueeze_helper(g, axes, [0])
else:
assert len(starts) == len(ends)
assert len(starts) == len(axes)
assert steps is None or len(starts) == len(steps)
if (
len(starts) == 1
and starts[0] == 0
and ends[0] == _constants.INT64_MAX
and (steps is None or (len(steps) == 1 and steps[0] == 1))
):
return input
if ends[0] > _constants.INT64_MAX:
ends[0] = _constants.INT64_MAX
axes = g.op("Constant", value_t=torch.tensor(axes))
starts = g.op("Constant", value_t=torch.tensor(starts))
ends = g.op("Constant", value_t=torch.tensor(ends))
if steps is None:
return g.op("Slice", input, starts, ends, axes)
steps = g.op("Constant", value_t=torch.tensor(steps))
return g.op("Slice", input, starts, ends, axes, steps)
@_onnx_symbolic("aten::slice")
@_beartype.beartype
def slice(g: jit_utils.GraphContext, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
dim, start, end, step = args
elif len(args) == 3:
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
start, end, step = args
dim = 0
else:
raise errors.SymbolicValueError("Unknown aten::slice signature", self)
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
start.type(), _C.NoneType
)
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
end.type(), _C.NoneType
)
is_start_onnx_const = start.node().kind() == "onnx::Constant"
is_end_onnx_const = end.node().kind() == "onnx::Constant"
step = symbolic_helper._parse_arg(step, "i")
if (
(not is_start_none and not is_start_onnx_const)
or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const)
or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant")
):
dynamic_slice = True
if is_start_none:
start = g.op("Constant", value_t=torch.tensor(0))
if is_end_none:
end = g.op("Constant", value_t=torch.tensor(_constants.INT64_MAX))
else:
start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")]
end = [
_constants.INT64_MAX
if is_end_none
else symbolic_helper._parse_arg(end, "i")
]
dim = [symbolic_helper._parse_arg(dim, "i")]
dynamic_slice = False
return symbolic_helper._slice_helper(
g,
self,
axes=dim,
starts=start,
ends=end,
steps=[step],
dynamic_slice=dynamic_slice,
)
@_onnx_symbolic("aten::flip")
@symbolic_helper.parse_args("v", "is")
@_beartype.beartype
def flip(g: jit_utils.GraphContext, input, dims):
return symbolic_helper._slice_helper(
g,
input,
axes=dims,
starts=[-1] * len(dims),
ends=[-_constants.INT64_MAX] * len(dims),
steps=[-1] * len(dims),
)
@_onnx_symbolic("aten::fmod")
@_beartype.beartype
def fmod(g: jit_utils.GraphContext, input, other):
return g.op("Mod", input, other, fmod_i=1)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return symbolic_helper._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
warnings.warn(
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
"Please use opset 11 or higher to export model for dynamic input shape.'"
)
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
if offsets_dim_0 is not None:
if include_last_offset:
offset_len = offsets_dim_0 - 1
offsets_extended = offsets
else:
offset_len = offsets_dim_0
offsets_extended = [
offsets,
g.op("Constant", value_t=torch.tensor([sys.maxsize])),
]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
for i in range(offset_len):
start_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
[0],
)
end_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
),
[0],
)
axes_ = g.op("Constant", value_t=torch.tensor([0]))
indices_row = g.op("Slice", indices, start_, end_, axes_)
embeddings = g.op("Gather", embedding_matrix, indices_row)
if not symbolic_helper._is_none(per_sample_weights):
per_sample_weights_row = g.op(
"Slice", per_sample_weights, start_, end_, axes_
)
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
g, per_sample_weights_row, [1]
)
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = symbolic_helper._reducesum_helper(
g, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
list_.append(embeddings)
output = g.op("Concat", *list_, axis_i=0)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return output, None, None, None
else:
return symbolic_helper._onnx_unsupported(
"embedding_bag with unknown shape of offsets for opset 10 is not supported. "
"please use opset 11 or higher."
)
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
quant_min=-128,
quant_max=127,
):
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) == (0, 127):
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Quantize range (0, 127) not supported, requires opset 13 Clip",
inputs,
)
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise errors.SymbolicValueError(
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
f"Got ({quant_min}, {quant_max})",
inputs,
)
scale = symbolic_helper._maybe_get_scalar(scale)
if scale is None:
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Non-constant scale not supported",
inputs,
)
scale = scale.float().data # Avoid exporter generating double type
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point),
scale,
zero_point,
)
@_onnx_symbolic("aten::isinf")
@_beartype.beartype
def isinf(g: jit_utils.GraphContext, input):
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
@_onnx_symbolic("aten::isfinite")
@_beartype.beartype
def isfinite(g: jit_utils.GraphContext, input):
inf_node = isinf(g, input)
nan_node = opset9.isnan(g, input)
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
@_onnx_symbolic("aten::quantize_per_tensor")
@_beartype.beartype
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
# TODO(justinchuby): Extract all the cast ops into a helper function.
zero_point = g.op(
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
@_onnx_symbolic("aten::dequantize")
@_beartype.beartype
def dequantize(g: jit_utils.GraphContext, input):
return symbolic_helper.dequantize_helper(g, input)[0]
@_onnx_symbolic("aten::nan_to_num")
@symbolic_helper.parse_args("v", "f", "f", "f")
@_beartype.beartype
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
# Cannot create a int type tensor with inf/nan values, so we simply
# return the original tensor
if not symbolic_helper._is_fp(input):
return input
input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
if nan is None:
nan = 0.0
nan_cond = opset9.isnan(g, input)
nan_result = g.op(
"Where",
nan_cond,
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
input,
)
# For None values of posinf, neginf we use the greatest/lowest finite
# value representable by inputs dtype.
finfo = torch.finfo(input_dtype)
if posinf is None:
posinf = finfo.max
posinf_cond = opset9.logical_and(
g,
isinf(g, nan_result),
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
)
nan_posinf_result = g.op(
"Where",
posinf_cond,
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
nan_result,
)
if neginf is None:
neginf = finfo.min
neginf_cond = opset9.logical_and(
g,
isinf(g, nan_posinf_result),
opset9.lt(
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
),
)
return g.op(
"Where",
neginf_cond,
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
nan_posinf_result,
)
# Quantized symbolics ---------------------------------------------------------
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
# introduced in opset version 10.
@_onnx_symbolic("quantized::linear")
@_beartype.beartype
def quantized_linear(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.linear(g, input, weight, bias)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add")
@_beartype.beartype
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add_relu")
@_beartype.beartype
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::mul")
@_beartype.beartype
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.mul(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::sigmoid")
@_beartype.beartype
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.sigmoid(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::leaky_relu")
@_beartype.beartype
def quantized_leaky_relu(
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.leaky_relu(g, x, negative_slope, inplace)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::layer_norm")
@_beartype.beartype
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
normalized_shape,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::group_norm")
@_beartype.beartype
def quantized_group_norm(
g: jit_utils.GraphContext,
x,
num_groups,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::instance_norm")
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
@_beartype.beartype
def quantized_instance_norm(
g: jit_utils.GraphContext,
q_input,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
output = opset9.instance_norm(
g, input, weight, bias, None, None, False, 0.0, eps, False
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv1d_relu")
@_beartype.beartype
def quantized_conv1d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d_relu")
@_beartype.beartype
def quantized_conv2d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d")
@_beartype.beartype
def quantized_conv2d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::cat")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def quantized_cat(
g: jit_utils.GraphContext,
q_inputs: _C.Value,
dim: int,
op_scale: _C.Value,
op_zero_point: _C.Value,
) -> _C.Value:
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
dequantized = [
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
]
concatenated = g.op("Concat", *dequantized, axis_i=dim)
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)