Files
pytorch/torch/onnx/symbolic.py
Spandan Tiwari f0f9277c3c Fixing ONNX export of logical ops to have correct output datatype (#15185)
Summary:
Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly.

This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185

Differential Revision: D13494873

Pulled By: houseroad

fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8
2018-12-20 12:37:27 -08:00

1407 lines
47 KiB
Python

import numbers
import torch
from torch._C import DynamicType, ListType
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.utils.rnn import PackedSequence
import warnings
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from collections import Iterable
from functools import partial, wraps
import itertools
# EDITING THIS FILE? READ THIS FIRST!
#
# - This file is ONLY for ATen operators (e.g., operators that show up in the
# trace as aten::blah). If you need to special case a primitive operator,
# look at _run_symbolic_function
# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
# tensors are always first, then non-tensor arguments.
# - Parameter names must *exactly* match the names in VariableType.cpp, because
# dispatch is done with keyword arguments.
# - Looking for inplace ops? They're detected by the trailing underscore, and
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
# ---------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value in the trace")
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
else:
raise RuntimeError("Casting constants to `{}` is not implemented".format(desc))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == "prim::ListConstruct"
return list(list_node.inputs())
def parse_args(*arg_descriptors):
def decorator(fn):
def wrapper(g, *args):
assert len(arg_descriptors) == len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
We only support implicit casting for scalars, so we never
actually need to insert an ONNX cast operator here; just
fix up the scalar.
"""
if isinstance(self, torch._C.Value):
return self
elif tensor.type().kind() == "TensorType" or tensor.type().kind() == "CompleteTensorType":
ty = tensor.type().scalarType().lower()
return getattr(self, ty)()
else:
return self
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
# READ ME BEFORE EDITING _onnx_opset_version:
#
# The variable below controls which ONNX operator set version we are
# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
# change occurred in version 8. As long as this variable < 8, you can
# export models targeting the old behavior. However, if you bump
# this variable to 8 or later, the breaking change will take into effect:
# you MUST adjust any symbolic affected by breaking changes. The ONNX
# spec publishes a *comprehensive* list of BC-breaking changes for every
# operator revision at:
#
# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
#
# Please be sure to go through and check all of our implementations here before
# increasing this number. This includes symbolic definitions NOT in this
# file, so grep for "OpName" (with quotes)
_onnx_opset_version = 9
# ---------------------------------------------------------------------
# Symbolic definitions
# ---------------------------------------------------------------------
# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
# some moving parts to implementing the ONNX translation in this case:
#
# - By the time we get the scalar in a symbolic function here, it is no longer
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
# want it to be a zero dim tensor but this change has not happened yet.)
# However, the type of this scalar is *exactly* what the user wrote in
# Python, which may not match the tensor it is being added to. PyTorch
# will do implicit conversions on scalars; however, ONNX will not, so
# we must do the conversion ourselves. This is what _if_scalar_type_as
# does.
#
# - Dispatch to these functions takes advantage an outrageous coincidence
# between the tensor and scalar name. When we add two tensors together,
# you get the dispatch:
#
# add(*[self, other], **{"alpha": alpha})
#
# When you add a tensor and a scalar, you get the dispatch:
#
# add(*[self], **{"other": other, "alpha": alpha})
#
# By having the argument name line up with the name of the scalar attribute
# if it exists, we can write a single function for both overloads.
#
# used to represent "missing" optional inputs
def unused(g):
return g.op("prim::Undefined")
def _shape_as_tensor(g, input):
return g.op('Shape', input)
def _reshape_from_tensor(g, input, shape):
return g.op('Reshape', input, shape)
def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Add", self, _if_scalar_type_as(g, other, self))
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
# See Note [Pointwise by scalar]. Note that self or other may be scalars.
other = _maybe_get_scalar(other)
return g.op("Sub", self, _if_scalar_type_as(g, other, self))
def rsub(g, self, other, alpha=None):
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Mul", self, _if_scalar_type_as(g, other, self))
def div(g, self, other):
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Div", self, _if_scalar_type_as(g, other, self))
def reciprocal(g, self):
return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
@parse_args('v', 'i')
def cat(g, tensor_list, dim):
tensors = _unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args('v', 'i')
def stack(g, tensor_list, dim):
unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
ty = _try_get_scalar_type(self, other).lower()
C = g.constant(0, [1], ty)
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def bmm(g, self, other):
return g.op("MatMul", self, other)
def matmul(g, self, other):
return g.op("MatMul", self, other)
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
def neg(g, self):
return g.op("Neg", self)
def sqrt(g, self):
return g.op("Sqrt", self)
def tanh(g, self):
return g.op("Tanh", self)
def sin(g, self):
return g.op("Sin", self)
def cos(g, self):
return g.op("Cos", self)
def tan(g, self):
return g.op("Tan", self)
def asin(g, self):
return g.op("Asin", self)
def acos(g, self):
return g.op("Acos", self)
def atan(g, self):
return g.op("Atan", self)
def sigmoid(g, self):
return g.op("Sigmoid", self)
def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
# dim-reduce path
dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
return symbolic
mean = _reduce_op_symbolic('ReduceMean')
sum = _reduce_op_symbolic('ReduceSum')
prod = _reduce_op_symbolic('ReduceProd')
@parse_args('v', 'i')
def cumsum(g, input, dim):
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
def t(g, self):
return g.op("Transpose", self, perm_i=(1, 0))
def expand(g, self, size, implicit):
size = _maybe_get_const(size, 'is')
if not _is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Expand", self, size)
def expand_as(g, self, other):
shape = g.op("Shape", other)
return g.op("Expand", self, shape)
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
return g.op("Gather", weight, indices)
@parse_args('v', 'v', 'v', 'i', 'i', 'i')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse):
return g.op("ATen",
embedding_matrix,
indices,
offsets,
operator_s="embedding_bag",
outputs=4,
scale_grad_by_freq_i=scale_grad_by_freq,
mode_i=mode,
sparse_i=sparse)
def size(g, self, dim):
full_shape = g.op("Shape", self)
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
@parse_args('v', 'i', 'i')
def transpose(g, self, dim0, dim1):
if dim0 == dim1: # micro-optimization
return self
# NB: Transpose in ONNX is actually a Permute
axes = list(range(len(self.type().sizes())))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
@parse_args('v', 'is')
def permute(g, self, dims):
if dims == list(range(0, len(dims))):
return self
return g.op("Transpose", self, perm_i=dims)
def view(g, self, size):
size = _maybe_get_const(size, 'is')
if _is_value(size):
shape = size
else:
if self.isTensor():
self_sizes = self.type().sizes()
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
return g.op("Flatten", self, axis_i=1)
shape = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Reshape", self, shape)
def prim_ConstantSplit(g, self, split_size, dim):
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
# TODO: It would be better to export this as a chunk directly, as this is
# less sensitive to changes in input size.
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
# method, and use the desugared version
def prim_ConstantChunk(g, self, chunks, dim):
split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
return prim_ConstantSplit(g, self, split_size, dim)
@parse_args('v', 'i', 'v')
def select(g, self, dim, index):
if dim > 1:
# TODO: this is a temporary hack because of the implementation details
# of Gather in caffe2. We need to change this as soon as possible.
# TODO: this breaks if index == -1
index_val = _parse_arg(index, 'i')
slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index_val], ends_i=[index_val + 1])
return g.op("Squeeze", slice_node, axes_i=[dim])
else:
return g.op("Gather", self, index, axis_i=dim)
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [_get_const(dim, 'i', 'dim')]
return g.op("Squeeze", self, axes_i=dims)
def prelu(g, self, weight):
return g.op("PRelu", self, weight)
def relu(g, input):
return g.op("Relu", input)
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
if _scalar(threshold) != 0:
return _unimplemented("threshold", "non-zero threshold")
if _scalar(value) != 0:
return _unimplemented("threshold", "non-zero value")
return g.op("Relu", self)
def leaky_relu(g, input, negative_slope, inplace=False):
negative_slope = _get_const(negative_slope, 't', 'negative_slope')
# See Note [Export inplace]
# TODO: Talk to ONNX about unconditional cast of scalar to float
return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
@parse_args('v', 'i')
def glu(g, input, dim):
assert input.type().sizes()[dim] % 2 == 0
first, second = g.op('Split', input, axis_i=dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))
@parse_args('v', 'i')
def softmax(g, input, dim):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
# If input is a 2 x 3 tensor:
# input = [[1.0, 1.0, 1.0],
# [1.0, 1,0, 1,0]]
# with dim = 0, the result is:
# result = [[0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5]]
# with axis = 0, the result is:
# result = [[0.167, 0.167, 0.167],
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
if dim < 0:
dim = len(input.type().sizes()) + dim
if len(input.type().sizes()) != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return g.op('Softmax', input, axis_i=dim)
@parse_args('v', 't', 'v')
def softplus(g, self, beta, threshold):
if beta != 1:
return _unimplemented("beta", "has to be 1")
return g.op('Softplus', self)
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode:
return _unimplemented("max_pool1d_with_indices", "ceil_mode")
if set(_single(dilation)) != {1}:
return _unimplemented("max_pool1d_with_indices", "dilation")
if stride is None:
stride = kernel_size
r = g.op("MaxPool", input,
kernel_shape_i=_single(kernel_size),
pads_i=_single(padding) * 2,
strides_i=_single(stride))
return r, None
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode:
return _unimplemented("max_pool2d_with_indices", "ceil_mode")
if set(_pair(dilation)) != {1}:
return _unimplemented("max_pool2d_with_indices", "dilation")
if not stride:
stride = kernel_size
r = g.op("MaxPool", input,
kernel_shape_i=_pair(kernel_size),
pads_i=_pair(padding) * 2,
strides_i=_pair(stride))
return r, None
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode:
return _unimplemented("max_pool3d_with_indices", "ceil_mode")
if set(_triple(dilation)) != {1}:
return _unimplemented("max_pool3d_with_indices", "dilation")
if not stride:
stride = kernel_size
r = g.op("MaxPool", input,
kernel_shape_i=_triple(kernel_size),
pads_i=_triple(padding) * 2,
strides_i=_triple(stride))
return r, None
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
if ceil_mode:
return _unimplemented("avg_pool2d", "ceil_mode")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if count_include_pad:
input = g.op("Pad", input,
pads_i=((0,) * 2 + padding) * 2,
mode_s='constant',
value_f=0.)
padding = (0,) * len(padding)
return g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2)
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
@parse_args('v', 'is')
def adaptive_avg_pool2d(g, input, output_size):
assert output_size == [1, 1], "Only output_size=[1, 1] is supported"
return g.op("GlobalAveragePool", input)
@parse_args('v', 'is')
def adaptive_max_pool2d(g, input, output_size):
assert output_size == [1, 1], "Only output_size=[1, 1] is supported"
return g.op("GlobalMaxPool", input), None
@parse_args('v', 'is', 'f')
def constant_pad_nd(g, input, padding, value):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "constant"
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
@parse_args('v', 'is')
def reflection_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "reflect"
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
@parse_args('v', 'is')
def replication_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "edge"
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad
@parse_args('v', 'is')
def upsample_nearest2d(g, input, output_size):
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
return g.op("Upsample", input, scales,
mode_s="nearest")
@parse_args('v', 'is', 'i')
def upsample_bilinear2d(g, input, output_size, align_corners):
if align_corners:
return _unimplemented("upsample_bilinear2d", "align_corners == True")
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
return g.op("Upsample", input, scales,
mode_s="linear")
def wrap_logical_op_with_cast_to_uint8(func):
def wrap_with_cast(g, input, other):
return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
return wrap_with_cast
def wrap_logical_op_with_negation(func):
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
@wrap_logical_op_with_cast_to_uint8
def gt(g, input, other):
return gt_impl(g, input, other)
def gt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
def lt(g, input, other):
return lt_impl(g, input, other)
def lt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Less", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def ge(g, input, other):
other = _maybe_get_scalar(other)
return lt_impl(g, input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def le(g, input, other):
other = _maybe_get_scalar(other)
return gt_impl(g, input, _if_scalar_type_as(g, other, input))
def where(g, condition, self, other):
return g.op("ATen", condition, self, other, operator_s="where")
@parse_args('v', 'i')
def log_softmax(g, input, dim=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
if dim < 0:
dim = len(input.type().sizes()) + dim
if len(input.type().sizes()) != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return g.op("LogSoftmax", input, axis_i=dim)
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
def _convolution(g, input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
weight_size = weight.type().sizes()
args = [input, weight]
# ONNX only supports 1D bias
if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
"strides_i": stride,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": padding + padding,
"dilations_i": dilation,
"group_i": groups}
if any(o != 0 for o in output_padding):
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
# output_padding is more straightforward, so we use it here.
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
assert transposed
assert len(stride) == len(output_padding)
kwargs["output_padding_i"] = output_padding
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) != 1:
return g.op("Add", n, bias)
else:
return n
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
input_sizes = input.type().sizes()
if len(input_sizes) == 2:
# batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
input = g.op("Unsqueeze", input, axes_i=[2])
if weight is None or weight.node().kind() == "prim::Undefined":
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().kind() == "prim::Undefined":
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not training else 5)
if not training:
if len(input_sizes) == 2:
out = g.op("Squeeze", out, axes_i=[2])
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setUniqueName("batch_norm_dead_output-" + saved_mean.uniqueName())
saved_var.setUniqueName("batch_norm_dead_output-" + saved_var.uniqueName())
if len(input_sizes) == 2:
res = g.op("Squeeze", res, axes_i=[2])
return res
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
input_sizes = input.type().sizes()
if weight is None or weight.node().kind() == "prim::Undefined":
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().kind() == "prim::Undefined":
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
@parse_args('v', 'i', 'i', 'i')
def unfold(g, input, dimension, size, step):
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
@parse_args('v', 'v', 'i')
def _weight_norm(graph, v, g, dim):
return graph.op("ATen", v, g, dim_i=dim, operator_s="_weight_norm")
@parse_args('v', 't', 't', 't')
def elu(g, input, alpha, scale, input_scale):
if scale and scale != 1.:
return _unimplemented("scale", "does not support scale in Elu")
if input_scale and input_scale != 1.:
return _unimplemented("input_scale", "does not support input_scale in Elu")
# See Note [Export inplace]
return g.op("Elu", input, alpha_f=_scalar(alpha))
def selu(g, input):
return g.op("Selu", input)
@parse_args('v', 'i', 'v')
def index_select(g, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
def index_put(g, self, indices_list_value, values, accumulate):
indices_list = _unpack_list(indices_list_value)
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s='index_put')
def type_as(g, self, other):
if self.isTensor() and other.isTensor() and self.type().scalarType() == other.type().scalarType():
return self
if other.isTensor():
other_type_name = other.type().scalarType()
return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
else:
# We don't know the type of other, bail by emitting ATen
return g.op("ATen", self, other, operator_s="type_as")
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
return g.op("ATen", self, weight, bias, normalized_shape_i=normalized_shape,
eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
# ignore clone operators that are inserted by PyTorch autograd
def clone(g, input):
return input
def abs(g, self):
return g.op("Abs", self)
def log(g, self):
return g.op("Log", self)
def pow(g, self, exponent):
exponent = _maybe_get_scalar(exponent)
return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
def clamp(g, self, min, max):
# min or max may be prim::None that we need to dispatch to
# Clip separately, as ONNX does not have None syntax
if min.node().kind() == "prim::None":
return clamp_max(g, self, max)
elif max.node().kind() == "prim::None":
return clamp_min(g, self, min)
else:
min = _parse_arg(min, 'f')
max = _parse_arg(max, 'f')
return g.op("Clip", self, min_f=min, max_f=max)
@parse_args('v', 'f')
def clamp_min(g, self, min):
return g.op("Clip", self, min_f=min)
@parse_args('v', 'f')
def clamp_max(g, self, max):
return g.op("Clip", self, max_f=max)
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
def max(g, self, dim_or_y=None, keepdim=None):
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
if keepdim is None:
return g.op("Max", self, dim_or_y)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="max",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
def min(g, self, dim_or_y=None, keepdim=None):
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
if keepdim is None:
return g.op("Min", self, dim_or_y)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="min",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
@wrap_logical_op_with_cast_to_uint8
def eq(g, self, other):
return g.op("Equal", self, other)
def ne(g, self, other):
return g.op("Not", eq(g, self, other))
def exp(g, self):
return g.op("Exp", self)
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
return r
def _unsupported_dropout(name):
@parse_args('v', 'f', 'i')
def feature_dropout(g, input, p, train):
# NB: In inference mode, FeatureDropout is exported as an identity op.
from torch.onnx.symbolic import _unimplemented
if train:
return _unimplemented(name, "training mode")
return input
return feature_dropout
feature_dropout = _unsupported_dropout("feature_dropout")
alpha_dropout = _unsupported_dropout("alpha_dropout")
feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
# See Note [Export inplace]
dropout_ = dropout
feature_dropout_ = feature_dropout
alpha_dropout_ = alpha_dropout
feature_alpha_dropout_ = feature_alpha_dropout
@parse_args('v', 't', 'i', 'i')
def norm(g, self, p, dim, keepdim):
if p == 1:
f = _reduce_op_symbolic("ReduceL1")
elif p == 2:
f = _reduce_op_symbolic("ReduceL2")
else:
raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
return f(g, self, dim=dim, keepdim=keepdim)
@parse_args('v', 'v', 'v', 'i')
def conv_tbc(g, input, weight, bias, pad):
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
@parse_args('v', 'i', 'i')
def _unique(g, input, sorted, return_inverse):
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
return_inverse_i=return_inverse, outputs=2)
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
}
scalar_name_to_pytorch = {
'uint8_t': 'Byte',
'int8_t': 'Char',
'double': 'Double',
'float': 'Float',
'half': 'Half',
'int': 'Int',
'int64_t': 'Long',
'int16_t': 'Short',
}
def _cast_func_template(to_i, g, input, non_blocking):
return g.op("Cast", input, to_i=to_i)
for k, v in cast_pytorch_to_onnx.items():
name = '_cast_{}'.format(k)
globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
scalar_type_to_onnx = [
cast_pytorch_to_onnx["Byte"],
cast_pytorch_to_onnx["Char"],
cast_pytorch_to_onnx["Short"],
cast_pytorch_to_onnx["Int"],
cast_pytorch_to_onnx["Long"],
cast_pytorch_to_onnx["Half"],
cast_pytorch_to_onnx["Float"],
cast_pytorch_to_onnx["Double"],
]
@parse_args('v', 'i', 'v', 'v')
def zeros(g, sizes, dtype, layout, device):
# NOTE: no way to set device and layout in ONNX, so we ignore it
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
@parse_args('v', 'i', 'v', 'v')
def zeros_like(g, input, dtype, layout, device):
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
@parse_args('v', 'i', 'v', 'v')
def ones(g, sizes, dtype, layout, device):
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
@parse_args('v', 'i', 'v', 'v')
def ones_like(g, input, dtype, layout, device):
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
def full(g, sizes, value, dtype, layout, device):
const_value = _maybe_get_const(value, 't')
if _is_value(const_value):
tmp = zeros(sizes, dtype, layout, device)
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = _get_const(dtype, 'i', 'dtype')
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype],
input_as_shape_i=1, value_f=const_value)
@parse_args('v', 'f', 'i', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device):
return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
@parse_args('v', 'v', 'v', 'v', 'i')
def slice(g, self, dim, start, end, step):
if step != 1:
_unimplemented("slice", "step!=1 is currently not supported")
if start.node().kind() != 'onnx::Constant' or \
end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant':
start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0])
end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0])
dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0])
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
else:
start = _parse_arg(start, 'i')
end = _parse_arg(end, 'i')
dim = _parse_arg(dim, 'i')
return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
@parse_args('v', 'f', 'f')
def hardtanh(g, self, min_val, max_val):
return g.op("Clip", self, min_f=min_val, max_f=max_val)
def alias(g, self):
return self
@parse_args('v', 'i')
def unsqueeze(g, self, dim):
return g.op("Unsqueeze", self, axes_i=[dim])
@parse_args('v', 'i', 'i', 'i', 'i')
def topk(g, self, k, dim, largest, sorted, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported for topk")
if not largest:
_unimplemented("TopK", "Ascending TopK is not supported")
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
def to(g, self, *args):
# ONNX doesn't have a concept of a device, so we ignore device casts
if len(args) == 3:
if args[0].type().isSubtypeOf(ListType.ofInts()):
# aten::to(Tensor, Device, bool, bool)
return self
else:
# aten::to(Tensor, ScalarType, bool, bool)
dtype = _get_const(args[0], 'i', 'dtype')
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
elif len(args) == 4:
# aten::to(Tensor, Device, ScalarType, bool, bool)
dtype = _get_const(args[1], 'i', 'dtype')
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
elif len(args) == 5:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
dtype = _get_const(args[0], 'i', 'dtype')
# Layout and device are ignored
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::to signature")
def repeat(g, self, repeats):
if not _is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
const_repeats = _maybe_get_const(repeats, 'is')
if self.isTensor() and not _is_value(const_repeats):
sizes = self.type().sizes()
diff_dims = len(const_repeats) - len(sizes)
if diff_dims > 0:
self = view(g, self, [1] * diff_dims + sizes)
return g.op("Tile", self, repeats)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = view(g, self, [-1, upscale_factor, upscale_factor,
output_channel, dims[2], dims[3]])
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return view(g, after_transpose,
[-1, output_channel, dims[2] * upscale_factor, dims[3] *
upscale_factor])
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
weights_per_layer = 4 if has_biases else 2
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)]
if batch_first:
return _unimplemented("RNN/GRU/LSTM", "batch_first")
if dropout and train:
return _unimplemented("RNN/GRU/LSTM", "dropout in training mode")
if variant.startswith('RNN'):
nonlinearity = variant[4:].lower()
variant = 'RNN'
w_hh = all_weights[1]
hidden_size = w_hh.type().sizes()[1]
unidirectional = not bidirectional
prev_output = input
h_outs = []
if variant == 'RNN' or variant == 'GRU':
h0 = initial_states
elif variant == 'LSTM':
h0, c0 = initial_states
c_outs = []
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
if variant == 'GRU':
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
reform_permutation = [(1, 2), (0, 1), (2, 3)]
elif variant == 'LSTM':
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
reform_permutation = [(0, 1), (3, 4), (1, 3)]
def reform_weights(g, w, n, intervals):
slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals]
return g.op('Concat', *slices, axis_i=0)
def transform_weights(layer_index):
if variant == 'RNN':
weight_ih, weight_hh, bias_ih, bias_hh = layer_weights[layer_index]
elif variant == 'GRU' or variant == 'LSTM':
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat))
def retrieve_state(x, start, end):
return x if num_layers == 1 else g.op('Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
for i in range(num_layers):
if unidirectional:
weight_ih, weight_hh, bias_concat = transform_weights(i)
state_indices = i, i + 1
else:
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
bias_concat = g.op('Concat', bias_f, bias_b, axis_i=0)
state_indices = 2 * i, 2 * i + 2
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices))
if variant == 'LSTM':
inputs.append(retrieve_state(c0, *state_indices))
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
if variant == 'RNN':
prev_output, h_out = g.op('RNN', *inputs, outputs=2,
hidden_size_i=hidden_size,
activations_s=[nonlinearity],
**extra_kwargs)
elif variant == 'GRU':
prev_output, h_out = g.op('GRU', *inputs, outputs=2,
hidden_size_i=hidden_size,
linear_before_reset_i=1,
**extra_kwargs)
elif variant == 'LSTM':
prev_output, h_out, c_out = g.op('LSTM', *inputs, outputs=3,
hidden_size_i=hidden_size,
**extra_kwargs)
if bidirectional:
# The ONNX RNN/GRU/LSTM produce an output of dimensions
# seq_len, num_directions, batch, hidden_size
# We have to convert to match pytorch's expected
# seq_len, batch, num_directions * hidden_size
# by first moving num_directions before hidden_size with
# Transpose, and then combining it with hidden_size
# with Reshape.
prev_output = g.op('Transpose', prev_output, perm_i=[0, 2, 1, 3])
prev_output = g.op('Reshape', prev_output, g.op('Constant', value_t=torch.LongTensor([0, 0, -1])))
else:
prev_output = g.op('Squeeze', prev_output, axes_i=[1])
h_outs.append(h_out)
if variant == 'LSTM':
c_outs.append(c_out)
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
if variant == 'RNN' or variant == 'GRU':
return prev_output, h_outs
elif variant == 'LSTM':
c_outs = c_out if num_layers == 1 else g.op('Concat', *c_outs, axis_i=0)
return prev_output, h_outs, c_outs
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def lstm(g, *args):
if _is_tensor_list(args[3]):
return _lstm_packed(g, *args)
else:
return _lstm_full(g, *args)
def _one_hidden_rnn(kind):
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
weight = _unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
weight = _unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def symbolic(g, *args):
if _is_tensor_list(args[3]):
return _rnn_packed(g, *args)
else:
return _rnn_full(g, *args)
return symbolic
gru = _one_hidden_rnn('GRU')
rnn_tanh = _one_hidden_rnn('RNN_TANH')
rnn_relu = _one_hidden_rnn('RNN_RELU')
@parse_args('v', 'i')
def _dim_arange(g, like, dim):
return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')
def detach(g, input):
# Erase aten::detach nodes because ONNX is inference only
return input
def contiguous(g, input):
return input
@parse_args('v', 'v', 'i')
def _pack_padded_sequence(g, input, lengths, batch_first):
# There currently is no PackPadded operator in ONNX. We rely on an
# optimization pass to remove this later. It is an error if all
# PackPadded operators cannot be optimized out.
if batch_first:
input = g.op('Transpose', input, perm_i=[1, 0, 2])
if not lengths.type().isSubtypeOf(torch._C.DynamicType.get()):
raise RuntimeError("Lengths must be a Tensor for ONNX export")
# We know it's a TensorType so this check is now safe.
# It's really only necessary beacuse those operators expand to something that
# only works with int32 types in Caffe2...
if lengths.type().scalarType() != 'Int':
lengths = _cast_Int(g, lengths, False)
return g.op("prim::PackPadded", input, lengths, outputs=2)
@parse_args('v', 'v', 'i', 't', 'v')
def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
# It is only useful/used when training using data_parallel model, so
# It shouldn't be relevant for ONNX anyway
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
if batch_first:
data = g.op('Transpose', data, perm_i=[1, 0, 2])
return data, lengths
def randn(g, *shapes):
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomNormal', shape_i=shape)
@parse_args('v', 'f', 'f', 'i', 'none')
def rrelu(g, input, lower, upper, training, generator):
p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
return g.op('PRelu', input, p)
@parse_args('v')
def log_sigmoid(g, input):
p = g.op('Sigmoid', input)
return g.op('Log', p)