mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Fix type annotations and enable type checking for all apis (#84091)
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a332afbf4
commit
388368b699
@ -3583,7 +3583,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
|
||||
x = torch.arange(1.0, 6.0, requires_grad=True)
|
||||
k = torch.tensor(3)
|
||||
self.run_test(MyModuleDynamic(), [x, k])
|
||||
self.run_test(MyModuleDynamic(), (x, k))
|
||||
|
||||
@skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
|
||||
@skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11.
|
||||
@ -7405,23 +7405,28 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
x = torch.randn(2, 2, 4, 4)
|
||||
self.run_test(model, x)
|
||||
|
||||
# Dynamic padding is added in opset 11
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_pad_types(self):
|
||||
@common_utils.parametrize(
|
||||
"pad",
|
||||
[
|
||||
common_utils.subtest([2, 4], name="scalar_list"),
|
||||
common_utils.subtest(
|
||||
[
|
||||
torch.tensor(2, dtype=torch.int64),
|
||||
torch.tensor(4, dtype=torch.int64),
|
||||
],
|
||||
name="scalar_tensor_list",
|
||||
),
|
||||
],
|
||||
)
|
||||
@skipIfUnsupportedMinOpsetVersion(11) # Dynamic padding is added in opset 11
|
||||
def test_pad_types(self, pad):
|
||||
# Test for different pad integer types
|
||||
class Pad(torch.nn.Module):
|
||||
def forward(self, x, pad: List[int]):
|
||||
return torch.nn.functional.pad(x, pad)
|
||||
|
||||
x = torch.randn(2, 2, 4, 4)
|
||||
y = pad = [2, 4]
|
||||
self.run_test(Pad(), (x, y))
|
||||
|
||||
y = pad = [
|
||||
torch.tensor(2, dtype=torch.int64),
|
||||
torch.tensor(4, dtype=torch.int64),
|
||||
]
|
||||
self.run_test(Pad(), (x, y))
|
||||
self.run_test(Pad(), (x, pad))
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(10)
|
||||
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script
|
||||
|
@ -10,15 +10,17 @@ from torch._C import _onnx as _C_onnx
|
||||
# Import utils to get _params_dict because it is a global that is accessed by c++ code
|
||||
from torch.onnx import _deprecation, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
||||
|
||||
|
||||
# TODO(#78694): Refactor the patching process to make it more transparent to users.
|
||||
@_beartype.beartype
|
||||
def _graph_op(
|
||||
g: _C.Graph,
|
||||
opname: str,
|
||||
*raw_args: _C.Value,
|
||||
*raw_args: Union[torch.Tensor, _C.Value],
|
||||
outputs: int = 1,
|
||||
**kwargs,
|
||||
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
|
||||
@ -76,6 +78,7 @@ def _graph_op(
|
||||
return tuple(n.outputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _const_if_tensor(g: _C.Graph, arg):
|
||||
if arg is None:
|
||||
return arg
|
||||
@ -85,6 +88,7 @@ def _const_if_tensor(g: _C.Graph, arg):
|
||||
|
||||
|
||||
# Generate an ONNX ATen op node.
|
||||
@_beartype.beartype
|
||||
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
|
||||
return _graph_op(
|
||||
g,
|
||||
@ -96,7 +100,8 @@ def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwarg
|
||||
)
|
||||
|
||||
|
||||
def _block_op(b: _C.Block, opname: str, *args, **kwargs):
|
||||
@_beartype.beartype
|
||||
def _block_op(b: _C.Block, opname: str, *args: _C.Value, **kwargs):
|
||||
if "::" in opname:
|
||||
aten = False
|
||||
ns_opname = opname
|
||||
@ -115,8 +120,9 @@ def _block_op(b: _C.Block, opname: str, *args, **kwargs):
|
||||
return outputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _new_node(
|
||||
g: _C.Graph, namespace: str, op: str, outputs: int, *args, **kwargs
|
||||
g: _C.Graph, namespace: str, op: str, outputs: int, *args: _C.Value, **kwargs
|
||||
) -> _C.Node:
|
||||
"""Creates a new node in the graph.
|
||||
|
||||
@ -138,6 +144,7 @@ def _new_node(
|
||||
return node
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_onnx_list(value):
|
||||
return (
|
||||
not isinstance(value, torch._six.string_classes)
|
||||
@ -146,19 +153,22 @@ def _is_onnx_list(value):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _scalar(x: torch.Tensor):
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
assert x.numel() == 1
|
||||
return x[0]
|
||||
|
||||
|
||||
def _is_caffe2_aten_fallback():
|
||||
@_beartype.beartype
|
||||
def _is_caffe2_aten_fallback() -> bool:
|
||||
return (
|
||||
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
r"""Initializes the right attribute based on type of value."""
|
||||
m = _ATTR_PATTERN.match(key)
|
||||
@ -188,6 +198,7 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
@_deprecation.deprecated(
|
||||
"1.13", "1.14", "Use 'g.op()' to create a constant node instead."
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _graph_constant(
|
||||
g,
|
||||
value,
|
||||
|
@ -8,6 +8,7 @@ from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
ScalarName = Literal[
|
||||
"Byte",
|
||||
@ -80,6 +81,7 @@ class JitScalarType(enum.IntEnum):
|
||||
UNDEFINED = enum.auto() # 16
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_name(
|
||||
cls, name: Union[ScalarName, TorchName, Optional[str]]
|
||||
) -> JitScalarType:
|
||||
@ -104,30 +106,36 @@ class JitScalarType(enum.IntEnum):
|
||||
raise ValueError(f"Unknown torch or scalar type: '{name}'")
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_dtype(cls, dtype: torch.dtype) -> JitScalarType:
|
||||
"""Convert a torch dtype to ScalarType."""
|
||||
if dtype not in _DTYPE_TO_SCALAR_TYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
return _DTYPE_TO_SCALAR_TYPE[dtype]
|
||||
|
||||
@_beartype.beartype
|
||||
def scalar_name(self) -> ScalarName:
|
||||
"""Convert a ScalarType to a JIT scalar type name."""
|
||||
return _SCALAR_TYPE_TO_NAME[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def torch_name(self) -> TorchName:
|
||||
"""Convert a ScalarType to a torch type name."""
|
||||
return _SCALAR_TYPE_TO_TORCH_NAME[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""Convert a ScalarType to a torch dtype."""
|
||||
return _SCALAR_TYPE_TO_DTYPE[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def onnx_type(self) -> _C_onnx.TensorProtoDataType:
|
||||
"""Convert a ScalarType to an ONNX data type."""
|
||||
if self not in _SCALAR_TYPE_TO_ONNX:
|
||||
raise ValueError(f"Scalar type {self} cannot be converted to ONNX")
|
||||
return _SCALAR_TYPE_TO_ONNX[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def onnx_compatible(self) -> bool:
|
||||
"""Return whether this ScalarType is compatible with ONNX."""
|
||||
return (
|
||||
@ -137,11 +145,13 @@ class JitScalarType(enum.IntEnum):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def valid_scalar_name(scalar_name: Union[ScalarName, str]) -> bool:
|
||||
"""Return whether the given scalar name is a valid JIT scalar type name."""
|
||||
return scalar_name in _SCALAR_NAME_TO_TYPE
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
|
||||
"""Return whether the given torch name is a valid torch type name."""
|
||||
return torch_name in _TORCH_NAME_TO_SCALAR_TYPE
|
||||
|
@ -17,6 +17,7 @@ from torch import _C
|
||||
from torch.onnx import _constants, _patch_torch, _type_utils, errors # noqa: F401
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.types import Number
|
||||
|
||||
# Note [Edit Symbolic Files]
|
||||
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
|
||||
@ -32,6 +33,8 @@ from torch.onnx._internal import _beartype
|
||||
# `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
|
||||
# transparently dispatched to their non inplace versions in
|
||||
# "run_symbolic_function". See Note [Export inplace]
|
||||
# - REQUIRED: Annotate new symbolic functions with type annotations and decorate with
|
||||
# _beartype.beartype to enable runtime type checking.
|
||||
#
|
||||
# ----------------------------------------------------------------------------------
|
||||
# A note on Tensor types
|
||||
@ -94,6 +97,7 @@ _ValueDescriptor = Literal[
|
||||
]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _parse_arg(
|
||||
value,
|
||||
desc: _ValueDescriptor,
|
||||
@ -163,6 +167,7 @@ def _parse_arg(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_get(node: _C.Node, key: str):
|
||||
"""Gets attributes of a node which is polymorphic over return type."""
|
||||
assert isinstance(node, _C.Node)
|
||||
@ -170,19 +175,26 @@ def _node_get(node: _C.Node, key: str):
|
||||
return getattr(node, sel)(key)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_onnx_constant(value: _C.Value):
|
||||
"""Whether a Value is an ONNX constant."""
|
||||
return value.node().kind() == "onnx::Constant"
|
||||
|
||||
|
||||
def _maybe_get_const(value: _C.Value, descriptor: _ValueDescriptor):
|
||||
@_beartype.beartype
|
||||
def _maybe_get_const(
|
||||
value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]],
|
||||
descriptor: _ValueDescriptor,
|
||||
):
|
||||
# NOTE: prim::Constant at this stage usually means something not compatible in ONNX,
|
||||
# otherwise it'd be converted to onnx::Constant
|
||||
if _is_value(value) and _is_onnx_constant(value):
|
||||
# TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy
|
||||
if isinstance(value, _C.Value) and _is_onnx_constant(value):
|
||||
return _parse_arg(value, descriptor)
|
||||
return value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_get_scalar(value):
|
||||
value_t = _maybe_get_const(value, "t")
|
||||
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
|
||||
@ -190,6 +202,7 @@ def _maybe_get_scalar(value):
|
||||
return value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_const(value, desc, arg_name):
|
||||
if not _is_constant(value):
|
||||
raise errors.SymbolicValueError(
|
||||
@ -200,6 +213,7 @@ def _get_const(value, desc, arg_name):
|
||||
return _parse_arg(value, desc)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
|
||||
list_node = list_value.node()
|
||||
if list_node.kind() != "prim::ListConstruct":
|
||||
@ -211,6 +225,7 @@ def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
|
||||
return list(list_node.inputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
tuple_node = tuple_value.node()
|
||||
if not _is_tuple_construct(tuple_value):
|
||||
@ -222,6 +237,7 @@ def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
return tuple(tuple_node.inputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
"""Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
|
||||
Args:
|
||||
@ -245,10 +261,12 @@ def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
|
||||
# Check if list_value is output from prim::ListConstruct
|
||||
# This is usually called before _unpack_list to ensure the list can be unpacked.
|
||||
@_beartype.beartype
|
||||
def _is_packed_list(list_value: _C.Value) -> bool:
|
||||
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def parse_args(*arg_descriptors: _ValueDescriptor):
|
||||
"""A decorator which converts args from torch._C.Value to built-in types.
|
||||
|
||||
@ -326,6 +344,7 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
|
||||
return decorator
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def quantized_args(
|
||||
*arg_q_descriptors: bool,
|
||||
scale: Optional[float] = None,
|
||||
@ -430,13 +449,15 @@ def quantized_args(
|
||||
return decorator
|
||||
|
||||
|
||||
def _scalar(x: torch.Tensor):
|
||||
@_beartype.beartype
|
||||
def _scalar(x: Any) -> Optional[Number]:
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
if isinstance(x, torch.Tensor) and x.shape == ():
|
||||
return x.item()
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _if_scalar_type_as(g: _C.Graph, self, tensor):
|
||||
"""
|
||||
Convert self into the same type of tensor, as necessary.
|
||||
@ -455,14 +476,17 @@ def _if_scalar_type_as(g: _C.Graph, self, tensor):
|
||||
return self
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_none(x: _C.Value) -> bool:
|
||||
return x.node().mustBeNone()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_value(x: Any) -> bool:
|
||||
return isinstance(x, _C.Value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_constant(value: Any) -> bool:
|
||||
return not _is_value(value) or value.node().kind() in {
|
||||
"onnx::Constant",
|
||||
@ -470,20 +494,24 @@ def _is_constant(value: Any) -> bool:
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tensor(x: _C.Value) -> bool:
|
||||
return x.type().isSubtypeOf(_C.TensorType.get())
|
||||
|
||||
|
||||
# Note: _C.JitType is not exposed to Python and cannot be checked in runtime.
|
||||
def _as_list_type(jit_type: _C.JitType) -> Optional[_C.ListType]:
|
||||
if isinstance(jit_type, _C.ListType):
|
||||
return jit_type
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_list(x: _C.Value) -> bool:
|
||||
return _as_list_type(x.type()) is not None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tensor_list(x: _C.Value) -> bool:
|
||||
x_type = _as_list_type(x.type())
|
||||
if x_type is None:
|
||||
@ -491,6 +519,7 @@ def _is_tensor_list(x: _C.Value) -> bool:
|
||||
return isinstance(x_type.getElementType(), _C.TensorType)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_scalar_list(x: _C.Value) -> bool:
|
||||
"""Checks if x is a scalar list, for example: List[float], List[int].
|
||||
|
||||
@ -507,10 +536,12 @@ def _is_scalar_list(x: _C.Value) -> bool:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tuple_construct(x: _C.Value) -> bool:
|
||||
return x.node().kind() == "prim::TupleConstruct"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def is_caffe2_aten_fallback() -> bool:
|
||||
return (
|
||||
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
@ -548,6 +579,7 @@ def _get_tensor_dim_size(x: _C.Value, dim: int) -> Optional[int]:
|
||||
return sizes[dim] if sizes else None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
|
||||
if dim == -1:
|
||||
tensor_rank = _get_tensor_rank(x)
|
||||
@ -563,6 +595,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
|
||||
return dim
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
|
||||
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
|
||||
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
||||
@ -571,6 +604,7 @@ def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
|
||||
_onnx_unsupported(f"{op}, {msg}", value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoReturn:
|
||||
message = (
|
||||
f"Unsupported: ONNX export of operator {op_name}. "
|
||||
@ -585,6 +619,7 @@ def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoRetur
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_opset_unsupported(
|
||||
op_name: str,
|
||||
current_opset: int,
|
||||
@ -603,6 +638,7 @@ def _onnx_opset_unsupported(
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_opset_unsupported_detailed(
|
||||
op_name: str,
|
||||
current_opset: int,
|
||||
@ -622,6 +658,7 @@ def _onnx_opset_unsupported_detailed(
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _block_list_in_opset(name: str):
|
||||
def symbolic_fn(*args, **kwargs):
|
||||
raise errors.OnnxExporterError(
|
||||
@ -633,6 +670,7 @@ def _block_list_in_opset(name: str):
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_get_scalar_type(*args) -> Optional[str]:
|
||||
for arg in args:
|
||||
try:
|
||||
@ -642,6 +680,7 @@ def _try_get_scalar_type(*args) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _select_helper(g, self, dim, index, apply_reshape=True):
|
||||
index_const = _maybe_get_scalar(index)
|
||||
index_dim = _get_tensor_rank(index)
|
||||
@ -661,6 +700,7 @@ def _select_helper(g, self, dim, index, apply_reshape=True):
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||
if GLOBALS.export_onnx_opset_version <= 9:
|
||||
from torch.onnx.symbolic_opset9 import _slice as _slice9
|
||||
@ -672,6 +712,7 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
|
||||
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_in_type_group(value, scalar_types: Set[_type_utils.JitScalarType]) -> bool:
|
||||
"""Helper function for determining if a value is in a scalar type group."""
|
||||
if value is None:
|
||||
@ -696,6 +737,7 @@ def _is_in_type_group(value, scalar_types: Set[_type_utils.JitScalarType]) -> bo
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_fp(value) -> bool:
|
||||
return _is_in_type_group(
|
||||
value,
|
||||
@ -708,10 +750,12 @@ def _is_fp(value) -> bool:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_bool(value) -> bool:
|
||||
return _is_in_type_group(value, {_type_utils.JitScalarType.BOOL})
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _generate_wrapped_number(g, scalar):
|
||||
"""Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
|
||||
|
||||
@ -730,6 +774,7 @@ def _generate_wrapped_number(g, scalar):
|
||||
return g.op("Constant", value_t=torch.tensor(scalar))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _sort_helper(g, input, dim, decending=True, out=None):
|
||||
if out is not None:
|
||||
_unimplemented("Sort", "Out parameter is not supported")
|
||||
@ -749,6 +794,7 @@ def _sort_helper(g, input, dim, decending=True, out=None):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
|
||||
if out is not None:
|
||||
_unimplemented("TopK", "Out parameter is not supported")
|
||||
@ -768,6 +814,7 @@ def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _lt_helper(g, input, other):
|
||||
if GLOBALS.export_onnx_opset_version <= 8:
|
||||
from torch.onnx.symbolic_opset8 import lt as _lt8
|
||||
@ -779,6 +826,7 @@ def _lt_helper(g, input, other):
|
||||
return _lt9(g, input, other)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_warning(interpolate_mode):
|
||||
onnx_op = (
|
||||
"onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
|
||||
@ -796,6 +844,7 @@ def _interpolate_warning(interpolate_mode):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unsqueeze_helper(g, input, axes_i):
|
||||
if _is_constant(axes_i[0]):
|
||||
if GLOBALS.export_onnx_opset_version >= 13:
|
||||
@ -810,6 +859,7 @@ def _unsqueeze_helper(g, input, axes_i):
|
||||
return g.op("Unsqueeze", input, axes_i[0])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _squeeze_helper(g, input, axes_i):
|
||||
if _is_constant(axes_i[0]):
|
||||
if GLOBALS.export_onnx_opset_version >= 13:
|
||||
@ -835,6 +885,7 @@ def _squeeze_helper(g, input, axes_i):
|
||||
return g.op("Squeeze", input, axes_t)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
|
||||
keepdims_i = _maybe_get_const(keepdims_i, "i")
|
||||
if GLOBALS.export_onnx_opset_version >= 13:
|
||||
@ -860,6 +911,7 @@ def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_
|
||||
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_size_to_scales(g, input, output_size, dim):
|
||||
output_size = _maybe_get_const(output_size, "is")
|
||||
if _is_value(output_size):
|
||||
@ -886,6 +938,7 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
|
||||
return scales
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales_if_available(g, scales):
|
||||
available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
|
||||
scales[0]
|
||||
@ -902,6 +955,7 @@ def _interpolate_get_scales_if_available(g, scales):
|
||||
return scales
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_interpolate_attributes(g, mode, args):
|
||||
if mode == "nearest":
|
||||
align_corners = None
|
||||
@ -913,6 +967,7 @@ def _get_interpolate_attributes(g, mode, args):
|
||||
return scales, align_corners
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales(g, scale_factor, dim):
|
||||
offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
|
||||
scale_factor_rank = _get_tensor_rank(scale_factor)
|
||||
@ -930,6 +985,7 @@ def _interpolate_get_scales(g, scale_factor, dim):
|
||||
return scale_factor
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_corners):
|
||||
mode = _maybe_get_const(mode, "s")
|
||||
if "linear" in mode:
|
||||
@ -963,8 +1019,9 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_c
|
||||
return scale_factor, mode
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _argmin_argmax_helper(
|
||||
g, input: torch._C.Value, dim: torch._C.Value, keepdim: int, op_name: str
|
||||
g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, op_name: str
|
||||
):
|
||||
def op_wrapper(input, axis_i, keepdims_i):
|
||||
if GLOBALS.export_onnx_opset_version >= 12:
|
||||
@ -997,6 +1054,7 @@ def _argmin_argmax_helper(
|
||||
return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_helper(name, dim, interpolate_mode):
|
||||
@quantized_args(True, False, False)
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
@ -1064,6 +1122,7 @@ def _interpolate_helper(name, dim, interpolate_mode):
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __interpolate_helper(
|
||||
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
|
||||
):
|
||||
@ -1157,6 +1216,7 @@ def __interpolate_helper(
|
||||
) # only valid when mode="nearest"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unbind_helper(g, self, dim, _outputs):
|
||||
if GLOBALS.export_onnx_opset_version < 11:
|
||||
from torch.onnx.symbolic_opset9 import unbind
|
||||
@ -1167,6 +1227,7 @@ def _unbind_helper(g, self, dim, _outputs):
|
||||
return unbind(g, self, dim, _outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _scatter_helper(g, self, dim, index, src):
|
||||
if GLOBALS.export_onnx_opset_version <= 10:
|
||||
from torch.onnx.symbolic_opset9 import scatter
|
||||
@ -1176,6 +1237,7 @@ def _scatter_helper(g, self, dim, index, src):
|
||||
return scatter(g, self, dim, index, src)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _repeat_interleave_split_helper(g, self, reps, dim):
|
||||
if GLOBALS.export_onnx_opset_version <= 12:
|
||||
split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
|
||||
@ -1187,6 +1249,7 @@ def _repeat_interleave_split_helper(g, self, reps, dim):
|
||||
return split_out if reps > 1 else [split_out]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _arange_cast_helper(
|
||||
g, end, start=None, step=None, dtype=None
|
||||
) -> Tuple[
|
||||
@ -1228,6 +1291,7 @@ def _arange_cast_helper(
|
||||
return scalar_type, end, start, step
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _arange_helper(g, *args):
|
||||
if GLOBALS.export_onnx_opset_version <= 10:
|
||||
from torch.onnx.symbolic_opset9 import arange
|
||||
@ -1236,6 +1300,7 @@ def _arange_helper(g, *args):
|
||||
return arange(g, *args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _size_helper(g, self, dim):
|
||||
full_shape = g.op("Shape", self)
|
||||
from torch.onnx.symbolic_opset9 import select
|
||||
@ -1243,6 +1308,7 @@ def _size_helper(g, self, dim):
|
||||
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _index_fill_reshape_helper(g, self, dim, index):
|
||||
# 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
|
||||
# 2. expand index => [..., dim, ...], same shape as self except for dim.
|
||||
@ -1276,6 +1342,7 @@ def _index_fill_reshape_helper(g, self, dim, index):
|
||||
# allowzero=1 indicates that if any value in the 'shape' input is set to zero,
|
||||
# the zero value is honored, similar to NumPy.
|
||||
# allowzero=1 is only supported for opset version >= 14.
|
||||
@_beartype.beartype
|
||||
def _reshape_helper(g, input, shape, allowzero=0):
|
||||
shape = _maybe_get_const(shape, "is")
|
||||
if not _is_value(shape):
|
||||
@ -1290,6 +1357,7 @@ def _reshape_helper(g, input, shape, allowzero=0):
|
||||
return g.op("Reshape", input, shape, allowzero_i=allowzero)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
|
||||
from torch.onnx.symbolic_opset9 import _var_mean
|
||||
|
||||
@ -1349,6 +1417,7 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
|
||||
return weight, bias, running_mean, running_var
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _avgpool_helper(
|
||||
tuple_fn: Callable[[Any], Sequence[int]],
|
||||
padding: Union[int, Sequence[int]],
|
||||
@ -1362,6 +1431,7 @@ def _avgpool_helper(
|
||||
return tuple(tuple_fn(padding))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def check_training_mode(op_train_mode: int, op_name: str) -> None:
|
||||
"""Warns the user if the model's training mode and the export mode do not agree."""
|
||||
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
|
||||
@ -1387,6 +1457,7 @@ def check_training_mode(op_train_mode: int, op_name: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _flatten_helper(g, input, start_dim, end_dim, dim):
|
||||
input_size = g.op("Shape", input)
|
||||
slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
|
||||
@ -1407,6 +1478,7 @@ def _flatten_helper(g, input, start_dim, end_dim, dim):
|
||||
return _reshape_from_tensor(g, input, final_shape)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_split_static(split_size_or_sizes, _outputs):
|
||||
if _outputs is None:
|
||||
return False
|
||||
@ -1418,12 +1490,14 @@ def _is_split_static(split_size_or_sizes, _outputs):
|
||||
return True
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _optional_input_placeholder_tensor(g):
|
||||
n = g.op("prim::Constant")
|
||||
n.setType(_C.OptionalType.ofTensor())
|
||||
return n
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _handle_reduce_dim_none(g, self, op_name):
|
||||
rank = _get_tensor_rank(self)
|
||||
if rank is not None and any(
|
||||
@ -1435,10 +1509,11 @@ def _handle_reduce_dim_none(g, self, op_name):
|
||||
return g.op(op_name, self, keepdims_i=0)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def dequantize_helper(
|
||||
g,
|
||||
qtensor: _C.Value,
|
||||
qdtype: Optional[torch.onnx.TensorProtoDataType] = None,
|
||||
qdtype: Optional[_C_onnx.TensorProtoDataType] = None,
|
||||
) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
|
||||
"""Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
|
||||
|
||||
@ -1485,6 +1560,7 @@ def dequantize_helper(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def quantize_helper(
|
||||
g,
|
||||
tensor: _C.Value,
|
||||
@ -1538,6 +1614,7 @@ def quantize_helper(
|
||||
return g.op("prim::TupleConstruct", *args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def requantize_bias_helper(g, bias, input_scale, weight_scale, axis=None):
|
||||
"""In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel.
|
||||
In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized.
|
||||
@ -1558,6 +1635,7 @@ def requantize_bias_helper(g, bias, input_scale, weight_scale, axis=None):
|
||||
return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def args_have_same_dtype(args):
|
||||
assert args
|
||||
base_dtype = args[0].type().scalarType()
|
||||
|
@ -16,6 +16,7 @@ from torch.onnx import ( # noqa: F401
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
@ -58,6 +59,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def div(g, self, other, *args):
|
||||
if len(args) == 0:
|
||||
return opset9.true_divide(g, self, other)
|
||||
@ -66,6 +68,7 @@ def div(g, self, other, *args):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "s")
|
||||
@_beartype.beartype
|
||||
def _div_rounding_mode(g, self, other, rounding_mode):
|
||||
if rounding_mode == "floor":
|
||||
return _floor_divide(g, self, other)
|
||||
@ -73,6 +76,7 @@ def _div_rounding_mode(g, self, other, rounding_mode):
|
||||
return opset9._div_rounding_mode(g, self, other, rounding_mode)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _floor_divide(g, self, other):
|
||||
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||||
out = opset9.true_divide(g, self, other)
|
||||
@ -94,20 +98,24 @@ def _floor_divide(g, self, other):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def sort(g, self, dim, decending, out=None):
|
||||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def topk(g, self, k, dim, largest, sorted, out=None):
|
||||
return symbolic_helper._topk_helper(
|
||||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _max_pool(name, tuple_fn, ndims, return_indices):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
@ -178,9 +186,11 @@ max_pool3d_with_indices = _max_pool(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _avg_pool(name, tuple_fn):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input: _C.Value,
|
||||
@ -196,6 +206,7 @@ def _avg_pool(name, tuple_fn):
|
||||
padding = symbolic_helper._avgpool_helper(
|
||||
tuple_fn, padding, kernel_size, stride, divisor_override, name
|
||||
)
|
||||
assert isinstance(padding, tuple)
|
||||
if count_include_pad:
|
||||
input = opset9.op_with_optional_float_cast(
|
||||
g,
|
||||
@ -225,8 +236,10 @@ avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair)
|
||||
avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
@symbolic_helper.quantized_args(True, False, False)
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||||
g, interpolate_mode, args
|
||||
@ -252,6 +265,7 @@ upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
|
||||
upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __interpolate(
|
||||
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
):
|
||||
@ -261,6 +275,7 @@ def __interpolate(
|
||||
return g.op("Resize", input, scales, mode_s=mode)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||
if dynamic_slice:
|
||||
starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
|
||||
@ -288,6 +303,7 @@ def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||
return g.op("Slice", input, starts, ends, axes, steps)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def slice(g, self, *args):
|
||||
if len(args) == 4:
|
||||
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
|
||||
@ -336,6 +352,7 @@ def slice(g, self, *args):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "is")
|
||||
@_beartype.beartype
|
||||
def flip(g, input, dims):
|
||||
return symbolic_helper._slice_helper(
|
||||
g,
|
||||
@ -347,11 +364,13 @@ def flip(g, input, dims):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def fmod(g, input, other):
|
||||
return g.op("Mod", input, other, fmod_i=1)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g,
|
||||
embedding_matrix,
|
||||
@ -437,6 +456,7 @@ def embedding_bag(
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_tensor_affine(
|
||||
g, inputs, scale, zero_point, quant_min=-128, quant_max=127
|
||||
):
|
||||
@ -478,10 +498,12 @@ def fake_quantize_per_tensor_affine(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def isinf(g, input):
|
||||
return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def isfinite(g, input):
|
||||
from torch.onnx.symbolic_opset9 import __not_, __or_
|
||||
|
||||
@ -490,6 +512,7 @@ def isfinite(g, input):
|
||||
return __not_(g, __or_(g, inf_node, nan_node))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def quantize_per_tensor(g, input, scale, zero_point, dtype):
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
# TODO(justinchuby): Extract all the cast ops into a helper function.
|
||||
@ -500,11 +523,13 @@ def quantize_per_tensor(g, input, scale, zero_point, dtype):
|
||||
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def dequantize(g, input):
|
||||
return symbolic_helper.dequantize_helper(g, input)[0]
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "f", "f", "f")
|
||||
@_beartype.beartype
|
||||
def nan_to_num(g, input, nan, posinf, neginf):
|
||||
# Cannot create a int type tensor with inf/nan values, so we simply
|
||||
# return the original tensor
|
||||
@ -566,6 +591,7 @@ class Quantized:
|
||||
domain = "quantized"
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
|
||||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||||
@ -579,6 +605,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def add(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -588,6 +615,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def add_relu(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -598,6 +626,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def mul(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -607,6 +636,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def hardswish(g, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -615,6 +645,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def sigmoid(g, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -623,6 +654,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def leaky_relu(g, x, negative_slope, inplace, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -631,6 +663,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def layer_norm(g, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -639,6 +672,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def group_norm(g, x, num_groups, weight, bias, eps, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -648,6 +682,7 @@ class Quantized:
|
||||
|
||||
@staticmethod
|
||||
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
|
||||
@_beartype.beartype
|
||||
def instance_norm(
|
||||
g,
|
||||
q_input,
|
||||
@ -660,12 +695,13 @@ class Quantized:
|
||||
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||||
|
||||
output = opset9.instance_norm(
|
||||
g, input, weight, bias, None, None, False, 0, eps, False
|
||||
g, input, weight, bias, None, None, False, 0.0, eps, False
|
||||
)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def conv2d_relu(
|
||||
g,
|
||||
q_input,
|
||||
@ -693,6 +729,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def conv2d(
|
||||
g,
|
||||
q_input,
|
||||
@ -720,6 +757,7 @@ class Quantized:
|
||||
|
||||
@staticmethod
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def cat(
|
||||
g,
|
||||
q_inputs: _C.Value,
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -16,6 +16,7 @@ from torch.onnx import (
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
@ -94,6 +95,7 @@ __all__ = [
|
||||
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "f", "f")
|
||||
@_beartype.beartype
|
||||
def hardtanh(g, self: _C.Value, min_val: float, max_val: float):
|
||||
dtype = self.type().scalarType()
|
||||
if dtype is None:
|
||||
@ -113,9 +115,11 @@ def hardtanh(g, self: _C.Value, min_val: float, max_val: float):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def clamp(g, self, min, max):
|
||||
dtype = self.type().scalarType()
|
||||
|
||||
@_beartype.beartype
|
||||
def _cast_if_not_none(tensor, dtype):
|
||||
if tensor is not None and not symbolic_helper._is_none(tensor):
|
||||
return g.op(
|
||||
@ -147,6 +151,7 @@ def clamp(g, self, min, max):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def clamp_min(g, self, min):
|
||||
dtype = self.type().scalarType()
|
||||
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type())
|
||||
@ -160,6 +165,7 @@ def clamp_min(g, self, min):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def clamp_max(g, self, max):
|
||||
dtype = self.type().scalarType()
|
||||
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type())
|
||||
@ -172,6 +178,7 @@ def clamp_max(g, self, max):
|
||||
return opset9.op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def relu6(g, input):
|
||||
relu_ = opset9.op_with_optional_float_cast(g, "Relu", input, opset_before=14)
|
||||
dtype = input.type().scalarType()
|
||||
@ -193,10 +200,12 @@ def relu6(g, input):
|
||||
# Opset 11 gather accepts negative indices
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "i", "v")
|
||||
@_beartype.beartype
|
||||
def select(g, self, dim, index):
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def index_put(g, self, indices_list_value, values, accumulate=False):
|
||||
if symbolic_helper._is_packed_list(indices_list_value):
|
||||
indices_list = symbolic_helper._unpack_list(indices_list_value)
|
||||
@ -307,6 +316,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def pixel_shuffle(g, self, upscale_factor):
|
||||
rank = symbolic_helper._get_tensor_rank(self)
|
||||
if rank is not None and rank != 4:
|
||||
@ -314,6 +324,7 @@ def pixel_shuffle(g, self, upscale_factor):
|
||||
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
|
||||
|
||||
@ -336,6 +347,7 @@ upsample_bicubic2d.__module__ = "torch.onnx.symbolic_opset11"
|
||||
|
||||
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@_beartype.beartype
|
||||
def __interpolate(
|
||||
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
):
|
||||
@ -345,6 +357,7 @@ def __interpolate(
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def gather(g, self, dim, index, sparse_grad=False):
|
||||
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
|
||||
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
|
||||
@ -354,6 +367,7 @@ def gather(g, self, dim, index, sparse_grad=False):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter(g, self, dim, index, src):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("scatter", self, dim, index, src, overload_name="src")
|
||||
@ -378,6 +392,7 @@ def scatter(g, self, dim, index, src):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def cumsum(g, self, dim, dtype=None):
|
||||
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -391,11 +406,13 @@ def cumsum(g, self, dim, dtype=None):
|
||||
return csum
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def masked_select(g, self, mask):
|
||||
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
|
||||
return g.op("GatherND", self, index)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def masked_scatter(g, self, mask, source):
|
||||
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
|
||||
# NOTE: source can have more elements than needed.
|
||||
@ -413,6 +430,7 @@ def masked_scatter(g, self, mask, source):
|
||||
return g.op("ScatterND", self, index, source)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _len(g, self):
|
||||
if (
|
||||
symbolic_helper._is_tensor_list(self)
|
||||
@ -423,6 +441,7 @@ def _len(g, self):
|
||||
return symbolic_helper._squeeze_helper(g, sz_0, [0])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __getitem_(g, self, i):
|
||||
if symbolic_helper._is_tensor_list(self):
|
||||
# SequenceAt requires that the input be a List of Tensors
|
||||
@ -433,15 +452,18 @@ def __getitem_(g, self, i):
|
||||
return getitem(g, self, i)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _set_item(g, tensor_list, i, v):
|
||||
tensor_list = g.op("SequenceErase", tensor_list, i)
|
||||
return g.op("SequenceInsert", tensor_list, v, i)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def append(g, self, tensor):
|
||||
return g.op("SequenceInsert", self, tensor)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def add(g, self, other, alpha=None):
|
||||
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
|
||||
tensor_list_node = other.node()
|
||||
@ -458,18 +480,22 @@ def add(g, self, other, alpha=None):
|
||||
return opset9.add(g, self, other, alpha)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def insert(g, self, pos, tensor):
|
||||
return g.op("SequenceInsert", self, tensor, pos)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def pop(g, tensor_list, dim):
|
||||
return g.op("SequenceErase", tensor_list, dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def Delete(g, tensor_list, dim):
|
||||
return g.op("SequenceErase", tensor_list, dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def cat(g, tensor_list, dim):
|
||||
if symbolic_helper._is_packed_list(tensor_list):
|
||||
return opset9.cat(g, tensor_list, dim)
|
||||
@ -478,6 +504,7 @@ def cat(g, tensor_list, dim):
|
||||
return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def stack(g, tensor_list, dim):
|
||||
if symbolic_helper._is_packed_list(tensor_list):
|
||||
return opset9.stack(g, tensor_list, dim)
|
||||
@ -487,6 +514,7 @@ def stack(g, tensor_list, dim):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _unique2(g, self, sorted, return_inverse, return_counts):
|
||||
u, indices, inverse_indices, counts = g.op(
|
||||
"Unique", self, sorted_i=sorted, outputs=4
|
||||
@ -494,15 +522,17 @@ def _unique2(g, self, sorted, return_inverse, return_counts):
|
||||
return u, inverse_indices, counts
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _avg_pool(name, tuple_fn):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input: _C.Value,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[int, Tuple[int, ...]],
|
||||
kernel_size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
padding: Union[int, Sequence[int]],
|
||||
ceil_mode: int,
|
||||
count_include_pad: int,
|
||||
divisor_override=None,
|
||||
@ -510,6 +540,7 @@ def _avg_pool(name, tuple_fn):
|
||||
padding = symbolic_helper._avgpool_helper(
|
||||
tuple_fn, padding, kernel_size, stride, divisor_override, name
|
||||
)
|
||||
assert isinstance(padding, tuple)
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
if count_include_pad:
|
||||
@ -539,6 +570,7 @@ avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
|
||||
u, indices, inverse_indices, counts = g.op(
|
||||
"Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
|
||||
@ -547,6 +579,7 @@ def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def topk(g, self, k, dim, largest, sorted, out=None):
|
||||
return symbolic_helper._topk_helper(
|
||||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||||
@ -554,11 +587,13 @@ def topk(g, self, k, dim, largest, sorted, out=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def sort(g, self, dim, decending, out=None):
|
||||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def argsort(g, self, dim, decending, out=None):
|
||||
_, indices = symbolic_helper._sort_helper(
|
||||
g, self, dim, decending=decending, out=out
|
||||
@ -566,10 +601,12 @@ def argsort(g, self, dim, decending, out=None):
|
||||
return indices
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def round(g, self):
|
||||
return g.op("Round", self)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def remainder(g, input, other):
|
||||
if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
|
||||
return opset9.remainder(g, input, other)
|
||||
@ -577,6 +614,7 @@ def remainder(g, input, other):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split(g, self, split_size_or_sizes, dim, _outputs=None):
|
||||
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
|
||||
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
|
||||
@ -614,11 +652,13 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
|
||||
return split(g, self, split_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unbind(g, self, dim=0, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -638,6 +678,7 @@ def unbind(g, self, dim=0, _outputs=None):
|
||||
# pad: the paddings in pytorch.
|
||||
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
|
||||
# where m is in range [0, n].
|
||||
@_beartype.beartype
|
||||
def _prepare_onnx_paddings(g, input, pad):
|
||||
if (
|
||||
not symbolic_helper._is_packed_list(pad)
|
||||
@ -687,6 +728,7 @@ def _prepare_onnx_paddings(g, input, pad):
|
||||
return padding_c
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def constant_pad_nd(g, input, padding, value=None):
|
||||
mode = "constant"
|
||||
value = symbolic_helper._maybe_get_scalar(value)
|
||||
@ -695,12 +737,14 @@ def constant_pad_nd(g, input, padding, value=None):
|
||||
return g.op("Pad", input, pad, value, mode_s=mode)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def reflection_pad(g, input, padding):
|
||||
mode = "reflect"
|
||||
paddings = _prepare_onnx_paddings(g, input, padding)
|
||||
return g.op("Pad", input, paddings, mode_s=mode)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def replication_pad(g, input, padding):
|
||||
mode = "edge"
|
||||
paddings = _prepare_onnx_paddings(g, input, padding)
|
||||
@ -715,6 +759,7 @@ replication_pad2d = replication_pad
|
||||
replication_pad3d = replication_pad
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def pad(g, input, pad, mode, value):
|
||||
mode = symbolic_helper._parse_arg(mode, "s")
|
||||
if mode == "replicate":
|
||||
@ -729,15 +774,19 @@ def pad(g, input, pad, mode, value):
|
||||
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def linalg_det(g, self):
|
||||
return g.op("Det", self)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def logdet(g, input):
|
||||
return opset9.log(g, linalg_det(g, input))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def arange(g, *args):
|
||||
@_beartype.beartype
|
||||
def _get_arange_dtype(dtype):
|
||||
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
||||
return dtype
|
||||
@ -790,6 +839,7 @@ def arange(g, *args):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def _dim_arange(g, like, dim):
|
||||
like_shape = g.op("Shape", like)
|
||||
stop = g.op(
|
||||
@ -800,12 +850,14 @@ def _dim_arange(g, like, dim):
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def size(g, self, dim=None):
|
||||
if dim is None:
|
||||
return g.op("Shape", self)
|
||||
return symbolic_helper._size_helper(g, self, dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def squeeze(g, self, dim=None):
|
||||
if dim is None:
|
||||
return g.op("Squeeze", self)
|
||||
@ -857,6 +909,7 @@ def squeeze(g, self, dim=None):
|
||||
return symbolic_helper._squeeze_helper(g, self, [dim])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unsqueeze(g, self, dim):
|
||||
if symbolic_helper._is_constant(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
@ -864,10 +917,12 @@ def unsqueeze(g, self, dim):
|
||||
return symbolic_helper._unsqueeze_helper(g, self, [dim])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def mm(g, self, other):
|
||||
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def index(g, self, index):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("index", self, index, overload_name="Tensor")
|
||||
@ -888,6 +943,7 @@ def index(g, self, index):
|
||||
return opset9.index(g, self, index)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def index_fill(g, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
@ -909,6 +965,7 @@ def index_fill(g, self, dim, index, value):
|
||||
return scatter(g, self, dim, expanded_index, expanded_value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def index_copy(g, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
@ -919,6 +976,7 @@ def index_copy(g, self, dim, index, source):
|
||||
return scatter(g, self, dim, expanded_index, source)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __rshift_(g, self, other):
|
||||
# make sure to cast other to self's type
|
||||
# (when self is long, make sure that other is not float)
|
||||
@ -948,6 +1006,7 @@ def __rshift_(g, self, other):
|
||||
return rshift
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __lshift_(g, self, other):
|
||||
# make sure to cast other to self's type
|
||||
# (when self is long, make sure that other is not float)
|
||||
@ -977,6 +1036,7 @@ def __lshift_(g, self, other):
|
||||
return lshift
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_indices_along_dim(
|
||||
g, input_d, kernel_size_d, dilation_d, padding_d, stride_d
|
||||
):
|
||||
@ -1020,6 +1080,7 @@ def _get_im2col_indices_along_dim(
|
||||
return block_mask
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_padded_input(g, input, padding_h, padding_w):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# Padding tensor has the following format: (padding_h, padding_w)
|
||||
@ -1028,6 +1089,7 @@ def _get_im2col_padded_input(g, input, padding_h, padding_w):
|
||||
return g.op("Pad", input, pad)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
|
||||
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
|
||||
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
|
||||
@ -1045,6 +1107,7 @@ def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is")
|
||||
@_beartype.beartype
|
||||
def im2col(g, input, kernel_size, dilation, padding, stride):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# All other args are int[2]
|
||||
@ -1096,6 +1159,7 @@ def im2col(g, input, kernel_size, dilation, padding, stride):
|
||||
return symbolic_helper._reshape_helper(g, output, output_shape)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def narrow(g, input, dim, start, length):
|
||||
end = g.op("Add", start, length)
|
||||
return symbolic_helper._slice_helper(
|
||||
@ -1105,6 +1169,7 @@ def narrow(g, input, dim, start, length):
|
||||
|
||||
@symbolic_helper.quantized_args(True, False, False)
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def flatten(g, input, start_dim, end_dim):
|
||||
dim = symbolic_helper._get_tensor_rank(input)
|
||||
if dim == 1:
|
||||
@ -1129,14 +1194,17 @@ def flatten(g, input, start_dim, end_dim):
|
||||
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "f", "is", "i", "v")
|
||||
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def linalg_vector_norm(
|
||||
g, self, ord, dim: Optional[Sequence[int]], keepdim: bool, dtype
|
||||
):
|
||||
if ord == 0:
|
||||
if dim is None:
|
||||
self = symbolic_helper._reshape_helper(
|
||||
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
)
|
||||
keepdim = 0
|
||||
keepdim = False
|
||||
|
||||
cond_op = g.op(
|
||||
"Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
@ -1156,6 +1224,7 @@ def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g,
|
||||
embedding_matrix,
|
||||
@ -1243,6 +1312,7 @@ def embedding_bag(
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "f", "f")
|
||||
@_beartype.beartype
|
||||
def embedding_renorm(g, weight, indices, max_norm, norm_type):
|
||||
unique_indices = g.op("Unique", indices)
|
||||
partial_weight = g.op("Gather", weight, unique_indices)
|
||||
@ -1280,6 +1350,7 @@ def embedding_renorm(g, weight, indices, max_norm, norm_type):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def chunk(g, self, chunks, dim):
|
||||
# Calculate chunk size for dynamic chunk
|
||||
dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
|
||||
@ -1296,6 +1367,7 @@ def chunk(g, self, chunks, dim):
|
||||
return split(g, self, chunk_vec, dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def normal(
|
||||
g,
|
||||
mean,
|
||||
@ -1322,6 +1394,7 @@ class Prim:
|
||||
domain = "prim"
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def ConstantChunk(g, self, chunks, dim):
|
||||
input_shape = g.op("Shape", self)
|
||||
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
|
@ -10,6 +10,7 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
@ -38,6 +39,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _einsum_helper(g, equation, tensors):
|
||||
if not tensors:
|
||||
raise RuntimeError("Einsum inputs are empty.")
|
||||
@ -57,12 +59,14 @@ def _einsum_helper(g, equation, tensors):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("s", "v")
|
||||
@_beartype.beartype
|
||||
def einsum(g, equation, tensor_list):
|
||||
tensors = symbolic_helper._unpack_list(tensor_list)
|
||||
return _einsum_helper(g, equation, tensors)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def outer(g, input, other):
|
||||
# make sure to cast other to self's type
|
||||
if other.type().scalarType() != input.type().scalarType():
|
||||
@ -76,6 +80,7 @@ def outer(g, input, other):
|
||||
return _einsum_helper(g, "i,j->ij", [input, other])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _dropout_returns_masked_input_and_mask(
|
||||
g, input: torch._C.Value, p: float, train: bool
|
||||
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
|
||||
@ -90,17 +95,20 @@ def _dropout_returns_masked_input_and_mask(
|
||||
return r, mask
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "f", "i")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
@_beartype.beartype
|
||||
def dropout(g, input, p, train):
|
||||
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
return masked
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "f", "i")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
@_beartype.beartype
|
||||
def native_dropout(g, input, p, train):
|
||||
return _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def nll_loss(g, self, target, weight, reduction, ignore_index):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
@ -133,14 +141,17 @@ def nll_loss(g, self, target, weight, reduction, ignore_index):
|
||||
return nllloss
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def nll_loss2d(g, self, target, weight, reduction, ignore_index):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def nll_loss_nd(g, self, target, weight, reduction, ignore_index):
|
||||
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def cross_entropy_loss(
|
||||
g, self, target, weight, reduction, ignore_index, label_smoothing
|
||||
):
|
||||
@ -182,6 +193,7 @@ def cross_entropy_loss(
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction):
|
||||
p = g.op("Constant", value_t=torch.tensor([1]))
|
||||
sig_x = opset9.sigmoid(g, input)
|
||||
@ -223,6 +235,7 @@ def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduc
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def celu(g, self, alpha):
|
||||
alpha = symbolic_helper._maybe_get_const(alpha, "f")
|
||||
# if the input is of type double cast it to float
|
||||
@ -234,29 +247,35 @@ def celu(g, self, alpha):
|
||||
return g.op("Celu", self, alpha_f=alpha)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: int):
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: int):
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool):
|
||||
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def pow(g, self, exponent):
|
||||
return g.op("Pow", self, exponent)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def ge(g, input, other):
|
||||
return g.op("GreaterOrEqual", input, other)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def le(g, input, other):
|
||||
return g.op("LessOrEqual", input, other)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def unfold(g, input, dimension, size, step):
|
||||
const_size = symbolic_helper._maybe_get_const(size, "i")
|
||||
const_step = symbolic_helper._maybe_get_const(step, "i")
|
||||
@ -326,6 +345,7 @@ def unfold(g, input, dimension, size, step):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
|
||||
@_beartype.beartype
|
||||
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
|
||||
if out is not None:
|
||||
symbolic_helper._unimplemented(
|
||||
|
@ -12,9 +12,11 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def softmax(g, input, dim, dtype=None):
|
||||
softmax = g.op("Softmax", input, axis_i=dim)
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -27,6 +29,7 @@ def softmax(g, input, dim, dtype=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def log_softmax(g, input, dim, dtype=None):
|
||||
return_op = g.op("LogSoftmax", input, axis_i=dim)
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -38,6 +41,7 @@ def log_softmax(g, input, dim, dtype=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def frobenius_norm(g, self, dim=None, keepdim=False):
|
||||
dim_val = symbolic_helper._maybe_get_const(dim, "is")
|
||||
if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
|
||||
@ -48,6 +52,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split(g, self, split_size_or_sizes, dim, _outputs=None):
|
||||
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
|
||||
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
|
||||
@ -103,19 +108,23 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
|
||||
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
|
||||
return split(g, self, split_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
|
||||
return split(g, self, split_size_or_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
|
||||
return split_with_sizes(g, self, split_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
|
||||
axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
axis = opset11.unsqueeze(g, axis, 0)
|
||||
@ -247,6 +256,7 @@ def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unbind(g, self, dim=0, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -268,11 +278,13 @@ def unbind(g, self, dim=0, _outputs=None):
|
||||
|
||||
|
||||
# Emitted from `torch.nonzero(x, as_tuple=True)`
|
||||
@_beartype.beartype
|
||||
def nonzero_numpy(g, input, _outputs=None):
|
||||
return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def where(g, condition, self=None, other=None, _outputs=None):
|
||||
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
|
||||
if not symbolic_helper._is_bool(condition):
|
||||
@ -286,6 +298,7 @@ def where(g, condition, self=None, other=None, _outputs=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_channel_affine(
|
||||
g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127
|
||||
):
|
||||
@ -314,6 +327,7 @@ def fake_quantize_per_channel_affine(
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_tensor_affine(
|
||||
g, inputs, scale, zero_point, quant_min=-128, quant_max=127
|
||||
):
|
||||
@ -342,7 +356,9 @@ def fake_quantize_per_tensor_affine(
|
||||
return g.op("DequantizeLinear", quantized, scale, zero_point)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_op_symbolic(onnx_op_name):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = opset9._maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None:
|
||||
@ -355,12 +371,15 @@ def _reduce_op_symbolic(onnx_op_name):
|
||||
return symbolic
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op, name):
|
||||
symbolic = _reduce_op_symbolic(onnx_op)
|
||||
|
||||
@opset9.overload_by_arg_count
|
||||
@_beartype.beartype
|
||||
def reduce(g, *args, **kwargs):
|
||||
@symbolic_helper.parse_args("v", "none")
|
||||
@_beartype.beartype
|
||||
def reduce_nodim(g, self, dtype):
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
@ -372,6 +391,7 @@ def _reduce_with_dtype(onnx_op, name):
|
||||
return symbolic(g, self)
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def reduce_dim(g, self, dim, keepdim, dtype):
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
@ -392,6 +412,7 @@ sum = _reduce_with_dtype("ReduceSum", "sum")
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -418,6 +439,7 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None):
|
||||
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
||||
input = self
|
||||
final_dim = dim
|
||||
@ -551,6 +573,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def diagonal(g, self, offset, dim1, dim2):
|
||||
dim1_size = opset9.size(
|
||||
g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
|
||||
@ -674,6 +697,7 @@ class Quantized:
|
||||
domain = "quantized"
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
|
||||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||||
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
|
||||
@ -687,6 +711,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def conv2d(
|
||||
g,
|
||||
q_input,
|
||||
@ -713,6 +738,7 @@ class Quantized:
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def conv2d_relu(
|
||||
g,
|
||||
q_input,
|
||||
|
@ -18,26 +18,31 @@ Updated operators:
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v")
|
||||
@_beartype.beartype
|
||||
def hardswish(g, self):
|
||||
return g.op("HardSwish", self)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def tril(g, self, diagonal, out=None):
|
||||
k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
|
||||
return g.op("Trilu", self, k, upper_i=0)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def triu(g, self, diagonal, out=None):
|
||||
k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
|
||||
return g.op("Trilu", self, k, upper_i=1)
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def reshape(g, self, shape):
|
||||
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
|
||||
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
|
||||
@ -45,6 +50,7 @@ def reshape(g, self, shape):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
|
||||
@_beartype.beartype
|
||||
def batch_norm(
|
||||
g,
|
||||
input,
|
||||
@ -107,6 +113,7 @@ class Quantized:
|
||||
domain = "quantized"
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def hardswish(g, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
|
@ -28,8 +28,10 @@ Updated operators:
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __is_(g, self, other):
|
||||
if symbolic_helper._is_none(other):
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
@ -41,6 +43,7 @@ def __is_(g, self, other):
|
||||
|
||||
|
||||
@opset9.wrap_logical_op_with_negation
|
||||
@_beartype.beartype
|
||||
def __isnot_(g, self, other):
|
||||
return __is_(g, self, other)
|
||||
|
||||
@ -49,6 +52,7 @@ class Prim:
|
||||
domain = "prim"
|
||||
|
||||
@staticmethod
|
||||
@_beartype.beartype
|
||||
def unchecked_cast(g, self):
|
||||
# exists to refine the type of the Value
|
||||
# if x is Optional[Tensor], unchecked_cast will cast
|
||||
|
@ -30,11 +30,13 @@ from torch.nn.functional import (
|
||||
GRID_SAMPLE_PADDING_MODES,
|
||||
)
|
||||
from torch.onnx import _type_utils, symbolic_helper
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
||||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
@_beartype.beartype
|
||||
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
|
||||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
||||
@ -49,6 +51,7 @@ def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter_add(g, self, dim, index, src):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("scatter", self, dim, index, src, overload_name="src")
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -77,6 +77,7 @@ _params_dict = {} # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
||||
r"""A context manager to temporarily set the training mode of ``model``
|
||||
to ``mode``, resetting it when we exit the with-block.
|
||||
@ -128,6 +129,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def disable_apex_o2_state_dict_hook(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction]
|
||||
):
|
||||
@ -161,7 +163,8 @@ def disable_apex_o2_state_dict_hook(
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_onnx_logging(verbose):
|
||||
@_beartype.beartype
|
||||
def setup_onnx_logging(verbose: bool):
|
||||
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
|
||||
if is_originally_enabled or verbose:
|
||||
torch.onnx.enable_log()
|
||||
@ -173,7 +176,8 @@ def setup_onnx_logging(verbose):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def exporter_context(model, mode, verbose):
|
||||
@_beartype.beartype
|
||||
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
|
||||
with select_model_mode_for_export(
|
||||
model, mode
|
||||
) as mode_ctx, disable_apex_o2_state_dict_hook(
|
||||
@ -498,6 +502,7 @@ def export(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_constant_tensor_list(node):
|
||||
if node.kind() != "prim::Constant":
|
||||
return False
|
||||
@ -512,6 +517,7 @@ def _is_constant_tensor_list(node):
|
||||
# get generated in constant prop. So we split them back into prim::ListConstructs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _split_tensor_list_constants(g, block):
|
||||
for node in block.nodes():
|
||||
for subblock in node.blocks():
|
||||
@ -534,6 +540,7 @@ def _split_tensor_list_constants(g, block):
|
||||
node.output().replaceAllUsesWith(lc)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _optimize_graph(
|
||||
graph: _C.Graph,
|
||||
operator_export_type: _C_onnx.OperatorExportTypes,
|
||||
@ -655,6 +662,7 @@ def _optimize_graph(
|
||||
return graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def warn_on_static_input_change(input_states):
|
||||
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
|
||||
|
||||
@ -684,6 +692,7 @@ def warn_on_static_input_change(input_states):
|
||||
warnings.warn(warning)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
|
||||
if (
|
||||
@ -700,6 +709,7 @@ def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||
return arg_value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_keep_init_as_input(
|
||||
keep_initializers_as_inputs: Optional[bool],
|
||||
operator_export_type: _C_onnx.OperatorExportTypes,
|
||||
@ -743,12 +753,14 @@ def _decide_keep_init_as_input(
|
||||
return val_keep_init_as_ip
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_add_node_names(add_node_names, operator_export_type):
|
||||
return _resolve_args_by_export_type(
|
||||
"add_node_names", add_node_names, operator_export_type
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
|
||||
do_constant_folding = _resolve_args_by_export_type(
|
||||
"do_constant_folding", do_constant_folding, operator_export_type
|
||||
@ -767,6 +779,7 @@ def _decide_constant_folding(do_constant_folding, operator_export_type, training
|
||||
return do_constant_folding
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _signature(model) -> inspect.Signature:
|
||||
should_be_callable = getattr(model, "forward", model)
|
||||
if callable(should_be_callable):
|
||||
@ -774,6 +787,7 @@ def _signature(model) -> inspect.Signature:
|
||||
raise ValueError("model has no forward method and is not callable")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_input_format(model, args):
|
||||
try:
|
||||
sig = _signature(model)
|
||||
@ -813,6 +827,7 @@ def _decide_input_format(model, args):
|
||||
return args
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace(func, args, operator_export_type, return_outs=False):
|
||||
# Special case for common case of passing a single Tensor
|
||||
if isinstance(args, torch.Tensor):
|
||||
@ -829,6 +844,7 @@ def _trace(func, args, operator_export_type, return_outs=False):
|
||||
return trace_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace_and_get_graph_from_model(model, args):
|
||||
# A basic sanity check: make sure the state_dict keys are the same
|
||||
# before and after running the model. Fail fast!
|
||||
@ -857,6 +873,7 @@ def _trace_and_get_graph_from_model(model, args):
|
||||
return trace_graph, torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_param_count_list(method_graph, args_params):
|
||||
param_count_list = []
|
||||
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
|
||||
@ -869,9 +886,11 @@ def _get_param_count_list(method_graph, args_params):
|
||||
return param_count_list
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _check_flatten_did_not_remove(original, jit_flattened):
|
||||
"""torch.jit._flatten removes None. Check if it did so in this case."""
|
||||
|
||||
@_beartype.beartype
|
||||
def flatten(x):
|
||||
if isinstance(x, (list, tuple)):
|
||||
for inner in x:
|
||||
@ -949,6 +968,7 @@ def _create_jit_graph(
|
||||
return graph, params, torch_out, None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_named_param_dict(graph, params):
|
||||
input_and_param_names = [val.debugName() for val in graph.inputs()]
|
||||
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
|
||||
@ -956,6 +976,7 @@ def _get_named_param_dict(graph, params):
|
||||
return _params_dict
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_example_outputs(model, args):
|
||||
input_args = copy.deepcopy(args)
|
||||
input_kwargs = {}
|
||||
@ -980,6 +1001,7 @@ _qtype_vtype_map = {
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unpack_quantized_tensor(value, cast_onnx_accepted=True):
|
||||
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
|
||||
q_value_dequantize = value.dequantize()
|
||||
@ -1000,6 +1022,7 @@ def unpack_quantized_tensor(value, cast_onnx_accepted=True):
|
||||
return (value,)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _pre_trace_quant_model(model, args):
|
||||
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
|
||||
original model.
|
||||
@ -1013,6 +1036,7 @@ def _pre_trace_quant_model(model, args):
|
||||
return model
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _model_to_graph(
|
||||
model,
|
||||
args,
|
||||
@ -1028,7 +1052,15 @@ def _model_to_graph(
|
||||
) -> Tuple[
|
||||
_C.Graph,
|
||||
Dict[str, torch.Tensor],
|
||||
Optional[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]],
|
||||
Optional[
|
||||
Union[
|
||||
torch.Tensor,
|
||||
Tuple[torch.Tensor, ...],
|
||||
List[torch.Tensor],
|
||||
Dict[str, torch.Tensor],
|
||||
Any, # Can be nested tuples etc.
|
||||
]
|
||||
],
|
||||
]:
|
||||
"""Converts model into an ONNX graph.
|
||||
|
||||
@ -1133,6 +1165,7 @@ def _model_to_graph(
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_to_pretty_string(
|
||||
model,
|
||||
args,
|
||||
@ -1209,6 +1242,7 @@ def export_to_pretty_string(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unconvertible_ops(
|
||||
model, args, training=_C_onnx.TrainingMode.EVAL, opset_version=None
|
||||
):
|
||||
@ -1249,6 +1283,7 @@ def unconvertible_ops(
|
||||
return graph, unsupported_ops
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _setup_trace_module_map(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
|
||||
@ -1326,11 +1361,13 @@ def _setup_trace_module_map(
|
||||
return module_typenames
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reset_trace_module_map():
|
||||
torch.jit._trace._trace_module_map = None
|
||||
_C._jit_pass_onnx_clear_scope_records()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_module_attributes(module):
|
||||
|
||||
annotations = typing.get_type_hints(type(module))
|
||||
@ -1339,6 +1376,7 @@ def _get_module_attributes(module):
|
||||
return {k: getattr(module, k) for k in annotations}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export(
|
||||
model,
|
||||
args,
|
||||
@ -1561,6 +1599,7 @@ def _export(
|
||||
return torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _apply_friendly_debug_names(graph, params):
|
||||
for n in graph.nodes():
|
||||
for v in n.inputs():
|
||||
@ -1573,7 +1612,9 @@ def _apply_friendly_debug_names(graph, params):
|
||||
params[new_name] = params.pop(old_name)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _set_input_and_output_names(graph, input_names, output_names):
|
||||
@_beartype.beartype
|
||||
def set_names(node_list, name_list, descriptor):
|
||||
if name_list is None:
|
||||
return
|
||||
@ -1604,6 +1645,7 @@ def _set_input_and_output_names(graph, input_names, output_names):
|
||||
set_names(list(graph.outputs()), output_names, "output")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_symbolic_method(g, op_name, symbolic_fn, args):
|
||||
r"""
|
||||
This trampoline function gets invoked for every symbolic method
|
||||
@ -1619,14 +1661,17 @@ def _run_symbolic_method(g, op_name, symbolic_fn, args):
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_block(node: _C.Node):
|
||||
return node.addBlock() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_input_to_block(block: _C.Block):
|
||||
return block.addInputToBlock() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_output_to_block(block: _C.Block, value: _C.Value):
|
||||
new_output = block.registerOutput(value) # type: ignore[attr-defined]
|
||||
return new_output
|
||||
@ -1641,6 +1686,7 @@ def _add_output_to_block(block: _C.Block, value: _C.Value):
|
||||
# inplace annotations, but we are losing information this way.
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_symbolic_in_registry(
|
||||
domain: str,
|
||||
op_name: str,
|
||||
@ -1666,6 +1712,7 @@ def _find_symbolic_in_registry(
|
||||
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
|
||||
|
||||
is_exportable_aten_op = symbolic_registry.is_registered_op(
|
||||
@ -1680,6 +1727,7 @@ def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _need_symbolic_context(symbolic_fn) -> bool:
|
||||
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
|
||||
params = tuple(inspect.signature(symbolic_fn).parameters.values())
|
||||
@ -1695,6 +1743,7 @@ def _need_symbolic_context(symbolic_fn) -> bool:
|
||||
return issubclass(param_type, _exporter_states.SymbolicContext)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
|
||||
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
||||
@ -1704,6 +1753,7 @@ def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
return _C.parse_schema(schema).overload_name
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_symbolic_function(
|
||||
g: _C.Graph,
|
||||
block: _C.Block,
|
||||
@ -1711,7 +1761,7 @@ def _run_symbolic_function(
|
||||
inputs: Any,
|
||||
env: Dict[_C.Value, _C.Value],
|
||||
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
||||
) -> Optional[Union[_C.Value, Tuple[_C.Value, ...]]]:
|
||||
) -> Optional[Union[_C.Value, Sequence[Optional[_C.Value]]]]:
|
||||
"""Runs a symbolic function.
|
||||
|
||||
The function is used in C++ to export the node to ONNX.
|
||||
@ -1815,6 +1865,7 @@ def _run_symbolic_function(
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def get_ns_op_name_from_custom_op(symbolic_name):
|
||||
if not bool(
|
||||
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
|
||||
@ -1838,6 +1889,7 @@ def get_ns_op_name_from_custom_op(symbolic_name):
|
||||
return ns, op_name
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
|
||||
"""Registers a symbolic function for a custom operator.
|
||||
|
||||
@ -1865,6 +1917,7 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
|
||||
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
"""Unregisters ``symbolic_name``.
|
||||
|
||||
@ -1884,6 +1937,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
symbolic_registry.unregister_op(op_name, ns, version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
||||
"""Ensures dynamic axes argument is follows the expected format."""
|
||||
if len(dynamic_axes) == 0:
|
||||
|
@ -13,7 +13,7 @@ import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -22,10 +22,15 @@ import torch._C._onnx as _C_onnx
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, _experimental, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.types import Number
|
||||
|
||||
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
||||
|
||||
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _flatten_tuples(elem):
|
||||
flattened = []
|
||||
for t in elem:
|
||||
@ -36,7 +41,8 @@ def _flatten_tuples(elem):
|
||||
return flattened
|
||||
|
||||
|
||||
def _to_numpy(elem):
|
||||
# TODO(justinchuby): Add type checking by narrowing down the return type when input is None
|
||||
def _to_numpy(elem) -> Union[list, np.ndarray]:
|
||||
if isinstance(elem, torch.Tensor):
|
||||
if elem.requires_grad:
|
||||
return elem.detach().cpu().numpy()
|
||||
@ -49,12 +55,13 @@ def _to_numpy(elem):
|
||||
elif isinstance(elem, dict):
|
||||
flattened = []
|
||||
for k in elem:
|
||||
flattened += [_to_numpy(k)] + [_to_numpy(elem[k])]
|
||||
flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
|
||||
return flattened
|
||||
return elem
|
||||
|
||||
|
||||
def _inline_flatten_list(inputs, res_list):
|
||||
@_beartype.beartype
|
||||
def _inline_flatten_list(inputs, res_list) -> list:
|
||||
for i in inputs:
|
||||
res_list.append(i) if not isinstance(
|
||||
i, (list, tuple)
|
||||
@ -62,7 +69,8 @@ def _inline_flatten_list(inputs, res_list):
|
||||
return res_list
|
||||
|
||||
|
||||
def _unpack_to_numpy(values, cast_onnx_accepted=True):
|
||||
@_beartype.beartype
|
||||
def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
|
||||
value_unpacked = []
|
||||
for value in values:
|
||||
value_unpacked.extend(
|
||||
@ -71,6 +79,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True):
|
||||
return [_to_numpy(v) for v in value_unpacked]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_ort(ort_session, inputs):
|
||||
kw_inputs = {}
|
||||
if inputs and isinstance(inputs[-1], dict):
|
||||
@ -92,6 +101,7 @@ def _run_ort(ort_session, inputs):
|
||||
return ort_outs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _ort_session(
|
||||
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
|
||||
):
|
||||
@ -115,9 +125,10 @@ def _ort_session(
|
||||
return ort_session
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_ort_pytorch_outputs(
|
||||
ort_outs: Sequence[np.ndarray],
|
||||
pt_outs: Sequence[torch.Tensor],
|
||||
ort_outs: Union[Sequence[_NumericType], Sequence],
|
||||
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]],
|
||||
rtol: float,
|
||||
atol: float,
|
||||
check_shape: bool,
|
||||
@ -191,6 +202,7 @@ def _compare_ort_pytorch_outputs(
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_pytorch(args, kwargs):
|
||||
"""Prepare input for PyTorch model execution.
|
||||
|
||||
@ -217,6 +229,7 @@ def _prepare_input_for_pytorch(args, kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_export(args, kwargs):
|
||||
"""Prepare input for ONNX model export.
|
||||
|
||||
@ -241,6 +254,7 @@ def _prepare_input_for_export(args, kwargs):
|
||||
return onnx_inputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
|
||||
"""Prepare input for ONNX model execution in ONNX Runtime.
|
||||
|
||||
@ -266,6 +280,7 @@ def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
|
||||
return onnx_inputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_clone_model(model):
|
||||
"""Used for preserving original model in case forward mutates model states."""
|
||||
try:
|
||||
@ -277,6 +292,7 @@ def _try_clone_model(model):
|
||||
return model
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_ort_pytorch_model(
|
||||
model,
|
||||
ort_session,
|
||||
@ -301,6 +317,7 @@ def _compare_ort_pytorch_model(
|
||||
equal up to specified precision.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def compare_ort_pytorch_model_with_input(input_args, input_kwargs):
|
||||
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
|
||||
# TODO: remove this and treat mutating model separately. See #77679
|
||||
@ -333,6 +350,7 @@ def _compare_ort_pytorch_model(
|
||||
class _GraphDiff:
|
||||
"""A class to represent the difference between two graphs."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
|
||||
"""Construct a _GraphDiff object.
|
||||
|
||||
@ -343,13 +361,16 @@ class _GraphDiff:
|
||||
self.graph_a = graph_a
|
||||
self.graph_b = graph_b
|
||||
|
||||
@_beartype.beartype
|
||||
def __str__(self):
|
||||
"""See function :func:`diff_report`."""
|
||||
return self.diff_report()
|
||||
|
||||
@_beartype.beartype
|
||||
def _indent(self, lines: str) -> str:
|
||||
return "\n".join(["\t" + line for line in lines.splitlines()])
|
||||
|
||||
@_beartype.beartype
|
||||
def diff_report(self) -> str:
|
||||
"""Return a string representation of the graph difference.
|
||||
|
||||
@ -399,6 +420,7 @@ class _GraphDiff:
|
||||
return "\n".join(graph_diff_report)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _check_graph_diff(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
||||
@ -440,6 +462,7 @@ def _check_graph_diff(
|
||||
return ""
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _traced_graph_from_model(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
args: Tuple[Any, ...],
|
||||
@ -467,6 +490,7 @@ def _traced_graph_from_model(
|
||||
return jit_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_graph_from_model(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
args: Tuple[Any, ...],
|
||||
@ -534,6 +558,7 @@ def _onnx_graph_from_model(
|
||||
return onnx_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def check_export_model_diff(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
||||
@ -577,9 +602,10 @@ def check_export_model_diff(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def verify(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
input_args: Tuple[Any, ...],
|
||||
input_args: Union[torch.Tensor, Tuple[Any, ...]],
|
||||
input_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
do_constant_folding: bool = True,
|
||||
dynamic_axes: Optional[
|
||||
@ -593,7 +619,9 @@ def verify(
|
||||
verbose: bool = False,
|
||||
fixed_batch_size: bool = False,
|
||||
use_external_data: bool = False,
|
||||
additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
|
||||
additional_test_inputs: Optional[
|
||||
Sequence[Union[torch.Tensor, Tuple[Any, ...]]]
|
||||
] = None,
|
||||
remained_onnx_input_idx: Optional[Sequence[int]] = None,
|
||||
flatten: bool = True,
|
||||
ignore_none: bool = True,
|
||||
|
Reference in New Issue
Block a user