mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
275 lines
8.5 KiB
Python
275 lines
8.5 KiB
Python
"""Importing this patches torch._C classes to add ONNX conveniences."""
|
|
import numbers
|
|
import re
|
|
from typing import Any, Iterable, Tuple, Union
|
|
|
|
import torch
|
|
from torch import _C
|
|
from torch._C import _onnx as _C_onnx
|
|
|
|
# Import utils to get _params_dict because it is a global that is accessed by c++ code
|
|
from torch.onnx import _deprecation, utils
|
|
from torch.onnx._globals import GLOBALS
|
|
from torch.onnx._internal import _beartype
|
|
|
|
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
|
|
|
|
|
# TODO(#78694): Refactor the patching process to make it more transparent to users.
|
|
@_beartype.beartype
|
|
def _graph_op(
|
|
g: _C.Graph,
|
|
opname: str,
|
|
*raw_args: Union[torch.Tensor, _C.Value],
|
|
outputs: int = 1,
|
|
**kwargs,
|
|
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
|
|
r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
|
|
|
|
The set of operators and the inputs/attributes they take
|
|
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
|
|
|
This function is monkey-patched onto Graph.
|
|
|
|
Args:
|
|
g: The Torch graph.
|
|
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
|
with a namespace, e.g., `aten::add`.
|
|
raw_args: The inputs to the operator; usually provided
|
|
as arguments to the `symbolic` definition.
|
|
outputs: The number of outputs this operator returns.
|
|
By default an operator is assumed to return a single output.
|
|
If `outputs` is greater than one, this functions returns a tuple
|
|
of output `Node`, representing each output of the ONNX operator
|
|
in positional.
|
|
kwargs: The attributes of the ONNX operator, whose keys are named
|
|
according to the following convention: `alpha_f` indicates
|
|
the `alpha` attribute with type `f`. The valid type specifiers are
|
|
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
|
specified with type float accepts either a single float, or a
|
|
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
|
that takes a list of integers).
|
|
|
|
Returns:
|
|
The node representing the single output of this operator (see the `outputs`
|
|
keyword argument for multi-return nodes).
|
|
"""
|
|
# Filter out None attributes, this can be convenient client side because
|
|
# now they can pass through None attributes, and have them not show up
|
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
|
|
args = [_const_if_tensor(g, arg) for arg in raw_args]
|
|
|
|
if "::" in opname:
|
|
namespace, op = opname.split("::")
|
|
else:
|
|
namespace = "onnx"
|
|
op = opname
|
|
|
|
n = g.insertNode(_new_node(g, namespace, op, outputs, *args, **kwargs))
|
|
|
|
if GLOBALS.onnx_shape_inference:
|
|
_C._jit_pass_onnx_node_shape_type_inference(
|
|
n, utils._params_dict, GLOBALS.export_onnx_opset_version
|
|
)
|
|
|
|
if outputs == 1:
|
|
return n.output()
|
|
return tuple(n.outputs())
|
|
|
|
|
|
@_beartype.beartype
|
|
def _const_if_tensor(g: _C.Graph, arg):
|
|
if arg is None:
|
|
return arg
|
|
if isinstance(arg, _C.Value):
|
|
return arg
|
|
return _graph_op(g, "Constant", value_z=arg)
|
|
|
|
|
|
# Generate an ONNX ATen op node.
|
|
@_beartype.beartype
|
|
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
|
|
return _graph_op(
|
|
g,
|
|
"aten::ATen",
|
|
*args,
|
|
operator_s=operator,
|
|
overload_name_s=overload_name,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _block_op(b: _C.Block, opname: str, *args: _C.Value, **kwargs):
|
|
if "::" in opname:
|
|
aten = False
|
|
ns_opname = opname
|
|
else:
|
|
aten = kwargs.pop("aten", False)
|
|
ns = "aten" if aten else "onnx"
|
|
ns_opname = ns + "::" + opname
|
|
n = b.addNode(ns_opname, args)
|
|
for k, v in sorted(kwargs.items()):
|
|
if k == "inplace":
|
|
continue
|
|
_add_attribute(n, k, v, aten=aten)
|
|
outputs = tuple(n.outputs())
|
|
if len(outputs) == 1:
|
|
return n.output()
|
|
return outputs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _new_node(
|
|
g: _C.Graph, namespace: str, op: str, outputs: int, *args: _C.Value, **kwargs
|
|
) -> _C.Node:
|
|
"""Creates a new node in the graph.
|
|
|
|
Args:
|
|
g: The graph to create the operator on.
|
|
namespace: The namespace of the operator. E.g., "aten", "onnx".
|
|
op: The name of the operator to create.
|
|
outputs: The number of the outputs of the node.
|
|
|
|
Returns:
|
|
The new node.
|
|
"""
|
|
aten = kwargs.pop("aten", False)
|
|
node = g.create(f"{namespace}::{op}", args, outputs)
|
|
for k, v in sorted(kwargs.items()):
|
|
if k == "inplace":
|
|
continue
|
|
_add_attribute(node, k, v, aten=aten)
|
|
return node
|
|
|
|
|
|
@_beartype.beartype
|
|
def _is_onnx_list(value):
|
|
return (
|
|
not isinstance(value, torch._six.string_classes)
|
|
and not isinstance(value, torch.Tensor)
|
|
and isinstance(value, Iterable)
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _scalar(x: torch.Tensor):
|
|
"""Convert a scalar tensor into a Python value."""
|
|
assert x.numel() == 1
|
|
return x[0]
|
|
|
|
|
|
@_beartype.beartype
|
|
def _is_caffe2_aten_fallback() -> bool:
|
|
return (
|
|
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
|
r"""Initializes the right attribute based on type of value."""
|
|
m = _ATTR_PATTERN.match(key)
|
|
if m is None:
|
|
raise ValueError(
|
|
f"Invalid attribute specifier '{key}' names "
|
|
" must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
|
|
)
|
|
name, kind = m.group(1), m.group(2)
|
|
if _is_onnx_list(value):
|
|
kind += "s"
|
|
|
|
if aten and _is_caffe2_aten_fallback():
|
|
if isinstance(value, torch.Tensor):
|
|
# Caffe2 proto does not support tensor attribute.
|
|
if value.numel() > 1:
|
|
raise ValueError("Should not pass tensor attribute")
|
|
value = _scalar(value)
|
|
if isinstance(value, float):
|
|
kind = "f"
|
|
else:
|
|
kind = "i"
|
|
return getattr(node, f"{kind}_")(name, value)
|
|
|
|
|
|
# TODO(#76254): Remove the deprecated function.
|
|
@_deprecation.deprecated(
|
|
"1.13", "1.14", "Use 'g.op()' to create a constant node instead."
|
|
)
|
|
@_beartype.beartype
|
|
def _graph_constant(
|
|
g,
|
|
value,
|
|
dims,
|
|
type_: str,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""This helper function can create either constant tensor or constant scalar.
|
|
|
|
If dims is None or 0 or [0], generate a 0-d tensor (scalar).
|
|
"""
|
|
assert isinstance(value, numbers.Number)
|
|
assert type_ is not None
|
|
isscalar = False
|
|
if dims is None or dims == 0 or set(dims) == {0}:
|
|
dims = [1]
|
|
isscalar = True
|
|
type_ = type_.lower()
|
|
tensor: Union[
|
|
torch.CharTensor,
|
|
torch.ShortTensor,
|
|
torch.IntTensor,
|
|
torch.LongTensor,
|
|
torch.HalfTensor,
|
|
torch.FloatTensor,
|
|
torch.DoubleTensor,
|
|
]
|
|
if type_ == "char":
|
|
tensor = torch.CharTensor(*dims)
|
|
elif type_ == "short":
|
|
tensor = torch.ShortTensor(*dims)
|
|
elif type_ == "int":
|
|
tensor = torch.IntTensor(*dims)
|
|
elif type_ == "long":
|
|
tensor = torch.LongTensor(*dims)
|
|
elif type_ == "half":
|
|
tensor = torch.HalfTensor(*dims)
|
|
elif type_ == "float":
|
|
tensor = torch.FloatTensor(*dims)
|
|
elif type_ == "double":
|
|
tensor = torch.DoubleTensor(*dims)
|
|
else:
|
|
raise ValueError(
|
|
"Unknown type, type should be one of the following strings: "
|
|
"char, short, int, long, half, float, double"
|
|
)
|
|
tensor.fill_(value) # type: ignore[call-overload]
|
|
if isscalar:
|
|
return g.op("Constant", *args, value_z=tensor, **kwargs)
|
|
return g.op("Constant", *args, value_t=tensor, **kwargs)
|
|
|
|
|
|
# TODO(#76254): Remove the deprecated function.
|
|
@_deprecation.deprecated(
|
|
"1.13",
|
|
"1.14",
|
|
"Internally use '_node_get' in symbolic_helper instead.",
|
|
)
|
|
def _node_getitem(self, k):
|
|
"""Gets attributes of a node which is polymorphic over return type.
|
|
|
|
This is monkey-patched onto Node.
|
|
"""
|
|
sel = self.kindOf(k)
|
|
return getattr(self, sel)(k)
|
|
|
|
|
|
torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
|
|
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
|
|
torch._C.Block.op = _block_op # type: ignore[attr-defined]
|
|
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
|
|
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]
|