[ONNX] Fix type annotations and enable type checking for all apis (#84091)

Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.

Profile:

export `torchvision.models.alexnet(pretrained=True)`

```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes

+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
This commit is contained in:
Justin Chu
2022-09-02 23:19:03 +00:00
committed by PyTorch MergeBot
parent 2a332afbf4
commit 388368b699
14 changed files with 752 additions and 71 deletions

View File

@ -3583,7 +3583,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.arange(1.0, 6.0, requires_grad=True)
k = torch.tensor(3)
self.run_test(MyModuleDynamic(), [x, k])
self.run_test(MyModuleDynamic(), (x, k))
@skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
@skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11.
@ -7405,23 +7405,28 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(2, 2, 4, 4)
self.run_test(model, x)
# Dynamic padding is added in opset 11
@skipIfUnsupportedMinOpsetVersion(11)
def test_pad_types(self):
@common_utils.parametrize(
"pad",
[
common_utils.subtest([2, 4], name="scalar_list"),
common_utils.subtest(
[
torch.tensor(2, dtype=torch.int64),
torch.tensor(4, dtype=torch.int64),
],
name="scalar_tensor_list",
),
],
)
@skipIfUnsupportedMinOpsetVersion(11) # Dynamic padding is added in opset 11
def test_pad_types(self, pad):
# Test for different pad integer types
class Pad(torch.nn.Module):
def forward(self, x, pad: List[int]):
return torch.nn.functional.pad(x, pad)
x = torch.randn(2, 2, 4, 4)
y = pad = [2, 4]
self.run_test(Pad(), (x, y))
y = pad = [
torch.tensor(2, dtype=torch.int64),
torch.tensor(4, dtype=torch.int64),
]
self.run_test(Pad(), (x, y))
self.run_test(Pad(), (x, pad))
@skipIfUnsupportedMaxOpsetVersion(10)
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script

View File

@ -10,15 +10,17 @@ from torch._C import _onnx as _C_onnx
# Import utils to get _params_dict because it is a global that is accessed by c++ code
from torch.onnx import _deprecation, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
# TODO(#78694): Refactor the patching process to make it more transparent to users.
@_beartype.beartype
def _graph_op(
g: _C.Graph,
opname: str,
*raw_args: _C.Value,
*raw_args: Union[torch.Tensor, _C.Value],
outputs: int = 1,
**kwargs,
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
@ -76,6 +78,7 @@ def _graph_op(
return tuple(n.outputs())
@_beartype.beartype
def _const_if_tensor(g: _C.Graph, arg):
if arg is None:
return arg
@ -85,6 +88,7 @@ def _const_if_tensor(g: _C.Graph, arg):
# Generate an ONNX ATen op node.
@_beartype.beartype
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
return _graph_op(
g,
@ -96,7 +100,8 @@ def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwarg
)
def _block_op(b: _C.Block, opname: str, *args, **kwargs):
@_beartype.beartype
def _block_op(b: _C.Block, opname: str, *args: _C.Value, **kwargs):
if "::" in opname:
aten = False
ns_opname = opname
@ -115,8 +120,9 @@ def _block_op(b: _C.Block, opname: str, *args, **kwargs):
return outputs
@_beartype.beartype
def _new_node(
g: _C.Graph, namespace: str, op: str, outputs: int, *args, **kwargs
g: _C.Graph, namespace: str, op: str, outputs: int, *args: _C.Value, **kwargs
) -> _C.Node:
"""Creates a new node in the graph.
@ -138,6 +144,7 @@ def _new_node(
return node
@_beartype.beartype
def _is_onnx_list(value):
return (
not isinstance(value, torch._six.string_classes)
@ -146,19 +153,22 @@ def _is_onnx_list(value):
)
@_beartype.beartype
def _scalar(x: torch.Tensor):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x[0]
def _is_caffe2_aten_fallback():
@_beartype.beartype
def _is_caffe2_aten_fallback() -> bool:
return (
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and _C_onnx._CAFFE2_ATEN_FALLBACK
)
@_beartype.beartype
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
r"""Initializes the right attribute based on type of value."""
m = _ATTR_PATTERN.match(key)
@ -188,6 +198,7 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
@_deprecation.deprecated(
"1.13", "1.14", "Use 'g.op()' to create a constant node instead."
)
@_beartype.beartype
def _graph_constant(
g,
value,

View File

@ -8,6 +8,7 @@ from typing_extensions import Literal
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx._internal import _beartype
ScalarName = Literal[
"Byte",
@ -80,6 +81,7 @@ class JitScalarType(enum.IntEnum):
UNDEFINED = enum.auto() # 16
@classmethod
@_beartype.beartype
def from_name(
cls, name: Union[ScalarName, TorchName, Optional[str]]
) -> JitScalarType:
@ -104,30 +106,36 @@ class JitScalarType(enum.IntEnum):
raise ValueError(f"Unknown torch or scalar type: '{name}'")
@classmethod
@_beartype.beartype
def from_dtype(cls, dtype: torch.dtype) -> JitScalarType:
"""Convert a torch dtype to ScalarType."""
if dtype not in _DTYPE_TO_SCALAR_TYPE:
raise ValueError(f"Unknown dtype: {dtype}")
return _DTYPE_TO_SCALAR_TYPE[dtype]
@_beartype.beartype
def scalar_name(self) -> ScalarName:
"""Convert a ScalarType to a JIT scalar type name."""
return _SCALAR_TYPE_TO_NAME[self]
@_beartype.beartype
def torch_name(self) -> TorchName:
"""Convert a ScalarType to a torch type name."""
return _SCALAR_TYPE_TO_TORCH_NAME[self]
@_beartype.beartype
def dtype(self) -> torch.dtype:
"""Convert a ScalarType to a torch dtype."""
return _SCALAR_TYPE_TO_DTYPE[self]
@_beartype.beartype
def onnx_type(self) -> _C_onnx.TensorProtoDataType:
"""Convert a ScalarType to an ONNX data type."""
if self not in _SCALAR_TYPE_TO_ONNX:
raise ValueError(f"Scalar type {self} cannot be converted to ONNX")
return _SCALAR_TYPE_TO_ONNX[self]
@_beartype.beartype
def onnx_compatible(self) -> bool:
"""Return whether this ScalarType is compatible with ONNX."""
return (
@ -137,11 +145,13 @@ class JitScalarType(enum.IntEnum):
)
@_beartype.beartype
def valid_scalar_name(scalar_name: Union[ScalarName, str]) -> bool:
"""Return whether the given scalar name is a valid JIT scalar type name."""
return scalar_name in _SCALAR_NAME_TO_TYPE
@_beartype.beartype
def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
"""Return whether the given torch name is a valid torch type name."""
return torch_name in _TORCH_NAME_TO_SCALAR_TYPE

View File

@ -17,6 +17,7 @@ from torch import _C
from torch.onnx import _constants, _patch_torch, _type_utils, errors # noqa: F401
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
from torch.types import Number
# Note [Edit Symbolic Files]
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
@ -32,6 +33,8 @@ from torch.onnx._internal import _beartype
# `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
# transparently dispatched to their non inplace versions in
# "run_symbolic_function". See Note [Export inplace]
# - REQUIRED: Annotate new symbolic functions with type annotations and decorate with
# _beartype.beartype to enable runtime type checking.
#
# ----------------------------------------------------------------------------------
# A note on Tensor types
@ -94,6 +97,7 @@ _ValueDescriptor = Literal[
]
@_beartype.beartype
def _parse_arg(
value,
desc: _ValueDescriptor,
@ -163,6 +167,7 @@ def _parse_arg(
)
@_beartype.beartype
def _node_get(node: _C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type."""
assert isinstance(node, _C.Node)
@ -170,19 +175,26 @@ def _node_get(node: _C.Node, key: str):
return getattr(node, sel)(key)
@_beartype.beartype
def _is_onnx_constant(value: _C.Value):
"""Whether a Value is an ONNX constant."""
return value.node().kind() == "onnx::Constant"
def _maybe_get_const(value: _C.Value, descriptor: _ValueDescriptor):
@_beartype.beartype
def _maybe_get_const(
value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]],
descriptor: _ValueDescriptor,
):
# NOTE: prim::Constant at this stage usually means something not compatible in ONNX,
# otherwise it'd be converted to onnx::Constant
if _is_value(value) and _is_onnx_constant(value):
# TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy
if isinstance(value, _C.Value) and _is_onnx_constant(value):
return _parse_arg(value, descriptor)
return value
@_beartype.beartype
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, "t")
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
@ -190,6 +202,7 @@ def _maybe_get_scalar(value):
return value
@_beartype.beartype
def _get_const(value, desc, arg_name):
if not _is_constant(value):
raise errors.SymbolicValueError(
@ -200,6 +213,7 @@ def _get_const(value, desc, arg_name):
return _parse_arg(value, desc)
@_beartype.beartype
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
list_node = list_value.node()
if list_node.kind() != "prim::ListConstruct":
@ -211,6 +225,7 @@ def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
return list(list_node.inputs())
@_beartype.beartype
def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
tuple_node = tuple_value.node()
if not _is_tuple_construct(tuple_value):
@ -222,6 +237,7 @@ def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
return tuple(tuple_node.inputs())
@_beartype.beartype
def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
"""Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
Args:
@ -245,10 +261,12 @@ def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be unpacked.
@_beartype.beartype
def _is_packed_list(list_value: _C.Value) -> bool:
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
@_beartype.beartype
def parse_args(*arg_descriptors: _ValueDescriptor):
"""A decorator which converts args from torch._C.Value to built-in types.
@ -326,6 +344,7 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
return decorator
@_beartype.beartype
def quantized_args(
*arg_q_descriptors: bool,
scale: Optional[float] = None,
@ -430,13 +449,15 @@ def quantized_args(
return decorator
def _scalar(x: torch.Tensor):
@_beartype.beartype
def _scalar(x: Any) -> Optional[Number]:
"""Convert a scalar tensor into a Python value."""
if isinstance(x, torch.Tensor) and x.shape == ():
return x.item()
return None
@_beartype.beartype
def _if_scalar_type_as(g: _C.Graph, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
@ -455,14 +476,17 @@ def _if_scalar_type_as(g: _C.Graph, self, tensor):
return self
@_beartype.beartype
def _is_none(x: _C.Value) -> bool:
return x.node().mustBeNone()
@_beartype.beartype
def _is_value(x: Any) -> bool:
return isinstance(x, _C.Value)
@_beartype.beartype
def _is_constant(value: Any) -> bool:
return not _is_value(value) or value.node().kind() in {
"onnx::Constant",
@ -470,20 +494,24 @@ def _is_constant(value: Any) -> bool:
}
@_beartype.beartype
def _is_tensor(x: _C.Value) -> bool:
return x.type().isSubtypeOf(_C.TensorType.get())
# Note: _C.JitType is not exposed to Python and cannot be checked in runtime.
def _as_list_type(jit_type: _C.JitType) -> Optional[_C.ListType]:
if isinstance(jit_type, _C.ListType):
return jit_type
return None
@_beartype.beartype
def _is_list(x: _C.Value) -> bool:
return _as_list_type(x.type()) is not None
@_beartype.beartype
def _is_tensor_list(x: _C.Value) -> bool:
x_type = _as_list_type(x.type())
if x_type is None:
@ -491,6 +519,7 @@ def _is_tensor_list(x: _C.Value) -> bool:
return isinstance(x_type.getElementType(), _C.TensorType)
@_beartype.beartype
def _is_scalar_list(x: _C.Value) -> bool:
"""Checks if x is a scalar list, for example: List[float], List[int].
@ -507,10 +536,12 @@ def _is_scalar_list(x: _C.Value) -> bool:
)
@_beartype.beartype
def _is_tuple_construct(x: _C.Value) -> bool:
return x.node().kind() == "prim::TupleConstruct"
@_beartype.beartype
def is_caffe2_aten_fallback() -> bool:
return (
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
@ -548,6 +579,7 @@ def _get_tensor_dim_size(x: _C.Value, dim: int) -> Optional[int]:
return sizes[dim] if sizes else None
@_beartype.beartype
def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
if dim == -1:
tensor_rank = _get_tensor_rank(x)
@ -563,6 +595,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
return dim
@_beartype.beartype
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
if _C_onnx._CAFFE2_ATEN_FALLBACK:
@ -571,6 +604,7 @@ def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
_onnx_unsupported(f"{op}, {msg}", value)
@_beartype.beartype
def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoReturn:
message = (
f"Unsupported: ONNX export of operator {op_name}. "
@ -585,6 +619,7 @@ def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoRetur
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _onnx_opset_unsupported(
op_name: str,
current_opset: int,
@ -603,6 +638,7 @@ def _onnx_opset_unsupported(
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _onnx_opset_unsupported_detailed(
op_name: str,
current_opset: int,
@ -622,6 +658,7 @@ def _onnx_opset_unsupported_detailed(
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _block_list_in_opset(name: str):
def symbolic_fn(*args, **kwargs):
raise errors.OnnxExporterError(
@ -633,6 +670,7 @@ def _block_list_in_opset(name: str):
return symbolic_fn
@_beartype.beartype
def _try_get_scalar_type(*args) -> Optional[str]:
for arg in args:
try:
@ -642,6 +680,7 @@ def _try_get_scalar_type(*args) -> Optional[str]:
return None
@_beartype.beartype
def _select_helper(g, self, dim, index, apply_reshape=True):
index_const = _maybe_get_scalar(index)
index_dim = _get_tensor_rank(index)
@ -661,6 +700,7 @@ def _select_helper(g, self, dim, index, apply_reshape=True):
return g.op("Gather", self, index, axis_i=dim)
@_beartype.beartype
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if GLOBALS.export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice as _slice9
@ -672,6 +712,7 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
@_beartype.beartype
def _is_in_type_group(value, scalar_types: Set[_type_utils.JitScalarType]) -> bool:
"""Helper function for determining if a value is in a scalar type group."""
if value is None:
@ -696,6 +737,7 @@ def _is_in_type_group(value, scalar_types: Set[_type_utils.JitScalarType]) -> bo
return False
@_beartype.beartype
def _is_fp(value) -> bool:
return _is_in_type_group(
value,
@ -708,10 +750,12 @@ def _is_fp(value) -> bool:
)
@_beartype.beartype
def _is_bool(value) -> bool:
return _is_in_type_group(value, {_type_utils.JitScalarType.BOOL})
@_beartype.beartype
def _generate_wrapped_number(g, scalar):
"""Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
@ -730,6 +774,7 @@ def _generate_wrapped_number(g, scalar):
return g.op("Constant", value_t=torch.tensor(scalar))
@_beartype.beartype
def _sort_helper(g, input, dim, decending=True, out=None):
if out is not None:
_unimplemented("Sort", "Out parameter is not supported")
@ -749,6 +794,7 @@ def _sort_helper(g, input, dim, decending=True, out=None):
)
@_beartype.beartype
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported")
@ -768,6 +814,7 @@ def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
)
@_beartype.beartype
def _lt_helper(g, input, other):
if GLOBALS.export_onnx_opset_version <= 8:
from torch.onnx.symbolic_opset8 import lt as _lt8
@ -779,6 +826,7 @@ def _lt_helper(g, input, other):
return _lt9(g, input, other)
@_beartype.beartype
def _interpolate_warning(interpolate_mode):
onnx_op = (
"onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
@ -796,6 +844,7 @@ def _interpolate_warning(interpolate_mode):
)
@_beartype.beartype
def _unsqueeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
if GLOBALS.export_onnx_opset_version >= 13:
@ -810,6 +859,7 @@ def _unsqueeze_helper(g, input, axes_i):
return g.op("Unsqueeze", input, axes_i[0])
@_beartype.beartype
def _squeeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
if GLOBALS.export_onnx_opset_version >= 13:
@ -835,6 +885,7 @@ def _squeeze_helper(g, input, axes_i):
return g.op("Squeeze", input, axes_t)
@_beartype.beartype
def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
keepdims_i = _maybe_get_const(keepdims_i, "i")
if GLOBALS.export_onnx_opset_version >= 13:
@ -860,6 +911,7 @@ def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
@_beartype.beartype
def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, "is")
if _is_value(output_size):
@ -886,6 +938,7 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
return scales
@_beartype.beartype
def _interpolate_get_scales_if_available(g, scales):
available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
scales[0]
@ -902,6 +955,7 @@ def _interpolate_get_scales_if_available(g, scales):
return scales
@_beartype.beartype
def _get_interpolate_attributes(g, mode, args):
if mode == "nearest":
align_corners = None
@ -913,6 +967,7 @@ def _get_interpolate_attributes(g, mode, args):
return scales, align_corners
@_beartype.beartype
def _interpolate_get_scales(g, scale_factor, dim):
offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
scale_factor_rank = _get_tensor_rank(scale_factor)
@ -930,6 +985,7 @@ def _interpolate_get_scales(g, scale_factor, dim):
return scale_factor
@_beartype.beartype
def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_corners):
mode = _maybe_get_const(mode, "s")
if "linear" in mode:
@ -963,8 +1019,9 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_c
return scale_factor, mode
@_beartype.beartype
def _argmin_argmax_helper(
g, input: torch._C.Value, dim: torch._C.Value, keepdim: int, op_name: str
g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, op_name: str
):
def op_wrapper(input, axis_i, keepdims_i):
if GLOBALS.export_onnx_opset_version >= 12:
@ -997,6 +1054,7 @@ def _argmin_argmax_helper(
return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
@_beartype.beartype
def _interpolate_helper(name, dim, interpolate_mode):
@quantized_args(True, False, False)
def symbolic_fn(g, input, output_size, *args):
@ -1064,6 +1122,7 @@ def _interpolate_helper(name, dim, interpolate_mode):
return symbolic_fn
@_beartype.beartype
def __interpolate_helper(
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
):
@ -1157,6 +1216,7 @@ def __interpolate_helper(
) # only valid when mode="nearest"
@_beartype.beartype
def _unbind_helper(g, self, dim, _outputs):
if GLOBALS.export_onnx_opset_version < 11:
from torch.onnx.symbolic_opset9 import unbind
@ -1167,6 +1227,7 @@ def _unbind_helper(g, self, dim, _outputs):
return unbind(g, self, dim, _outputs)
@_beartype.beartype
def _scatter_helper(g, self, dim, index, src):
if GLOBALS.export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
@ -1176,6 +1237,7 @@ def _scatter_helper(g, self, dim, index, src):
return scatter(g, self, dim, index, src)
@_beartype.beartype
def _repeat_interleave_split_helper(g, self, reps, dim):
if GLOBALS.export_onnx_opset_version <= 12:
split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
@ -1187,6 +1249,7 @@ def _repeat_interleave_split_helper(g, self, reps, dim):
return split_out if reps > 1 else [split_out]
@_beartype.beartype
def _arange_cast_helper(
g, end, start=None, step=None, dtype=None
) -> Tuple[
@ -1228,6 +1291,7 @@ def _arange_cast_helper(
return scalar_type, end, start, step
@_beartype.beartype
def _arange_helper(g, *args):
if GLOBALS.export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import arange
@ -1236,6 +1300,7 @@ def _arange_helper(g, *args):
return arange(g, *args)
@_beartype.beartype
def _size_helper(g, self, dim):
full_shape = g.op("Shape", self)
from torch.onnx.symbolic_opset9 import select
@ -1243,6 +1308,7 @@ def _size_helper(g, self, dim):
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
@_beartype.beartype
def _index_fill_reshape_helper(g, self, dim, index):
# 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
# 2. expand index => [..., dim, ...], same shape as self except for dim.
@ -1276,6 +1342,7 @@ def _index_fill_reshape_helper(g, self, dim, index):
# allowzero=1 indicates that if any value in the 'shape' input is set to zero,
# the zero value is honored, similar to NumPy.
# allowzero=1 is only supported for opset version >= 14.
@_beartype.beartype
def _reshape_helper(g, input, shape, allowzero=0):
shape = _maybe_get_const(shape, "is")
if not _is_value(shape):
@ -1290,6 +1357,7 @@ def _reshape_helper(g, input, shape, allowzero=0):
return g.op("Reshape", input, shape, allowzero_i=allowzero)
@_beartype.beartype
def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
from torch.onnx.symbolic_opset9 import _var_mean
@ -1349,6 +1417,7 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
return weight, bias, running_mean, running_var
@_beartype.beartype
def _avgpool_helper(
tuple_fn: Callable[[Any], Sequence[int]],
padding: Union[int, Sequence[int]],
@ -1362,6 +1431,7 @@ def _avgpool_helper(
return tuple(tuple_fn(padding))
@_beartype.beartype
def check_training_mode(op_train_mode: int, op_name: str) -> None:
"""Warns the user if the model's training mode and the export mode do not agree."""
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
@ -1387,6 +1457,7 @@ def check_training_mode(op_train_mode: int, op_name: str) -> None:
)
@_beartype.beartype
def _flatten_helper(g, input, start_dim, end_dim, dim):
input_size = g.op("Shape", input)
slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
@ -1407,6 +1478,7 @@ def _flatten_helper(g, input, start_dim, end_dim, dim):
return _reshape_from_tensor(g, input, final_shape)
@_beartype.beartype
def _is_split_static(split_size_or_sizes, _outputs):
if _outputs is None:
return False
@ -1418,12 +1490,14 @@ def _is_split_static(split_size_or_sizes, _outputs):
return True
@_beartype.beartype
def _optional_input_placeholder_tensor(g):
n = g.op("prim::Constant")
n.setType(_C.OptionalType.ofTensor())
return n
@_beartype.beartype
def _handle_reduce_dim_none(g, self, op_name):
rank = _get_tensor_rank(self)
if rank is not None and any(
@ -1435,10 +1509,11 @@ def _handle_reduce_dim_none(g, self, op_name):
return g.op(op_name, self, keepdims_i=0)
@_beartype.beartype
def dequantize_helper(
g,
qtensor: _C.Value,
qdtype: Optional[torch.onnx.TensorProtoDataType] = None,
qdtype: Optional[_C_onnx.TensorProtoDataType] = None,
) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
"""Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
@ -1485,6 +1560,7 @@ def dequantize_helper(
)
@_beartype.beartype
def quantize_helper(
g,
tensor: _C.Value,
@ -1538,6 +1614,7 @@ def quantize_helper(
return g.op("prim::TupleConstruct", *args)
@_beartype.beartype
def requantize_bias_helper(g, bias, input_scale, weight_scale, axis=None):
"""In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel.
In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized.
@ -1558,6 +1635,7 @@ def requantize_bias_helper(g, bias, input_scale, weight_scale, axis=None):
return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
@_beartype.beartype
def args_have_same_dtype(args):
assert args
base_dtype = args[0].type().scalarType()

View File

@ -16,6 +16,7 @@ from torch.onnx import ( # noqa: F401
symbolic_opset9 as opset9,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@ -58,6 +59,7 @@ __all__ = [
]
@_beartype.beartype
def div(g, self, other, *args):
if len(args) == 0:
return opset9.true_divide(g, self, other)
@ -66,6 +68,7 @@ def div(g, self, other, *args):
@symbolic_helper.parse_args("v", "v", "s")
@_beartype.beartype
def _div_rounding_mode(g, self, other, rounding_mode):
if rounding_mode == "floor":
return _floor_divide(g, self, other)
@ -73,6 +76,7 @@ def _div_rounding_mode(g, self, other, rounding_mode):
return opset9._div_rounding_mode(g, self, other, rounding_mode)
@_beartype.beartype
def _floor_divide(g, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = opset9.true_divide(g, self, other)
@ -94,20 +98,24 @@ def _floor_divide(g, self, other):
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
)
@_beartype.beartype
def _max_pool(name, tuple_fn, ndims, return_indices):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
@_beartype.beartype
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if not stride:
stride = kernel_size
@ -178,9 +186,11 @@ max_pool3d_with_indices = _max_pool(
)
@_beartype.beartype
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
@ -196,6 +206,7 @@ def _avg_pool(name, tuple_fn):
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
if count_include_pad:
input = opset9.op_with_optional_float_cast(
g,
@ -225,8 +236,10 @@ avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair)
avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
@_beartype.beartype
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
@ -252,6 +265,7 @@ upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
@_beartype.beartype
def __interpolate(
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
):
@ -261,6 +275,7 @@ def __interpolate(
return g.op("Resize", input, scales, mode_s=mode)
@_beartype.beartype
def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if dynamic_slice:
starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
@ -288,6 +303,7 @@ def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
return g.op("Slice", input, starts, ends, axes, steps)
@_beartype.beartype
def slice(g, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
@ -336,6 +352,7 @@ def slice(g, self, *args):
@symbolic_helper.parse_args("v", "is")
@_beartype.beartype
def flip(g, input, dims):
return symbolic_helper._slice_helper(
g,
@ -347,11 +364,13 @@ def flip(g, input, dims):
)
@_beartype.beartype
def fmod(g, input, other):
return g.op("Mod", input, other, fmod_i=1)
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g,
embedding_matrix,
@ -437,6 +456,7 @@ def embedding_bag(
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g, inputs, scale, zero_point, quant_min=-128, quant_max=127
):
@ -478,10 +498,12 @@ def fake_quantize_per_tensor_affine(
)
@_beartype.beartype
def isinf(g, input):
return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined]
@_beartype.beartype
def isfinite(g, input):
from torch.onnx.symbolic_opset9 import __not_, __or_
@ -490,6 +512,7 @@ def isfinite(g, input):
return __not_(g, __or_(g, inf_node, nan_node))
@_beartype.beartype
def quantize_per_tensor(g, input, scale, zero_point, dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
# TODO(justinchuby): Extract all the cast ops into a helper function.
@ -500,11 +523,13 @@ def quantize_per_tensor(g, input, scale, zero_point, dtype):
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
@_beartype.beartype
def dequantize(g, input):
return symbolic_helper.dequantize_helper(g, input)[0]
@symbolic_helper.parse_args("v", "f", "f", "f")
@_beartype.beartype
def nan_to_num(g, input, nan, posinf, neginf):
# Cannot create a int type tensor with inf/nan values, so we simply
# return the original tensor
@ -566,6 +591,7 @@ class Quantized:
domain = "quantized"
@staticmethod
@_beartype.beartype
def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
@ -579,6 +605,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def add(g, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -588,6 +615,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def add_relu(g, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -598,6 +626,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def mul(g, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -607,6 +636,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def hardswish(g, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -615,6 +645,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def sigmoid(g, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -623,6 +654,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def leaky_relu(g, x, negative_slope, inplace, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -631,6 +663,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def layer_norm(g, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -639,6 +672,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def group_norm(g, x, num_groups, weight, bias, eps, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -648,6 +682,7 @@ class Quantized:
@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
@_beartype.beartype
def instance_norm(
g,
q_input,
@ -660,12 +695,13 @@ class Quantized:
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
output = opset9.instance_norm(
g, input, weight, bias, None, None, False, 0, eps, False
g, input, weight, bias, None, None, False, 0.0, eps, False
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def conv2d_relu(
g,
q_input,
@ -693,6 +729,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def conv2d(
g,
q_input,
@ -720,6 +757,7 @@ class Quantized:
@staticmethod
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def cat(
g,
q_inputs: _C.Value,

View File

@ -2,7 +2,7 @@
import sys
import warnings
from typing import Tuple, Union
from typing import Optional, Sequence, Union
import torch
from torch import _C
@ -16,6 +16,7 @@ from torch.onnx import (
utils,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@ -94,6 +95,7 @@ __all__ = [
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
@_beartype.beartype
def hardtanh(g, self: _C.Value, min_val: float, max_val: float):
dtype = self.type().scalarType()
if dtype is None:
@ -113,9 +115,11 @@ def hardtanh(g, self: _C.Value, min_val: float, max_val: float):
)
@_beartype.beartype
def clamp(g, self, min, max):
dtype = self.type().scalarType()
@_beartype.beartype
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not symbolic_helper._is_none(tensor):
return g.op(
@ -147,6 +151,7 @@ def clamp(g, self, min, max):
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_min(g, self, min):
dtype = self.type().scalarType()
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type())
@ -160,6 +165,7 @@ def clamp_min(g, self, min):
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_max(g, self, max):
dtype = self.type().scalarType()
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type())
@ -172,6 +178,7 @@ def clamp_max(g, self, max):
return opset9.op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
@_beartype.beartype
def relu6(g, input):
relu_ = opset9.op_with_optional_float_cast(g, "Relu", input, opset_before=14)
dtype = input.type().scalarType()
@ -193,10 +200,12 @@ def relu6(g, input):
# Opset 11 gather accepts negative indices
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "v")
@_beartype.beartype
def select(g, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
@_beartype.beartype
def index_put(g, self, indices_list_value, values, accumulate=False):
if symbolic_helper._is_packed_list(indices_list_value):
indices_list = symbolic_helper._unpack_list(indices_list_value)
@ -307,6 +316,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def pixel_shuffle(g, self, upscale_factor):
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None and rank != 4:
@ -314,6 +324,7 @@ def pixel_shuffle(g, self, upscale_factor):
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
@ -336,6 +347,7 @@ upsample_bicubic2d.__module__ = "torch.onnx.symbolic_opset11"
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@_beartype.beartype
def __interpolate(
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
):
@ -345,6 +357,7 @@ def __interpolate(
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def gather(g, self, dim, index, sparse_grad=False):
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
@ -354,6 +367,7 @@ def gather(g, self, dim, index, sparse_grad=False):
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter(g, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
@ -378,6 +392,7 @@ def scatter(g, self, dim, index, src):
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def cumsum(g, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
if dtype and dtype.node().kind() != "prim::Constant":
@ -391,11 +406,13 @@ def cumsum(g, self, dim, dtype=None):
return csum
@_beartype.beartype
def masked_select(g, self, mask):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
return g.op("GatherND", self, index)
@_beartype.beartype
def masked_scatter(g, self, mask, source):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
# NOTE: source can have more elements than needed.
@ -413,6 +430,7 @@ def masked_scatter(g, self, mask, source):
return g.op("ScatterND", self, index, source)
@_beartype.beartype
def _len(g, self):
if (
symbolic_helper._is_tensor_list(self)
@ -423,6 +441,7 @@ def _len(g, self):
return symbolic_helper._squeeze_helper(g, sz_0, [0])
@_beartype.beartype
def __getitem_(g, self, i):
if symbolic_helper._is_tensor_list(self):
# SequenceAt requires that the input be a List of Tensors
@ -433,15 +452,18 @@ def __getitem_(g, self, i):
return getitem(g, self, i)
@_beartype.beartype
def _set_item(g, tensor_list, i, v):
tensor_list = g.op("SequenceErase", tensor_list, i)
return g.op("SequenceInsert", tensor_list, v, i)
@_beartype.beartype
def append(g, self, tensor):
return g.op("SequenceInsert", self, tensor)
@_beartype.beartype
def add(g, self, other, alpha=None):
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
tensor_list_node = other.node()
@ -458,18 +480,22 @@ def add(g, self, other, alpha=None):
return opset9.add(g, self, other, alpha)
@_beartype.beartype
def insert(g, self, pos, tensor):
return g.op("SequenceInsert", self, tensor, pos)
@_beartype.beartype
def pop(g, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_beartype.beartype
def Delete(g, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_beartype.beartype
def cat(g, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.cat(g, tensor_list, dim)
@ -478,6 +504,7 @@ def cat(g, tensor_list, dim):
return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
@_beartype.beartype
def stack(g, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.stack(g, tensor_list, dim)
@ -487,6 +514,7 @@ def stack(g, tensor_list, dim):
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def _unique2(g, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op(
"Unique", self, sorted_i=sorted, outputs=4
@ -494,15 +522,17 @@ def _unique2(g, self, sorted, return_inverse, return_counts):
return u, inverse_indices, counts
@_beartype.beartype
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[int, Tuple[int, ...]],
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Union[int, Sequence[int]],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
@ -510,6 +540,7 @@ def _avg_pool(name, tuple_fn):
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
if not stride:
stride = kernel_size
if count_include_pad:
@ -539,6 +570,7 @@ avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
@symbolic_helper.parse_args("v", "i", "i", "i", "i")
@_beartype.beartype
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op(
"Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
@ -547,6 +579,7 @@ def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
@ -554,11 +587,13 @@ def topk(g, self, k, dim, largest, sorted, out=None):
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def argsort(g, self, dim, decending, out=None):
_, indices = symbolic_helper._sort_helper(
g, self, dim, decending=decending, out=out
@ -566,10 +601,12 @@ def argsort(g, self, dim, decending, out=None):
return indices
@_beartype.beartype
def round(g, self):
return g.op("Round", self)
@_beartype.beartype
def remainder(g, input, other):
if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
return opset9.remainder(g, input, other)
@ -577,6 +614,7 @@ def remainder(g, input, other):
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split(g, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
@ -614,11 +652,13 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def unbind(g, self, dim=0, _outputs=None):
if _outputs is None:
return g.op(
@ -638,6 +678,7 @@ def unbind(g, self, dim=0, _outputs=None):
# pad: the paddings in pytorch.
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
# where m is in range [0, n].
@_beartype.beartype
def _prepare_onnx_paddings(g, input, pad):
if (
not symbolic_helper._is_packed_list(pad)
@ -687,6 +728,7 @@ def _prepare_onnx_paddings(g, input, pad):
return padding_c
@_beartype.beartype
def constant_pad_nd(g, input, padding, value=None):
mode = "constant"
value = symbolic_helper._maybe_get_scalar(value)
@ -695,12 +737,14 @@ def constant_pad_nd(g, input, padding, value=None):
return g.op("Pad", input, pad, value, mode_s=mode)
@_beartype.beartype
def reflection_pad(g, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, paddings, mode_s=mode)
@_beartype.beartype
def replication_pad(g, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(g, input, padding)
@ -715,6 +759,7 @@ replication_pad2d = replication_pad
replication_pad3d = replication_pad
@_beartype.beartype
def pad(g, input, pad, mode, value):
mode = symbolic_helper._parse_arg(mode, "s")
if mode == "replicate":
@ -729,15 +774,19 @@ def pad(g, input, pad, mode, value):
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
@_beartype.beartype
def linalg_det(g, self):
return g.op("Det", self)
@_beartype.beartype
def logdet(g, input):
return opset9.log(g, linalg_det(g, input))
@_beartype.beartype
def arange(g, *args):
@_beartype.beartype
def _get_arange_dtype(dtype):
dtype = symbolic_helper._maybe_get_const(dtype, "i")
return dtype
@ -790,6 +839,7 @@ def arange(g, *args):
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op(
@ -800,12 +850,14 @@ def _dim_arange(g, like, dim):
return arange(g, stop, 4, None, None, None)
@_beartype.beartype
def size(g, self, dim=None):
if dim is None:
return g.op("Shape", self)
return symbolic_helper._size_helper(g, self, dim)
@_beartype.beartype
def squeeze(g, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
@ -857,6 +909,7 @@ def squeeze(g, self, dim=None):
return symbolic_helper._squeeze_helper(g, self, [dim])
@_beartype.beartype
def unsqueeze(g, self, dim):
if symbolic_helper._is_constant(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
@ -864,10 +917,12 @@ def unsqueeze(g, self, dim):
return symbolic_helper._unsqueeze_helper(g, self, [dim])
@_beartype.beartype
def mm(g, self, other):
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
@_beartype.beartype
def index(g, self, index):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index", self, index, overload_name="Tensor")
@ -888,6 +943,7 @@ def index(g, self, index):
return opset9.index(g, self, index)
@_beartype.beartype
def index_fill(g, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
@ -909,6 +965,7 @@ def index_fill(g, self, dim, index, value):
return scatter(g, self, dim, expanded_index, expanded_value)
@_beartype.beartype
def index_copy(g, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
@ -919,6 +976,7 @@ def index_copy(g, self, dim, index, source):
return scatter(g, self, dim, expanded_index, source)
@_beartype.beartype
def __rshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
@ -948,6 +1006,7 @@ def __rshift_(g, self, other):
return rshift
@_beartype.beartype
def __lshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
@ -977,6 +1036,7 @@ def __lshift_(g, self, other):
return lshift
@_beartype.beartype
def _get_im2col_indices_along_dim(
g, input_d, kernel_size_d, dilation_d, padding_d, stride_d
):
@ -1020,6 +1080,7 @@ def _get_im2col_indices_along_dim(
return block_mask
@_beartype.beartype
def _get_im2col_padded_input(g, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
@ -1028,6 +1089,7 @@ def _get_im2col_padded_input(g, input, padding_h, padding_w):
return g.op("Pad", input, pad)
@_beartype.beartype
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
@ -1045,6 +1107,7 @@ def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
@symbolic_helper.parse_args("v", "is", "is", "is", "is")
@_beartype.beartype
def im2col(g, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
@ -1096,6 +1159,7 @@ def im2col(g, input, kernel_size, dilation, padding, stride):
return symbolic_helper._reshape_helper(g, output, output_shape)
@_beartype.beartype
def narrow(g, input, dim, start, length):
end = g.op("Add", start, length)
return symbolic_helper._slice_helper(
@ -1105,6 +1169,7 @@ def narrow(g, input, dim, start, length):
@symbolic_helper.quantized_args(True, False, False)
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def flatten(g, input, start_dim, end_dim):
dim = symbolic_helper._get_tensor_rank(input)
if dim == 1:
@ -1129,14 +1194,17 @@ def flatten(g, input, start_dim, end_dim):
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
@symbolic_helper.parse_args("v", "f", "is", "i", "v")
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
@_beartype.beartype
def linalg_vector_norm(
g, self, ord, dim: Optional[Sequence[int]], keepdim: bool, dtype
):
if ord == 0:
if dim is None:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
)
keepdim = 0
keepdim = False
cond_op = g.op(
"Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
@ -1156,6 +1224,7 @@ def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g,
embedding_matrix,
@ -1243,6 +1312,7 @@ def embedding_bag(
@symbolic_helper.parse_args("v", "v", "f", "f")
@_beartype.beartype
def embedding_renorm(g, weight, indices, max_norm, norm_type):
unique_indices = g.op("Unique", indices)
partial_weight = g.op("Gather", weight, unique_indices)
@ -1280,6 +1350,7 @@ def embedding_renorm(g, weight, indices, max_norm, norm_type):
)
@_beartype.beartype
def chunk(g, self, chunks, dim):
# Calculate chunk size for dynamic chunk
dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
@ -1296,6 +1367,7 @@ def chunk(g, self, chunks, dim):
return split(g, self, chunk_vec, dim)
@_beartype.beartype
def normal(
g,
mean,
@ -1322,6 +1394,7 @@ class Prim:
domain = "prim"
@staticmethod
@_beartype.beartype
def ConstantChunk(g, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))

View File

@ -10,6 +10,7 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import _beartype
# EDITING THIS FILE? READ THIS FIRST!
@ -38,6 +39,7 @@ __all__ = [
]
@_beartype.beartype
def _einsum_helper(g, equation, tensors):
if not tensors:
raise RuntimeError("Einsum inputs are empty.")
@ -57,12 +59,14 @@ def _einsum_helper(g, equation, tensors):
@symbolic_helper.parse_args("s", "v")
@_beartype.beartype
def einsum(g, equation, tensor_list):
tensors = symbolic_helper._unpack_list(tensor_list)
return _einsum_helper(g, equation, tensors)
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def outer(g, input, other):
# make sure to cast other to self's type
if other.type().scalarType() != input.type().scalarType():
@ -76,6 +80,7 @@ def outer(g, input, other):
return _einsum_helper(g, "i,j->ij", [input, other])
@_beartype.beartype
def _dropout_returns_masked_input_and_mask(
g, input: torch._C.Value, p: float, train: bool
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
@ -90,17 +95,20 @@ def _dropout_returns_masked_input_and_mask(
return r, mask
@symbolic_helper.parse_args("v", "f", "i")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def dropout(g, input, p, train):
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
return masked
@symbolic_helper.parse_args("v", "f", "i")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def native_dropout(g, input, p, train):
return _dropout_returns_masked_input_and_mask(g, input, p, train)
@_beartype.beartype
def nll_loss(g, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
@ -133,14 +141,17 @@ def nll_loss(g, self, target, weight, reduction, ignore_index):
return nllloss
@_beartype.beartype
def nll_loss2d(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_beartype.beartype
def nll_loss_nd(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_beartype.beartype
def cross_entropy_loss(
g, self, target, weight, reduction, ignore_index, label_smoothing
):
@ -182,6 +193,7 @@ def cross_entropy_loss(
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
@_beartype.beartype
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction):
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = opset9.sigmoid(g, input)
@ -223,6 +235,7 @@ def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduc
)
@_beartype.beartype
def celu(g, self, alpha):
alpha = symbolic_helper._maybe_get_const(alpha, "f")
# if the input is of type double cast it to float
@ -234,29 +247,35 @@ def celu(g, self, alpha):
return g.op("Celu", self, alpha_f=alpha)
@symbolic_helper.parse_args("v", "v", "i")
def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: int):
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
@symbolic_helper.parse_args("v", "v", "i")
def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: int):
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
@_beartype.beartype
def pow(g, self, exponent):
return g.op("Pow", self, exponent)
@_beartype.beartype
def ge(g, input, other):
return g.op("GreaterOrEqual", input, other)
@_beartype.beartype
def le(g, input, other):
return g.op("LessOrEqual", input, other)
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def unfold(g, input, dimension, size, step):
const_size = symbolic_helper._maybe_get_const(size, "i")
const_step = symbolic_helper._maybe_get_const(step, "i")
@ -326,6 +345,7 @@ def unfold(g, input, dimension, size, step):
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
@_beartype.beartype
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
if out is not None:
symbolic_helper._unimplemented(

View File

@ -12,9 +12,11 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import _beartype
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
@ -27,6 +29,7 @@ def softmax(g, input, dim, dtype=None):
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def log_softmax(g, input, dim, dtype=None):
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
@ -38,6 +41,7 @@ def log_softmax(g, input, dim, dtype=None):
@symbolic_helper.parse_args("v", "v", "i")
@_beartype.beartype
def frobenius_norm(g, self, dim=None, keepdim=False):
dim_val = symbolic_helper._maybe_get_const(dim, "is")
if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
@ -48,6 +52,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False):
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split(g, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
@ -103,19 +108,23 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
@_beartype.beartype
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
@_beartype.beartype
def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
return split(g, self, split_size_or_sizes, dim, _outputs)
@_beartype.beartype
def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split_with_sizes(g, self, split_sizes, dim, _outputs)
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
axis = opset11.unsqueeze(g, axis, 0)
@ -247,6 +256,7 @@ def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def unbind(g, self, dim=0, _outputs=None):
if _outputs is None:
return g.op(
@ -268,11 +278,13 @@ def unbind(g, self, dim=0, _outputs=None):
# Emitted from `torch.nonzero(x, as_tuple=True)`
@_beartype.beartype
def nonzero_numpy(g, input, _outputs=None):
return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def where(g, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if not symbolic_helper._is_bool(condition):
@ -286,6 +298,7 @@ def where(g, condition, self=None, other=None, _outputs=None):
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
@_beartype.beartype
def fake_quantize_per_channel_affine(
g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127
):
@ -314,6 +327,7 @@ def fake_quantize_per_channel_affine(
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g, inputs, scale, zero_point, quant_min=-128, quant_max=127
):
@ -342,7 +356,9 @@ def fake_quantize_per_tensor_affine(
return g.op("DequantizeLinear", quantized, scale, zero_point)
@_beartype.beartype
def _reduce_op_symbolic(onnx_op_name):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = opset9._maybe_cast_reduce_op_input(g, self)
if dim is None:
@ -355,12 +371,15 @@ def _reduce_op_symbolic(onnx_op_name):
return symbolic
@_beartype.beartype
def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)
@opset9.overload_by_arg_count
@_beartype.beartype
def reduce(g, *args, **kwargs):
@symbolic_helper.parse_args("v", "none")
@_beartype.beartype
def reduce_nodim(g, self, dtype):
if dtype.node().kind() == "onnx::Constant":
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
@ -372,6 +391,7 @@ def _reduce_with_dtype(onnx_op, name):
return symbolic(g, self)
@symbolic_helper.parse_args("v", "v", "i", "none")
@_beartype.beartype
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() == "onnx::Constant":
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
@ -392,6 +412,7 @@ sum = _reduce_with_dtype("ReduceSum", "sum")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
if _outputs is None:
return g.op(
@ -418,6 +439,7 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None):
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
@_beartype.beartype
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
input = self
final_dim = dim
@ -551,6 +573,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def diagonal(g, self, offset, dim1, dim2):
dim1_size = opset9.size(
g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
@ -674,6 +697,7 @@ class Quantized:
domain = "quantized"
@staticmethod
@_beartype.beartype
def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
@ -687,6 +711,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def conv2d(
g,
q_input,
@ -713,6 +738,7 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@_beartype.beartype
def conv2d_relu(
g,
q_input,

View File

@ -18,26 +18,31 @@ Updated operators:
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
@symbolic_helper.parse_args("v")
@_beartype.beartype
def hardswish(g, self):
return g.op("HardSwish", self)
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def tril(g, self, diagonal, out=None):
k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
return g.op("Trilu", self, k, upper_i=0)
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def triu(g, self, diagonal, out=None):
k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
return g.op("Trilu", self, k, upper_i=1)
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def reshape(g, self, shape):
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
@ -45,6 +50,7 @@ def reshape(g, self, shape):
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
@_beartype.beartype
def batch_norm(
g,
input,
@ -107,6 +113,7 @@ class Quantized:
domain = "quantized"
@staticmethod
@_beartype.beartype
def hardswish(g, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)

View File

@ -28,8 +28,10 @@ Updated operators:
import torch
from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import _beartype
@_beartype.beartype
def __is_(g, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
@ -41,6 +43,7 @@ def __is_(g, self, other):
@opset9.wrap_logical_op_with_negation
@_beartype.beartype
def __isnot_(g, self, other):
return __is_(g, self, other)
@ -49,6 +52,7 @@ class Prim:
domain = "prim"
@staticmethod
@_beartype.beartype
def unchecked_cast(g, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast

View File

@ -30,11 +30,13 @@ from torch.nn.functional import (
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, symbolic_helper
from torch.onnx._internal import _beartype
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
@ -49,6 +51,7 @@ def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_add(g, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")

File diff suppressed because it is too large Load Diff

View File

@ -77,6 +77,7 @@ _params_dict = {} # type: ignore[var-annotated]
@contextlib.contextmanager
@_beartype.beartype
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
r"""A context manager to temporarily set the training mode of ``model``
to ``mode``, resetting it when we exit the with-block.
@ -128,6 +129,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
@contextlib.contextmanager
@_beartype.beartype
def disable_apex_o2_state_dict_hook(
model: Union[torch.nn.Module, torch.jit.ScriptFunction]
):
@ -161,7 +163,8 @@ def disable_apex_o2_state_dict_hook(
@contextlib.contextmanager
def setup_onnx_logging(verbose):
@_beartype.beartype
def setup_onnx_logging(verbose: bool):
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
if is_originally_enabled or verbose:
torch.onnx.enable_log()
@ -173,7 +176,8 @@ def setup_onnx_logging(verbose):
@contextlib.contextmanager
def exporter_context(model, mode, verbose):
@_beartype.beartype
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
with select_model_mode_for_export(
model, mode
) as mode_ctx, disable_apex_o2_state_dict_hook(
@ -498,6 +502,7 @@ def export(
)
@_beartype.beartype
def _is_constant_tensor_list(node):
if node.kind() != "prim::Constant":
return False
@ -512,6 +517,7 @@ def _is_constant_tensor_list(node):
# get generated in constant prop. So we split them back into prim::ListConstructs
@_beartype.beartype
def _split_tensor_list_constants(g, block):
for node in block.nodes():
for subblock in node.blocks():
@ -534,6 +540,7 @@ def _split_tensor_list_constants(g, block):
node.output().replaceAllUsesWith(lc)
@_beartype.beartype
def _optimize_graph(
graph: _C.Graph,
operator_export_type: _C_onnx.OperatorExportTypes,
@ -655,6 +662,7 @@ def _optimize_graph(
return graph
@_beartype.beartype
def warn_on_static_input_change(input_states):
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
@ -684,6 +692,7 @@ def warn_on_static_input_change(input_states):
warnings.warn(warning)
@_beartype.beartype
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
if (
@ -700,6 +709,7 @@ def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
return arg_value
@_beartype.beartype
def _decide_keep_init_as_input(
keep_initializers_as_inputs: Optional[bool],
operator_export_type: _C_onnx.OperatorExportTypes,
@ -743,12 +753,14 @@ def _decide_keep_init_as_input(
return val_keep_init_as_ip
@_beartype.beartype
def _decide_add_node_names(add_node_names, operator_export_type):
return _resolve_args_by_export_type(
"add_node_names", add_node_names, operator_export_type
)
@_beartype.beartype
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
do_constant_folding = _resolve_args_by_export_type(
"do_constant_folding", do_constant_folding, operator_export_type
@ -767,6 +779,7 @@ def _decide_constant_folding(do_constant_folding, operator_export_type, training
return do_constant_folding
@_beartype.beartype
def _signature(model) -> inspect.Signature:
should_be_callable = getattr(model, "forward", model)
if callable(should_be_callable):
@ -774,6 +787,7 @@ def _signature(model) -> inspect.Signature:
raise ValueError("model has no forward method and is not callable")
@_beartype.beartype
def _decide_input_format(model, args):
try:
sig = _signature(model)
@ -813,6 +827,7 @@ def _decide_input_format(model, args):
return args
@_beartype.beartype
def _trace(func, args, operator_export_type, return_outs=False):
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
@ -829,6 +844,7 @@ def _trace(func, args, operator_export_type, return_outs=False):
return trace_graph
@_beartype.beartype
def _trace_and_get_graph_from_model(model, args):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
@ -857,6 +873,7 @@ def _trace_and_get_graph_from_model(model, args):
return trace_graph, torch_out
@_beartype.beartype
def _get_param_count_list(method_graph, args_params):
param_count_list = []
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
@ -869,9 +886,11 @@ def _get_param_count_list(method_graph, args_params):
return param_count_list
@_beartype.beartype
def _check_flatten_did_not_remove(original, jit_flattened):
"""torch.jit._flatten removes None. Check if it did so in this case."""
@_beartype.beartype
def flatten(x):
if isinstance(x, (list, tuple)):
for inner in x:
@ -949,6 +968,7 @@ def _create_jit_graph(
return graph, params, torch_out, None
@_beartype.beartype
def _get_named_param_dict(graph, params):
input_and_param_names = [val.debugName() for val in graph.inputs()]
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
@ -956,6 +976,7 @@ def _get_named_param_dict(graph, params):
return _params_dict
@_beartype.beartype
def _get_example_outputs(model, args):
input_args = copy.deepcopy(args)
input_kwargs = {}
@ -980,6 +1001,7 @@ _qtype_vtype_map = {
}
@_beartype.beartype
def unpack_quantized_tensor(value, cast_onnx_accepted=True):
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
q_value_dequantize = value.dequantize()
@ -1000,6 +1022,7 @@ def unpack_quantized_tensor(value, cast_onnx_accepted=True):
return (value,)
@_beartype.beartype
def _pre_trace_quant_model(model, args):
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
original model.
@ -1013,6 +1036,7 @@ def _pre_trace_quant_model(model, args):
return model
@_beartype.beartype
def _model_to_graph(
model,
args,
@ -1028,7 +1052,15 @@ def _model_to_graph(
) -> Tuple[
_C.Graph,
Dict[str, torch.Tensor],
Optional[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]],
Optional[
Union[
torch.Tensor,
Tuple[torch.Tensor, ...],
List[torch.Tensor],
Dict[str, torch.Tensor],
Any, # Can be nested tuples etc.
]
],
]:
"""Converts model into an ONNX graph.
@ -1133,6 +1165,7 @@ def _model_to_graph(
return graph, params_dict, torch_out
@_beartype.beartype
def export_to_pretty_string(
model,
args,
@ -1209,6 +1242,7 @@ def export_to_pretty_string(
)
@_beartype.beartype
def unconvertible_ops(
model, args, training=_C_onnx.TrainingMode.EVAL, opset_version=None
):
@ -1249,6 +1283,7 @@ def unconvertible_ops(
return graph, unsupported_ops
@_beartype.beartype
def _setup_trace_module_map(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
@ -1326,11 +1361,13 @@ def _setup_trace_module_map(
return module_typenames
@_beartype.beartype
def _reset_trace_module_map():
torch.jit._trace._trace_module_map = None
_C._jit_pass_onnx_clear_scope_records()
@_beartype.beartype
def _get_module_attributes(module):
annotations = typing.get_type_hints(type(module))
@ -1339,6 +1376,7 @@ def _get_module_attributes(module):
return {k: getattr(module, k) for k in annotations}
@_beartype.beartype
def _export(
model,
args,
@ -1561,6 +1599,7 @@ def _export(
return torch_out
@_beartype.beartype
def _apply_friendly_debug_names(graph, params):
for n in graph.nodes():
for v in n.inputs():
@ -1573,7 +1612,9 @@ def _apply_friendly_debug_names(graph, params):
params[new_name] = params.pop(old_name)
@_beartype.beartype
def _set_input_and_output_names(graph, input_names, output_names):
@_beartype.beartype
def set_names(node_list, name_list, descriptor):
if name_list is None:
return
@ -1604,6 +1645,7 @@ def _set_input_and_output_names(graph, input_names, output_names):
set_names(list(graph.outputs()), output_names, "output")
@_beartype.beartype
def _run_symbolic_method(g, op_name, symbolic_fn, args):
r"""
This trampoline function gets invoked for every symbolic method
@ -1619,14 +1661,17 @@ def _run_symbolic_method(g, op_name, symbolic_fn, args):
raise
@_beartype.beartype
def _add_block(node: _C.Node):
return node.addBlock() # type: ignore[attr-defined]
@_beartype.beartype
def _add_input_to_block(block: _C.Block):
return block.addInputToBlock() # type: ignore[attr-defined]
@_beartype.beartype
def _add_output_to_block(block: _C.Block, value: _C.Value):
new_output = block.registerOutput(value) # type: ignore[attr-defined]
return new_output
@ -1641,6 +1686,7 @@ def _add_output_to_block(block: _C.Block, value: _C.Value):
# inplace annotations, but we are losing information this way.
@_beartype.beartype
def _find_symbolic_in_registry(
domain: str,
op_name: str,
@ -1666,6 +1712,7 @@ def _find_symbolic_in_registry(
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
@_beartype.beartype
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
is_exportable_aten_op = symbolic_registry.is_registered_op(
@ -1680,6 +1727,7 @@ def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
)
@_beartype.beartype
def _need_symbolic_context(symbolic_fn) -> bool:
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
params = tuple(inspect.signature(symbolic_fn).parameters.values())
@ -1695,6 +1743,7 @@ def _need_symbolic_context(symbolic_fn) -> bool:
return issubclass(param_type, _exporter_states.SymbolicContext)
@_beartype.beartype
def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
@ -1704,6 +1753,7 @@ def _get_aten_op_overload_name(n: _C.Node) -> str:
return _C.parse_schema(schema).overload_name
@_beartype.beartype
def _run_symbolic_function(
g: _C.Graph,
block: _C.Block,
@ -1711,7 +1761,7 @@ def _run_symbolic_function(
inputs: Any,
env: Dict[_C.Value, _C.Value],
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
) -> Optional[Union[_C.Value, Tuple[_C.Value, ...]]]:
) -> Optional[Union[_C.Value, Sequence[Optional[_C.Value]]]]:
"""Runs a symbolic function.
The function is used in C++ to export the node to ONNX.
@ -1815,6 +1865,7 @@ def _run_symbolic_function(
raise
@_beartype.beartype
def get_ns_op_name_from_custom_op(symbolic_name):
if not bool(
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
@ -1838,6 +1889,7 @@ def get_ns_op_name_from_custom_op(symbolic_name):
return ns, op_name
@_beartype.beartype
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
"""Registers a symbolic function for a custom operator.
@ -1865,6 +1917,7 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
@_beartype.beartype
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
"""Unregisters ``symbolic_name``.
@ -1884,6 +1937,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
symbolic_registry.unregister_op(op_name, ns, version)
@_beartype.beartype
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
"""Ensures dynamic axes argument is follows the expected format."""
if len(dynamic_axes) == 0:

View File

@ -13,7 +13,7 @@ import itertools
import os
import tempfile
import warnings
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
@ -22,10 +22,15 @@ import torch._C._onnx as _C_onnx
from torch import _C
from torch.onnx import _constants, _experimental, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
from torch.types import Number
_ORT_PROVIDERS = ("CPUExecutionProvider",)
_NumericType = Union[Number, torch.Tensor, np.ndarray]
@_beartype.beartype
def _flatten_tuples(elem):
flattened = []
for t in elem:
@ -36,7 +41,8 @@ def _flatten_tuples(elem):
return flattened
def _to_numpy(elem):
# TODO(justinchuby): Add type checking by narrowing down the return type when input is None
def _to_numpy(elem) -> Union[list, np.ndarray]:
if isinstance(elem, torch.Tensor):
if elem.requires_grad:
return elem.detach().cpu().numpy()
@ -49,12 +55,13 @@ def _to_numpy(elem):
elif isinstance(elem, dict):
flattened = []
for k in elem:
flattened += [_to_numpy(k)] + [_to_numpy(elem[k])]
flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
return flattened
return elem
def _inline_flatten_list(inputs, res_list):
@_beartype.beartype
def _inline_flatten_list(inputs, res_list) -> list:
for i in inputs:
res_list.append(i) if not isinstance(
i, (list, tuple)
@ -62,7 +69,8 @@ def _inline_flatten_list(inputs, res_list):
return res_list
def _unpack_to_numpy(values, cast_onnx_accepted=True):
@_beartype.beartype
def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
value_unpacked = []
for value in values:
value_unpacked.extend(
@ -71,6 +79,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True):
return [_to_numpy(v) for v in value_unpacked]
@_beartype.beartype
def _run_ort(ort_session, inputs):
kw_inputs = {}
if inputs and isinstance(inputs[-1], dict):
@ -92,6 +101,7 @@ def _run_ort(ort_session, inputs):
return ort_outs
@_beartype.beartype
def _ort_session(
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
):
@ -115,9 +125,10 @@ def _ort_session(
return ort_session
@_beartype.beartype
def _compare_ort_pytorch_outputs(
ort_outs: Sequence[np.ndarray],
pt_outs: Sequence[torch.Tensor],
ort_outs: Union[Sequence[_NumericType], Sequence],
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]],
rtol: float,
atol: float,
check_shape: bool,
@ -191,6 +202,7 @@ def _compare_ort_pytorch_outputs(
raise
@_beartype.beartype
def _prepare_input_for_pytorch(args, kwargs):
"""Prepare input for PyTorch model execution.
@ -217,6 +229,7 @@ def _prepare_input_for_pytorch(args, kwargs):
return args, kwargs
@_beartype.beartype
def _prepare_input_for_export(args, kwargs):
"""Prepare input for ONNX model export.
@ -241,6 +254,7 @@ def _prepare_input_for_export(args, kwargs):
return onnx_inputs
@_beartype.beartype
def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
"""Prepare input for ONNX model execution in ONNX Runtime.
@ -266,6 +280,7 @@ def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
return onnx_inputs
@_beartype.beartype
def _try_clone_model(model):
"""Used for preserving original model in case forward mutates model states."""
try:
@ -277,6 +292,7 @@ def _try_clone_model(model):
return model
@_beartype.beartype
def _compare_ort_pytorch_model(
model,
ort_session,
@ -301,6 +317,7 @@ def _compare_ort_pytorch_model(
equal up to specified precision.
"""
@_beartype.beartype
def compare_ort_pytorch_model_with_input(input_args, input_kwargs):
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
# TODO: remove this and treat mutating model separately. See #77679
@ -333,6 +350,7 @@ def _compare_ort_pytorch_model(
class _GraphDiff:
"""A class to represent the difference between two graphs."""
@_beartype.beartype
def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
"""Construct a _GraphDiff object.
@ -343,13 +361,16 @@ class _GraphDiff:
self.graph_a = graph_a
self.graph_b = graph_b
@_beartype.beartype
def __str__(self):
"""See function :func:`diff_report`."""
return self.diff_report()
@_beartype.beartype
def _indent(self, lines: str) -> str:
return "\n".join(["\t" + line for line in lines.splitlines()])
@_beartype.beartype
def diff_report(self) -> str:
"""Return a string representation of the graph difference.
@ -399,6 +420,7 @@ class _GraphDiff:
return "\n".join(graph_diff_report)
@_beartype.beartype
def _check_graph_diff(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
@ -440,6 +462,7 @@ def _check_graph_diff(
return ""
@_beartype.beartype
def _traced_graph_from_model(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
args: Tuple[Any, ...],
@ -467,6 +490,7 @@ def _traced_graph_from_model(
return jit_graph
@_beartype.beartype
def _onnx_graph_from_model(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
args: Tuple[Any, ...],
@ -534,6 +558,7 @@ def _onnx_graph_from_model(
return onnx_graph
@_beartype.beartype
def check_export_model_diff(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
@ -577,9 +602,10 @@ def check_export_model_diff(
)
@_beartype.beartype
def verify(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
input_args: Tuple[Any, ...],
input_args: Union[torch.Tensor, Tuple[Any, ...]],
input_kwargs: Optional[Mapping[str, Any]] = None,
do_constant_folding: bool = True,
dynamic_axes: Optional[
@ -593,7 +619,9 @@ def verify(
verbose: bool = False,
fixed_batch_size: bool = False,
use_external_data: bool = False,
additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
additional_test_inputs: Optional[
Sequence[Union[torch.Tensor, Tuple[Any, ...]]]
] = None,
remained_onnx_input_idx: Optional[Sequence[int]] = None,
flatten: bool = True,
ignore_none: bool = True,