mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly. This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185 Differential Revision: D13494873 Pulled By: houseroad fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8
1407 lines
47 KiB
Python
1407 lines
47 KiB
Python
import numbers
|
|
|
|
import torch
|
|
from torch._C import DynamicType, ListType
|
|
from torch.nn.modules.utils import _single, _pair, _triple
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
import warnings
|
|
|
|
import torch.onnx
|
|
# This import monkey-patches graph manipulation methods on Graph, used for the
|
|
# ONNX symbolics
|
|
import torch.onnx.utils
|
|
|
|
from collections import Iterable
|
|
from functools import partial, wraps
|
|
import itertools
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
#
|
|
# - This file is ONLY for ATen operators (e.g., operators that show up in the
|
|
# trace as aten::blah). If you need to special case a primitive operator,
|
|
# look at _run_symbolic_function
|
|
# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
|
|
# tensors are always first, then non-tensor arguments.
|
|
# - Parameter names must *exactly* match the names in VariableType.cpp, because
|
|
# dispatch is done with keyword arguments.
|
|
# - Looking for inplace ops? They're detected by the trailing underscore, and
|
|
# transparently dispatched to their non inplace versions in
|
|
# 'run_symbolic_function'. See Note [Export inplace]
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Helper functions
|
|
# ---------------------------------------------------------------------
|
|
|
|
# Save some builtins as locals, because we'll shadown them below
|
|
_sum = sum
|
|
|
|
|
|
def _parse_arg(value, desc):
|
|
if desc == 'none':
|
|
return value
|
|
if desc == 'v' or not _is_value(value):
|
|
return value
|
|
if value.node().kind() != 'onnx::Constant':
|
|
raise RuntimeError("ONNX symbolic expected a constant value in the trace")
|
|
tval = value.node()['value']
|
|
if desc == 'i':
|
|
return int(tval)
|
|
elif desc == 'f':
|
|
return float(tval)
|
|
elif desc == 't':
|
|
return tval
|
|
elif desc == 'is':
|
|
return [int(v) for v in tval]
|
|
else:
|
|
raise RuntimeError("Casting constants to `{}` is not implemented".format(desc))
|
|
|
|
|
|
def _maybe_get_const(value, desc):
|
|
if _is_value(value) and value.node().kind() == 'onnx::Constant':
|
|
return _parse_arg(value, desc)
|
|
return value
|
|
|
|
|
|
def _maybe_get_scalar(value):
|
|
value_t = _maybe_get_const(value, 't')
|
|
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
|
|
return value_t
|
|
return value
|
|
|
|
|
|
def _get_const(value, desc, arg_name):
|
|
if _is_value(value) and value.node().kind() != 'onnx::Constant':
|
|
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
|
|
return _parse_arg(value, desc)
|
|
|
|
|
|
def _unpack_list(list_value):
|
|
list_node = list_value.node()
|
|
assert list_node.kind() == "prim::ListConstruct"
|
|
return list(list_node.inputs())
|
|
|
|
|
|
def parse_args(*arg_descriptors):
|
|
def decorator(fn):
|
|
def wrapper(g, *args):
|
|
assert len(arg_descriptors) == len(args)
|
|
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
|
|
return fn(g, *args)
|
|
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
|
|
try:
|
|
wrapper = wraps(fn)(wrapper)
|
|
except Exception:
|
|
pass
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def _scalar(x):
|
|
"""Convert a scalar tensor into a Python value."""
|
|
assert x.numel() == 1
|
|
return x.item()
|
|
|
|
|
|
def _if_scalar_type_as(g, self, tensor):
|
|
"""
|
|
Convert self into the same type of tensor, as necessary.
|
|
|
|
We only support implicit casting for scalars, so we never
|
|
actually need to insert an ONNX cast operator here; just
|
|
fix up the scalar.
|
|
"""
|
|
if isinstance(self, torch._C.Value):
|
|
return self
|
|
elif tensor.type().kind() == "TensorType" or tensor.type().kind() == "CompleteTensorType":
|
|
ty = tensor.type().scalarType().lower()
|
|
return getattr(self, ty)()
|
|
else:
|
|
return self
|
|
|
|
|
|
def _is_value(x):
|
|
return isinstance(x, torch._C.Value)
|
|
|
|
|
|
def _is_tensor_list(x):
|
|
return x.type().isSubtypeOf(ListType.ofTensors())
|
|
|
|
|
|
def _unimplemented(op, msg):
|
|
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
|
|
|
|
|
|
def _try_get_scalar_type(*args):
|
|
for arg in args:
|
|
try:
|
|
return arg.type().scalarType()
|
|
except RuntimeError:
|
|
pass
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# ONNX operator version
|
|
# ---------------------------------------------------------------------
|
|
|
|
# READ ME BEFORE EDITING _onnx_opset_version:
|
|
#
|
|
# The variable below controls which ONNX operator set version we are
|
|
# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
|
|
# change occurred in version 8. As long as this variable < 8, you can
|
|
# export models targeting the old behavior. However, if you bump
|
|
# this variable to 8 or later, the breaking change will take into effect:
|
|
# you MUST adjust any symbolic affected by breaking changes. The ONNX
|
|
# spec publishes a *comprehensive* list of BC-breaking changes for every
|
|
# operator revision at:
|
|
#
|
|
# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
|
|
#
|
|
# Please be sure to go through and check all of our implementations here before
|
|
# increasing this number. This includes symbolic definitions NOT in this
|
|
# file, so grep for "OpName" (with quotes)
|
|
|
|
_onnx_opset_version = 9
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Symbolic definitions
|
|
# ---------------------------------------------------------------------
|
|
|
|
|
|
# Note [Pointwise by scalar]
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
|
|
# some moving parts to implementing the ONNX translation in this case:
|
|
#
|
|
# - By the time we get the scalar in a symbolic function here, it is no longer
|
|
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
|
|
# want it to be a zero dim tensor but this change has not happened yet.)
|
|
# However, the type of this scalar is *exactly* what the user wrote in
|
|
# Python, which may not match the tensor it is being added to. PyTorch
|
|
# will do implicit conversions on scalars; however, ONNX will not, so
|
|
# we must do the conversion ourselves. This is what _if_scalar_type_as
|
|
# does.
|
|
#
|
|
# - Dispatch to these functions takes advantage an outrageous coincidence
|
|
# between the tensor and scalar name. When we add two tensors together,
|
|
# you get the dispatch:
|
|
#
|
|
# add(*[self, other], **{"alpha": alpha})
|
|
#
|
|
# When you add a tensor and a scalar, you get the dispatch:
|
|
#
|
|
# add(*[self], **{"other": other, "alpha": alpha})
|
|
#
|
|
# By having the argument name line up with the name of the scalar attribute
|
|
# if it exists, we can write a single function for both overloads.
|
|
#
|
|
|
|
# used to represent "missing" optional inputs
|
|
def unused(g):
|
|
return g.op("prim::Undefined")
|
|
|
|
|
|
def _shape_as_tensor(g, input):
|
|
return g.op('Shape', input)
|
|
|
|
|
|
def _reshape_from_tensor(g, input, shape):
|
|
return g.op('Reshape', input, shape)
|
|
|
|
|
|
def add(g, self, other, alpha=None):
|
|
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
|
|
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
|
|
return _unimplemented("add", "alpha != 1")
|
|
# See Note [Pointwise by scalar]
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Add", self, _if_scalar_type_as(g, other, self))
|
|
|
|
|
|
def sub(g, self, other, alpha=None):
|
|
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
|
|
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
|
|
return _unimplemented("sub", "alpha != 1")
|
|
# See Note [Pointwise by scalar]. Note that self or other may be scalars.
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Sub", self, _if_scalar_type_as(g, other, self))
|
|
|
|
|
|
def rsub(g, self, other, alpha=None):
|
|
return sub(g, other, self, alpha=alpha)
|
|
|
|
|
|
def mul(g, self, other):
|
|
# See Note [Pointwise by scalar]
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Mul", self, _if_scalar_type_as(g, other, self))
|
|
|
|
|
|
def div(g, self, other):
|
|
# See Note [Pointwise by scalar]
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Div", self, _if_scalar_type_as(g, other, self))
|
|
|
|
|
|
def reciprocal(g, self):
|
|
return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def cat(g, tensor_list, dim):
|
|
tensors = _unpack_list(tensor_list)
|
|
return g.op("Concat", *tensors, axis_i=dim)
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def stack(g, tensor_list, dim):
|
|
unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
|
|
return g.op("Concat", *unsqueezed, axis_i=dim)
|
|
|
|
|
|
def mm(g, self, other):
|
|
# Create a dummy C tensor. Only needed for API purposes, the value is
|
|
# since beta = 0
|
|
ty = _try_get_scalar_type(self, other).lower()
|
|
C = g.constant(0, [1], ty)
|
|
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
|
|
|
|
|
|
def bmm(g, self, other):
|
|
return g.op("MatMul", self, other)
|
|
|
|
|
|
def matmul(g, self, other):
|
|
return g.op("MatMul", self, other)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 't', 't')
|
|
def addmm(g, self, mat1, mat2, beta, alpha):
|
|
return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
|
|
|
|
|
|
def neg(g, self):
|
|
return g.op("Neg", self)
|
|
|
|
|
|
def sqrt(g, self):
|
|
return g.op("Sqrt", self)
|
|
|
|
|
|
def tanh(g, self):
|
|
return g.op("Tanh", self)
|
|
|
|
|
|
def sin(g, self):
|
|
return g.op("Sin", self)
|
|
|
|
|
|
def cos(g, self):
|
|
return g.op("Cos", self)
|
|
|
|
|
|
def tan(g, self):
|
|
return g.op("Tan", self)
|
|
|
|
|
|
def asin(g, self):
|
|
return g.op("Asin", self)
|
|
|
|
|
|
def acos(g, self):
|
|
return g.op("Acos", self)
|
|
|
|
|
|
def atan(g, self):
|
|
return g.op("Atan", self)
|
|
|
|
|
|
def sigmoid(g, self):
|
|
return g.op("Sigmoid", self)
|
|
|
|
|
|
def _reduce_op_symbolic(onnx_op_name):
|
|
def symbolic(g, self, dim=None, keepdim=None):
|
|
if dim is None:
|
|
# all-reduce path
|
|
return g.op(onnx_op_name, self, keepdims_i=0)
|
|
else:
|
|
# dim-reduce path
|
|
dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
|
|
return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
|
|
return symbolic
|
|
|
|
mean = _reduce_op_symbolic('ReduceMean')
|
|
sum = _reduce_op_symbolic('ReduceSum')
|
|
prod = _reduce_op_symbolic('ReduceProd')
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def cumsum(g, input, dim):
|
|
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
|
|
|
|
|
|
def t(g, self):
|
|
return g.op("Transpose", self, perm_i=(1, 0))
|
|
|
|
|
|
def expand(g, self, size, implicit):
|
|
size = _maybe_get_const(size, 'is')
|
|
if not _is_value(size):
|
|
size = g.op("Constant", value_t=torch.LongTensor(size))
|
|
return g.op("Expand", self, size)
|
|
|
|
|
|
def expand_as(g, self, other):
|
|
shape = g.op("Shape", other)
|
|
return g.op("Expand", self, shape)
|
|
|
|
|
|
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
|
|
return g.op("Gather", weight, indices)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'i', 'i', 'i')
|
|
def embedding_bag(g,
|
|
embedding_matrix,
|
|
indices,
|
|
offsets,
|
|
scale_grad_by_freq,
|
|
mode,
|
|
sparse):
|
|
return g.op("ATen",
|
|
embedding_matrix,
|
|
indices,
|
|
offsets,
|
|
operator_s="embedding_bag",
|
|
outputs=4,
|
|
scale_grad_by_freq_i=scale_grad_by_freq,
|
|
mode_i=mode,
|
|
sparse_i=sparse)
|
|
|
|
|
|
def size(g, self, dim):
|
|
full_shape = g.op("Shape", self)
|
|
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
|
|
|
|
|
|
@parse_args('v', 'i', 'i')
|
|
def transpose(g, self, dim0, dim1):
|
|
if dim0 == dim1: # micro-optimization
|
|
return self
|
|
|
|
# NB: Transpose in ONNX is actually a Permute
|
|
axes = list(range(len(self.type().sizes())))
|
|
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
|
|
return g.op("Transpose", self, perm_i=axes)
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def permute(g, self, dims):
|
|
if dims == list(range(0, len(dims))):
|
|
return self
|
|
return g.op("Transpose", self, perm_i=dims)
|
|
|
|
|
|
def view(g, self, size):
|
|
size = _maybe_get_const(size, 'is')
|
|
if _is_value(size):
|
|
shape = size
|
|
else:
|
|
if self.isTensor():
|
|
self_sizes = self.type().sizes()
|
|
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
|
|
return g.op("Flatten", self, axis_i=1)
|
|
shape = g.op("Constant", value_t=torch.LongTensor(size))
|
|
return g.op("Reshape", self, shape)
|
|
|
|
|
|
def prim_ConstantSplit(g, self, split_size, dim):
|
|
size = self.type().sizes()[dim]
|
|
splits = [split_size] * (size // split_size)
|
|
leftover = size % split_size
|
|
if leftover:
|
|
splits.append(leftover)
|
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
|
|
|
|
|
|
# TODO: It would be better to export this as a chunk directly, as this is
|
|
# less sensitive to changes in input size.
|
|
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
|
|
# method, and use the desugared version
|
|
def prim_ConstantChunk(g, self, chunks, dim):
|
|
split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
|
|
return prim_ConstantSplit(g, self, split_size, dim)
|
|
|
|
|
|
@parse_args('v', 'i', 'v')
|
|
def select(g, self, dim, index):
|
|
if dim > 1:
|
|
# TODO: this is a temporary hack because of the implementation details
|
|
# of Gather in caffe2. We need to change this as soon as possible.
|
|
# TODO: this breaks if index == -1
|
|
index_val = _parse_arg(index, 'i')
|
|
slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index_val], ends_i=[index_val + 1])
|
|
return g.op("Squeeze", slice_node, axes_i=[dim])
|
|
else:
|
|
return g.op("Gather", self, index, axis_i=dim)
|
|
|
|
|
|
def squeeze(g, self, dim=None):
|
|
if dim is None:
|
|
dims = []
|
|
for i, size in enumerate(self.type().sizes()):
|
|
if size == 1:
|
|
dims.append(i)
|
|
else:
|
|
dims = [_get_const(dim, 'i', 'dim')]
|
|
return g.op("Squeeze", self, axes_i=dims)
|
|
|
|
|
|
def prelu(g, self, weight):
|
|
return g.op("PRelu", self, weight)
|
|
|
|
|
|
def relu(g, input):
|
|
return g.op("Relu", input)
|
|
|
|
|
|
@parse_args('v', 't', 't')
|
|
def threshold(g, self, threshold, value):
|
|
# See Note [Export inplace]
|
|
if _scalar(threshold) != 0:
|
|
return _unimplemented("threshold", "non-zero threshold")
|
|
if _scalar(value) != 0:
|
|
return _unimplemented("threshold", "non-zero value")
|
|
return g.op("Relu", self)
|
|
|
|
|
|
def leaky_relu(g, input, negative_slope, inplace=False):
|
|
negative_slope = _get_const(negative_slope, 't', 'negative_slope')
|
|
# See Note [Export inplace]
|
|
# TODO: Talk to ONNX about unconditional cast of scalar to float
|
|
return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def glu(g, input, dim):
|
|
assert input.type().sizes()[dim] % 2 == 0
|
|
|
|
first, second = g.op('Split', input, axis_i=dim, outputs=2)
|
|
return g.op('Mul', first, g.op('Sigmoid', second))
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def softmax(g, input, dim):
|
|
# Softmax does normalization at vector level.
|
|
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
|
|
# Thus dim and axis have different meanings.
|
|
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
|
|
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
|
|
# If input is a 2 x 3 tensor:
|
|
# input = [[1.0, 1.0, 1.0],
|
|
# [1.0, 1,0, 1,0]]
|
|
# with dim = 0, the result is:
|
|
# result = [[0.5, 0.5, 0.5],
|
|
# [0.5, 0.5, 0.5]]
|
|
# with axis = 0, the result is:
|
|
# result = [[0.167, 0.167, 0.167],
|
|
# [0.167, 0.167, 0.167]]
|
|
# So only when dim and axis both equal to ndim - 1 (the last dimension),
|
|
# their semantics are equivalent.
|
|
if dim < 0:
|
|
dim = len(input.type().sizes()) + dim
|
|
if len(input.type().sizes()) != dim + 1:
|
|
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
|
|
return g.op('Softmax', input, axis_i=dim)
|
|
|
|
|
|
@parse_args('v', 't', 'v')
|
|
def softplus(g, self, beta, threshold):
|
|
if beta != 1:
|
|
return _unimplemented("beta", "has to be 1")
|
|
return g.op('Softplus', self)
|
|
|
|
|
|
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
|
|
def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
|
if ceil_mode:
|
|
return _unimplemented("max_pool1d_with_indices", "ceil_mode")
|
|
if set(_single(dilation)) != {1}:
|
|
return _unimplemented("max_pool1d_with_indices", "dilation")
|
|
if stride is None:
|
|
stride = kernel_size
|
|
r = g.op("MaxPool", input,
|
|
kernel_shape_i=_single(kernel_size),
|
|
pads_i=_single(padding) * 2,
|
|
strides_i=_single(stride))
|
|
return r, None
|
|
|
|
|
|
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
|
|
def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
|
if ceil_mode:
|
|
return _unimplemented("max_pool2d_with_indices", "ceil_mode")
|
|
if set(_pair(dilation)) != {1}:
|
|
return _unimplemented("max_pool2d_with_indices", "dilation")
|
|
if not stride:
|
|
stride = kernel_size
|
|
r = g.op("MaxPool", input,
|
|
kernel_shape_i=_pair(kernel_size),
|
|
pads_i=_pair(padding) * 2,
|
|
strides_i=_pair(stride))
|
|
return r, None
|
|
|
|
|
|
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
|
|
def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
|
if ceil_mode:
|
|
return _unimplemented("max_pool3d_with_indices", "ceil_mode")
|
|
if set(_triple(dilation)) != {1}:
|
|
return _unimplemented("max_pool3d_with_indices", "dilation")
|
|
if not stride:
|
|
stride = kernel_size
|
|
r = g.op("MaxPool", input,
|
|
kernel_shape_i=_triple(kernel_size),
|
|
pads_i=_triple(padding) * 2,
|
|
strides_i=_triple(stride))
|
|
return r, None
|
|
|
|
|
|
def _avg_pool(name, tuple_fn):
|
|
@parse_args('v', 'is', 'is', 'is', 'i', 'i')
|
|
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
|
|
if ceil_mode:
|
|
return _unimplemented("avg_pool2d", "ceil_mode")
|
|
if not stride:
|
|
stride = kernel_size
|
|
|
|
padding = tuple(tuple_fn(padding))
|
|
if count_include_pad:
|
|
input = g.op("Pad", input,
|
|
pads_i=((0,) * 2 + padding) * 2,
|
|
mode_s='constant',
|
|
value_f=0.)
|
|
padding = (0,) * len(padding)
|
|
|
|
return g.op("AveragePool", input,
|
|
kernel_shape_i=tuple_fn(kernel_size),
|
|
strides_i=tuple_fn(stride),
|
|
pads_i=padding * 2)
|
|
return symbolic_fn
|
|
|
|
|
|
avg_pool1d = _avg_pool('avg_pool1d', _single)
|
|
avg_pool2d = _avg_pool('avg_pool2d', _pair)
|
|
avg_pool3d = _avg_pool('avg_pool3d', _triple)
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def adaptive_avg_pool2d(g, input, output_size):
|
|
assert output_size == [1, 1], "Only output_size=[1, 1] is supported"
|
|
return g.op("GlobalAveragePool", input)
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def adaptive_max_pool2d(g, input, output_size):
|
|
assert output_size == [1, 1], "Only output_size=[1, 1] is supported"
|
|
return g.op("GlobalMaxPool", input), None
|
|
|
|
|
|
@parse_args('v', 'is', 'f')
|
|
def constant_pad_nd(g, input, padding, value):
|
|
from torch.autograd._functions.utils import prepare_onnx_paddings
|
|
mode = "constant"
|
|
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
|
|
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def reflection_pad(g, input, padding):
|
|
from torch.autograd._functions.utils import prepare_onnx_paddings
|
|
mode = "reflect"
|
|
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
|
|
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def replication_pad(g, input, padding):
|
|
from torch.autograd._functions.utils import prepare_onnx_paddings
|
|
mode = "edge"
|
|
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
|
|
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
|
|
|
|
|
|
reflection_pad1d = reflection_pad
|
|
reflection_pad2d = reflection_pad
|
|
reflection_pad3d = reflection_pad
|
|
replication_pad1d = replication_pad
|
|
replication_pad2d = replication_pad
|
|
replication_pad3d = replication_pad
|
|
|
|
|
|
@parse_args('v', 'is')
|
|
def upsample_nearest2d(g, input, output_size):
|
|
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
|
|
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
|
|
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
|
|
width_scale]))
|
|
|
|
return g.op("Upsample", input, scales,
|
|
mode_s="nearest")
|
|
|
|
|
|
@parse_args('v', 'is', 'i')
|
|
def upsample_bilinear2d(g, input, output_size, align_corners):
|
|
if align_corners:
|
|
return _unimplemented("upsample_bilinear2d", "align_corners == True")
|
|
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
|
|
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
|
|
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
|
|
width_scale]))
|
|
return g.op("Upsample", input, scales,
|
|
mode_s="linear")
|
|
|
|
|
|
def wrap_logical_op_with_cast_to_uint8(func):
|
|
def wrap_with_cast(g, input, other):
|
|
return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
|
|
return wrap_with_cast
|
|
|
|
|
|
def wrap_logical_op_with_negation(func):
|
|
def wrap_with_not(g, input, other):
|
|
return g.op("Not", func(g, input, other))
|
|
return wrap_with_not
|
|
|
|
|
|
@wrap_logical_op_with_cast_to_uint8
|
|
def gt(g, input, other):
|
|
return gt_impl(g, input, other)
|
|
|
|
|
|
def gt_impl(g, input, other):
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
|
|
|
|
|
|
@wrap_logical_op_with_cast_to_uint8
|
|
def lt(g, input, other):
|
|
return lt_impl(g, input, other)
|
|
|
|
|
|
def lt_impl(g, input, other):
|
|
other = _maybe_get_scalar(other)
|
|
return g.op("Less", input, _if_scalar_type_as(g, other, input))
|
|
|
|
|
|
@wrap_logical_op_with_cast_to_uint8
|
|
@wrap_logical_op_with_negation
|
|
def ge(g, input, other):
|
|
other = _maybe_get_scalar(other)
|
|
return lt_impl(g, input, _if_scalar_type_as(g, other, input))
|
|
|
|
|
|
@wrap_logical_op_with_cast_to_uint8
|
|
@wrap_logical_op_with_negation
|
|
def le(g, input, other):
|
|
other = _maybe_get_scalar(other)
|
|
return gt_impl(g, input, _if_scalar_type_as(g, other, input))
|
|
|
|
|
|
def where(g, condition, self, other):
|
|
return g.op("ATen", condition, self, other, operator_s="where")
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def log_softmax(g, input, dim=None):
|
|
# PyTorch dim and ONNX axis have different meanings.
|
|
# See Softmax comment for details.
|
|
if dim < 0:
|
|
dim = len(input.type().sizes()) + dim
|
|
if len(input.type().sizes()) != dim + 1:
|
|
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
|
|
return g.op("LogSoftmax", input, axis_i=dim)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
|
|
def _convolution(g, input, weight, bias, stride, padding, dilation,
|
|
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
|
|
weight_size = weight.type().sizes()
|
|
|
|
args = [input, weight]
|
|
# ONNX only supports 1D bias
|
|
if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) == 1:
|
|
args.append(bias)
|
|
|
|
kwargs = {"kernel_shape_i": weight_size[2:],
|
|
"strides_i": stride,
|
|
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
|
|
# symmetric padding
|
|
"pads_i": padding + padding,
|
|
"dilations_i": dilation,
|
|
"group_i": groups}
|
|
|
|
if any(o != 0 for o in output_padding):
|
|
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
|
|
# output_padding is more straightforward, so we use it here.
|
|
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
|
|
assert transposed
|
|
assert len(stride) == len(output_padding)
|
|
kwargs["output_padding_i"] = output_padding
|
|
|
|
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
|
|
|
|
if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) != 1:
|
|
return g.op("Add", n, bias)
|
|
else:
|
|
return n
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
|
|
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
|
|
input_sizes = input.type().sizes()
|
|
if len(input_sizes) == 2:
|
|
# batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
|
|
input = g.op("Unsqueeze", input, axes_i=[2])
|
|
|
|
if weight is None or weight.node().kind() == "prim::Undefined":
|
|
assert len(input_sizes) > 1
|
|
weight_value = torch.tensor([1.] * input_sizes[1]).type(
|
|
'torch.' + input.type().scalarType() + 'Tensor')
|
|
weight = g.op("Constant", value_t=weight_value)
|
|
if bias is None or bias.node().kind() == "prim::Undefined":
|
|
assert len(input_sizes) > 1
|
|
bias_value = torch.tensor([0.] * input_sizes[1]).type(
|
|
'torch.' + input.type().scalarType() + 'Tensor')
|
|
bias = g.op("Constant", value_t=bias_value)
|
|
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
|
|
epsilon_f=eps,
|
|
momentum_f=1 - momentum,
|
|
outputs=1 if not training else 5)
|
|
if not training:
|
|
if len(input_sizes) == 2:
|
|
out = g.op("Squeeze", out, axes_i=[2])
|
|
return out
|
|
else:
|
|
res, new_running_mean, new_running_var, saved_mean, saved_var = out
|
|
new_running_mean.setType(running_mean.type())
|
|
new_running_var.setType(running_var.type())
|
|
saved_mean.setUniqueName("batch_norm_dead_output-" + saved_mean.uniqueName())
|
|
saved_var.setUniqueName("batch_norm_dead_output-" + saved_var.uniqueName())
|
|
if len(input_sizes) == 2:
|
|
res = g.op("Squeeze", res, axes_i=[2])
|
|
return res
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
|
|
def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
|
|
input_sizes = input.type().sizes()
|
|
if weight is None or weight.node().kind() == "prim::Undefined":
|
|
assert len(input_sizes) > 1
|
|
weight_value = torch.tensor([1.] * input_sizes[1]).type(
|
|
'torch.' + input.type().scalarType() + 'Tensor')
|
|
weight = g.op("Constant", value_t=weight_value)
|
|
if bias is None or bias.node().kind() == "prim::Undefined":
|
|
assert len(input_sizes) > 1
|
|
bias_value = torch.tensor([0.] * input_sizes[1]).type(
|
|
'torch.' + input.type().scalarType() + 'Tensor')
|
|
bias = g.op("Constant", value_t=bias_value)
|
|
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
|
|
|
|
|
|
@parse_args('v', 'i', 'i', 'i')
|
|
def unfold(g, input, dimension, size, step):
|
|
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
|
|
|
|
|
|
@parse_args('v', 'v', 'i')
|
|
def _weight_norm(graph, v, g, dim):
|
|
return graph.op("ATen", v, g, dim_i=dim, operator_s="_weight_norm")
|
|
|
|
|
|
@parse_args('v', 't', 't', 't')
|
|
def elu(g, input, alpha, scale, input_scale):
|
|
if scale and scale != 1.:
|
|
return _unimplemented("scale", "does not support scale in Elu")
|
|
if input_scale and input_scale != 1.:
|
|
return _unimplemented("input_scale", "does not support input_scale in Elu")
|
|
# See Note [Export inplace]
|
|
return g.op("Elu", input, alpha_f=_scalar(alpha))
|
|
|
|
|
|
def selu(g, input):
|
|
return g.op("Selu", input)
|
|
|
|
|
|
@parse_args('v', 'i', 'v')
|
|
def index_select(g, self, dim, index):
|
|
return g.op("Gather", self, index, axis_i=dim)
|
|
|
|
|
|
def index_put(g, self, indices_list_value, values, accumulate):
|
|
indices_list = _unpack_list(indices_list_value)
|
|
args = [self] + indices_list + [values, accumulate]
|
|
return g.op("ATen", *args, operator_s='index_put')
|
|
|
|
|
|
def type_as(g, self, other):
|
|
if self.isTensor() and other.isTensor() and self.type().scalarType() == other.type().scalarType():
|
|
return self
|
|
|
|
if other.isTensor():
|
|
other_type_name = other.type().scalarType()
|
|
return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
|
|
else:
|
|
# We don't know the type of other, bail by emitting ATen
|
|
return g.op("ATen", self, other, operator_s="type_as")
|
|
|
|
|
|
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
|
|
def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
|
|
return g.op("ATen", self, weight, bias, normalized_shape_i=normalized_shape,
|
|
eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
|
|
|
|
|
|
# ignore clone operators that are inserted by PyTorch autograd
|
|
def clone(g, input):
|
|
return input
|
|
|
|
|
|
def abs(g, self):
|
|
return g.op("Abs", self)
|
|
|
|
|
|
def log(g, self):
|
|
return g.op("Log", self)
|
|
|
|
|
|
def pow(g, self, exponent):
|
|
exponent = _maybe_get_scalar(exponent)
|
|
return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
|
|
|
|
|
|
def clamp(g, self, min, max):
|
|
# min or max may be prim::None that we need to dispatch to
|
|
# Clip separately, as ONNX does not have None syntax
|
|
if min.node().kind() == "prim::None":
|
|
return clamp_max(g, self, max)
|
|
elif max.node().kind() == "prim::None":
|
|
return clamp_min(g, self, min)
|
|
else:
|
|
min = _parse_arg(min, 'f')
|
|
max = _parse_arg(max, 'f')
|
|
return g.op("Clip", self, min_f=min, max_f=max)
|
|
|
|
|
|
@parse_args('v', 'f')
|
|
def clamp_min(g, self, min):
|
|
return g.op("Clip", self, min_f=min)
|
|
|
|
|
|
@parse_args('v', 'f')
|
|
def clamp_max(g, self, max):
|
|
return g.op("Clip", self, max_f=max)
|
|
|
|
|
|
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
|
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
|
def max(g, self, dim_or_y=None, keepdim=None):
|
|
if dim_or_y is None and keepdim is None:
|
|
return g.op("ReduceMax", self, keepdims_i=0)
|
|
if keepdim is None:
|
|
return g.op("Max", self, dim_or_y)
|
|
else:
|
|
dim = _get_const(dim_or_y, 'i', 'dim')
|
|
keepdim = _get_const(keepdim, 'i', 'keepdim')
|
|
# TODO: export it as ReduceMax
|
|
return g.op("ATen",
|
|
self,
|
|
operator_s="max",
|
|
dim_i=dim,
|
|
keepdim_i=keepdim,
|
|
outputs=2)
|
|
|
|
|
|
def min(g, self, dim_or_y=None, keepdim=None):
|
|
if dim_or_y is None and keepdim is None:
|
|
return g.op("ReduceMin", self, keepdims_i=0)
|
|
if keepdim is None:
|
|
return g.op("Min", self, dim_or_y)
|
|
else:
|
|
dim = _get_const(dim_or_y, 'i', 'dim')
|
|
keepdim = _get_const(keepdim, 'i', 'keepdim')
|
|
# TODO: export it as ReduceMax
|
|
return g.op("ATen",
|
|
self,
|
|
operator_s="min",
|
|
dim_i=dim,
|
|
keepdim_i=keepdim,
|
|
outputs=2)
|
|
|
|
|
|
@wrap_logical_op_with_cast_to_uint8
|
|
def eq(g, self, other):
|
|
return g.op("Equal", self, other)
|
|
|
|
|
|
def ne(g, self, other):
|
|
return g.op("Not", eq(g, self, other))
|
|
|
|
|
|
def exp(g, self):
|
|
return g.op("Exp", self)
|
|
|
|
|
|
@parse_args('v', 'f', 'i')
|
|
def dropout(g, input, p, train):
|
|
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
|
|
return r
|
|
|
|
|
|
def _unsupported_dropout(name):
|
|
@parse_args('v', 'f', 'i')
|
|
def feature_dropout(g, input, p, train):
|
|
# NB: In inference mode, FeatureDropout is exported as an identity op.
|
|
from torch.onnx.symbolic import _unimplemented
|
|
if train:
|
|
return _unimplemented(name, "training mode")
|
|
return input
|
|
return feature_dropout
|
|
|
|
|
|
feature_dropout = _unsupported_dropout("feature_dropout")
|
|
alpha_dropout = _unsupported_dropout("alpha_dropout")
|
|
feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
|
|
|
|
# See Note [Export inplace]
|
|
dropout_ = dropout
|
|
feature_dropout_ = feature_dropout
|
|
alpha_dropout_ = alpha_dropout
|
|
feature_alpha_dropout_ = feature_alpha_dropout
|
|
|
|
|
|
@parse_args('v', 't', 'i', 'i')
|
|
def norm(g, self, p, dim, keepdim):
|
|
if p == 1:
|
|
f = _reduce_op_symbolic("ReduceL1")
|
|
elif p == 2:
|
|
f = _reduce_op_symbolic("ReduceL2")
|
|
else:
|
|
raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
|
|
return f(g, self, dim=dim, keepdim=keepdim)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'i')
|
|
def conv_tbc(g, input, weight, bias, pad):
|
|
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
|
|
|
|
|
|
@parse_args('v', 'i', 'i')
|
|
def _unique(g, input, sorted, return_inverse):
|
|
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
|
|
return_inverse_i=return_inverse, outputs=2)
|
|
|
|
|
|
# Metaprogram symbolics for each ATen native specialized cast operator.
|
|
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
|
|
# ONNX cast node with `to` attribute 'UINT8'
|
|
#
|
|
# TODO: remove these once we support Type's in the JIT IR and we can once again
|
|
# use the unified toType operator
|
|
cast_pytorch_to_onnx = {
|
|
'Byte': torch.onnx.TensorProtoDataType.UINT8,
|
|
'Char': torch.onnx.TensorProtoDataType.INT8,
|
|
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
|
|
'Float': torch.onnx.TensorProtoDataType.FLOAT,
|
|
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
|
|
'Int': torch.onnx.TensorProtoDataType.INT32,
|
|
'Long': torch.onnx.TensorProtoDataType.INT64,
|
|
'Short': torch.onnx.TensorProtoDataType.INT16,
|
|
}
|
|
|
|
scalar_name_to_pytorch = {
|
|
'uint8_t': 'Byte',
|
|
'int8_t': 'Char',
|
|
'double': 'Double',
|
|
'float': 'Float',
|
|
'half': 'Half',
|
|
'int': 'Int',
|
|
'int64_t': 'Long',
|
|
'int16_t': 'Short',
|
|
}
|
|
|
|
|
|
def _cast_func_template(to_i, g, input, non_blocking):
|
|
return g.op("Cast", input, to_i=to_i)
|
|
|
|
|
|
for k, v in cast_pytorch_to_onnx.items():
|
|
name = '_cast_{}'.format(k)
|
|
globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
|
|
|
|
|
|
scalar_type_to_onnx = [
|
|
cast_pytorch_to_onnx["Byte"],
|
|
cast_pytorch_to_onnx["Char"],
|
|
cast_pytorch_to_onnx["Short"],
|
|
cast_pytorch_to_onnx["Int"],
|
|
cast_pytorch_to_onnx["Long"],
|
|
cast_pytorch_to_onnx["Half"],
|
|
cast_pytorch_to_onnx["Float"],
|
|
cast_pytorch_to_onnx["Double"],
|
|
]
|
|
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def zeros(g, sizes, dtype, layout, device):
|
|
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
|
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
|
|
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def zeros_like(g, input, dtype, layout, device):
|
|
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
|
|
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def ones(g, sizes, dtype, layout, device):
|
|
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
|
|
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def ones_like(g, input, dtype, layout, device):
|
|
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
|
|
|
|
|
|
def full(g, sizes, value, dtype, layout, device):
|
|
const_value = _maybe_get_const(value, 't')
|
|
if _is_value(const_value):
|
|
tmp = zeros(sizes, dtype, layout, device)
|
|
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
|
else:
|
|
dtype = _get_const(dtype, 'i', 'dtype')
|
|
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype],
|
|
input_as_shape_i=1, value_f=const_value)
|
|
|
|
|
|
@parse_args('v', 'f', 'i', 'v', 'v')
|
|
def full_like(g, input, fill_value, dtype, layout, device):
|
|
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'v', 'i')
|
|
def slice(g, self, dim, start, end, step):
|
|
if step != 1:
|
|
_unimplemented("slice", "step!=1 is currently not supported")
|
|
if start.node().kind() != 'onnx::Constant' or \
|
|
end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant':
|
|
start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0])
|
|
end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0])
|
|
dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0])
|
|
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
|
|
else:
|
|
start = _parse_arg(start, 'i')
|
|
end = _parse_arg(end, 'i')
|
|
dim = _parse_arg(dim, 'i')
|
|
return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
|
|
|
|
|
|
@parse_args('v', 'f', 'f')
|
|
def hardtanh(g, self, min_val, max_val):
|
|
return g.op("Clip", self, min_f=min_val, max_f=max_val)
|
|
|
|
|
|
def alias(g, self):
|
|
return self
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def unsqueeze(g, self, dim):
|
|
return g.op("Unsqueeze", self, axes_i=[dim])
|
|
|
|
|
|
@parse_args('v', 'i', 'i', 'i', 'i')
|
|
def topk(g, self, k, dim, largest, sorted, out=None):
|
|
if out is not None:
|
|
_unimplemented("TopK", "Out parameter is not supported for topk")
|
|
if not largest:
|
|
_unimplemented("TopK", "Ascending TopK is not supported")
|
|
|
|
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
|
|
|
|
|
|
def to(g, self, *args):
|
|
# ONNX doesn't have a concept of a device, so we ignore device casts
|
|
if len(args) == 3:
|
|
if args[0].type().isSubtypeOf(ListType.ofInts()):
|
|
# aten::to(Tensor, Device, bool, bool)
|
|
return self
|
|
else:
|
|
# aten::to(Tensor, ScalarType, bool, bool)
|
|
dtype = _get_const(args[0], 'i', 'dtype')
|
|
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
|
|
elif len(args) == 4:
|
|
# aten::to(Tensor, Device, ScalarType, bool, bool)
|
|
dtype = _get_const(args[1], 'i', 'dtype')
|
|
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
|
|
elif len(args) == 5:
|
|
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
|
|
dtype = _get_const(args[0], 'i', 'dtype')
|
|
# Layout and device are ignored
|
|
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
|
|
else:
|
|
raise NotImplementedError("Unknown aten::to signature")
|
|
|
|
|
|
def repeat(g, self, repeats):
|
|
if not _is_value(repeats):
|
|
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
|
const_repeats = _maybe_get_const(repeats, 'is')
|
|
|
|
if self.isTensor() and not _is_value(const_repeats):
|
|
sizes = self.type().sizes()
|
|
diff_dims = len(const_repeats) - len(sizes)
|
|
if diff_dims > 0:
|
|
self = view(g, self, [1] * diff_dims + sizes)
|
|
return g.op("Tile", self, repeats)
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def pixel_shuffle(g, self, upscale_factor):
|
|
dims = self.type().sizes()
|
|
if len(dims) != 4:
|
|
return _unimplemented("pixel_shuffle", "only support 4d input")
|
|
output_channel = dims[1] // upscale_factor // upscale_factor
|
|
after_view = view(g, self, [-1, upscale_factor, upscale_factor,
|
|
output_channel, dims[2], dims[3]])
|
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
|
return view(g, after_transpose,
|
|
[-1, output_channel, dims[2] * upscale_factor, dims[3] *
|
|
upscale_factor])
|
|
|
|
|
|
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
|
|
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
|
|
weights_per_layer = 4 if has_biases else 2
|
|
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
|
|
layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)]
|
|
if batch_first:
|
|
return _unimplemented("RNN/GRU/LSTM", "batch_first")
|
|
if dropout and train:
|
|
return _unimplemented("RNN/GRU/LSTM", "dropout in training mode")
|
|
|
|
if variant.startswith('RNN'):
|
|
nonlinearity = variant[4:].lower()
|
|
variant = 'RNN'
|
|
|
|
w_hh = all_weights[1]
|
|
hidden_size = w_hh.type().sizes()[1]
|
|
|
|
unidirectional = not bidirectional
|
|
|
|
prev_output = input
|
|
|
|
h_outs = []
|
|
if variant == 'RNN' or variant == 'GRU':
|
|
h0 = initial_states
|
|
elif variant == 'LSTM':
|
|
h0, c0 = initial_states
|
|
c_outs = []
|
|
|
|
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
|
|
|
|
if variant == 'GRU':
|
|
# pytorch is reset, input, hidden
|
|
# onnx is input, reset, hidden
|
|
reform_permutation = [(1, 2), (0, 1), (2, 3)]
|
|
elif variant == 'LSTM':
|
|
# pytorch is input, forget, cell, output.
|
|
# onnx is input, output, forget, cell.
|
|
reform_permutation = [(0, 1), (3, 4), (1, 3)]
|
|
|
|
def reform_weights(g, w, n, intervals):
|
|
slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals]
|
|
return g.op('Concat', *slices, axis_i=0)
|
|
|
|
def transform_weights(layer_index):
|
|
if variant == 'RNN':
|
|
weight_ih, weight_hh, bias_ih, bias_hh = layer_weights[layer_index]
|
|
elif variant == 'GRU' or variant == 'LSTM':
|
|
weight_ih, weight_hh, bias_ih, bias_hh = \
|
|
[reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
|
|
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
|
|
|
|
return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat))
|
|
|
|
def retrieve_state(x, start, end):
|
|
return x if num_layers == 1 else g.op('Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
|
|
|
|
for i in range(num_layers):
|
|
if unidirectional:
|
|
weight_ih, weight_hh, bias_concat = transform_weights(i)
|
|
state_indices = i, i + 1
|
|
else:
|
|
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
|
|
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
|
|
|
|
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
|
|
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
|
|
bias_concat = g.op('Concat', bias_f, bias_b, axis_i=0)
|
|
|
|
state_indices = 2 * i, 2 * i + 2
|
|
|
|
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
|
|
|
|
inputs.append(retrieve_state(h0, *state_indices))
|
|
if variant == 'LSTM':
|
|
inputs.append(retrieve_state(c0, *state_indices))
|
|
|
|
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
|
|
if variant == 'RNN':
|
|
prev_output, h_out = g.op('RNN', *inputs, outputs=2,
|
|
hidden_size_i=hidden_size,
|
|
activations_s=[nonlinearity],
|
|
**extra_kwargs)
|
|
elif variant == 'GRU':
|
|
prev_output, h_out = g.op('GRU', *inputs, outputs=2,
|
|
hidden_size_i=hidden_size,
|
|
linear_before_reset_i=1,
|
|
**extra_kwargs)
|
|
elif variant == 'LSTM':
|
|
prev_output, h_out, c_out = g.op('LSTM', *inputs, outputs=3,
|
|
hidden_size_i=hidden_size,
|
|
**extra_kwargs)
|
|
|
|
if bidirectional:
|
|
# The ONNX RNN/GRU/LSTM produce an output of dimensions
|
|
# seq_len, num_directions, batch, hidden_size
|
|
# We have to convert to match pytorch's expected
|
|
# seq_len, batch, num_directions * hidden_size
|
|
# by first moving num_directions before hidden_size with
|
|
# Transpose, and then combining it with hidden_size
|
|
# with Reshape.
|
|
prev_output = g.op('Transpose', prev_output, perm_i=[0, 2, 1, 3])
|
|
prev_output = g.op('Reshape', prev_output, g.op('Constant', value_t=torch.LongTensor([0, 0, -1])))
|
|
else:
|
|
prev_output = g.op('Squeeze', prev_output, axes_i=[1])
|
|
|
|
h_outs.append(h_out)
|
|
if variant == 'LSTM':
|
|
c_outs.append(c_out)
|
|
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
|
|
if variant == 'RNN' or variant == 'GRU':
|
|
return prev_output, h_outs
|
|
elif variant == 'LSTM':
|
|
c_outs = c_out if num_layers == 1 else g.op('Concat', *c_outs, axis_i=0)
|
|
return prev_output, h_outs, c_outs
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
|
|
def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
|
|
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
|
|
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
|
|
dropout, train, bidirectional, batch_first)
|
|
|
|
|
|
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
|
|
def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
|
|
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
|
|
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
|
|
dropout, train, bidirectional, batch_sizes=batch_sizes)
|
|
|
|
|
|
def lstm(g, *args):
|
|
if _is_tensor_list(args[3]):
|
|
return _lstm_packed(g, *args)
|
|
else:
|
|
return _lstm_full(g, *args)
|
|
|
|
|
|
def _one_hidden_rnn(kind):
|
|
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
|
|
def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
|
|
weight = _unpack_list(weight_v)
|
|
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
|
|
dropout, train, bidirectional, batch_first)
|
|
|
|
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
|
|
def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
|
|
weight = _unpack_list(weight_v)
|
|
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
|
|
dropout, train, bidirectional, batch_sizes=batch_sizes)
|
|
|
|
def symbolic(g, *args):
|
|
if _is_tensor_list(args[3]):
|
|
return _rnn_packed(g, *args)
|
|
else:
|
|
return _rnn_full(g, *args)
|
|
|
|
return symbolic
|
|
|
|
|
|
gru = _one_hidden_rnn('GRU')
|
|
rnn_tanh = _one_hidden_rnn('RNN_TANH')
|
|
rnn_relu = _one_hidden_rnn('RNN_RELU')
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def _dim_arange(g, like, dim):
|
|
return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')
|
|
|
|
|
|
def detach(g, input):
|
|
# Erase aten::detach nodes because ONNX is inference only
|
|
return input
|
|
|
|
|
|
def contiguous(g, input):
|
|
return input
|
|
|
|
|
|
@parse_args('v', 'v', 'i')
|
|
def _pack_padded_sequence(g, input, lengths, batch_first):
|
|
# There currently is no PackPadded operator in ONNX. We rely on an
|
|
# optimization pass to remove this later. It is an error if all
|
|
# PackPadded operators cannot be optimized out.
|
|
if batch_first:
|
|
input = g.op('Transpose', input, perm_i=[1, 0, 2])
|
|
if not lengths.type().isSubtypeOf(torch._C.DynamicType.get()):
|
|
raise RuntimeError("Lengths must be a Tensor for ONNX export")
|
|
# We know it's a TensorType so this check is now safe.
|
|
# It's really only necessary beacuse those operators expand to something that
|
|
# only works with int32 types in Caffe2...
|
|
if lengths.type().scalarType() != 'Int':
|
|
lengths = _cast_Int(g, lengths, False)
|
|
return g.op("prim::PackPadded", input, lengths, outputs=2)
|
|
|
|
|
|
@parse_args('v', 'v', 'i', 't', 'v')
|
|
def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
|
|
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
|
|
# It is only useful/used when training using data_parallel model, so
|
|
# It shouldn't be relevant for ONNX anyway
|
|
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
|
|
if batch_first:
|
|
data = g.op('Transpose', data, perm_i=[1, 0, 2])
|
|
return data, lengths
|
|
|
|
|
|
def randn(g, *shapes):
|
|
shapes_list = list(shapes)
|
|
shape = _maybe_get_const(shapes_list[0], "is")
|
|
return g.op('RandomNormal', shape_i=shape)
|
|
|
|
|
|
@parse_args('v', 'f', 'f', 'i', 'none')
|
|
def rrelu(g, input, lower, upper, training, generator):
|
|
p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
|
|
return g.op('PRelu', input, p)
|
|
|
|
|
|
@parse_args('v')
|
|
def log_sigmoid(g, input):
|
|
p = g.op('Sigmoid', input)
|
|
return g.op('Log', p)
|