Revert "[ONNX] Remove beartype usage (#130484)"

This reverts commit 1794c35912025aa44b0d70f67ff664b4f7bd1014.

Reverted https://github.com/pytorch/pytorch/pull/130484 on behalf of https://github.com/clee2000 due to test_sympy_utils failure is real https://github.com/pytorch/pytorch/actions/runs/9961499559/job/27523758780 1794c35912.  Dr CI is matching with commits in current commit? ([comment](https://github.com/pytorch/pytorch/pull/130484#issuecomment-2231575577))
This commit is contained in:
PyTorch MergeBot
2024-07-16 18:41:51 +00:00
parent 09b1b113f5
commit 0851de5b16
52 changed files with 1265 additions and 107 deletions

View File

@ -10,6 +10,7 @@ retry () {
# A bunch of custom pip dependencies for ONNX
pip_install \
beartype==0.15.0 \
filelock==3.9.0 \
flatbuffers==2.0 \
mock==5.0.1 \

View File

@ -3,16 +3,18 @@ import io
import os
import onnx
from beartype import roar
import torch
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
from torch.onnx._internal import exporter
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.exporter import (
LargeProtobufONNXProgramSerializer,
ONNXProgramSerializer,
ProtobufONNXProgramSerializer,
ResolvedExportOptions,
)
from torch.onnx._internal.fx import diagnostics
from torch.testing._internal import common_utils
@ -47,6 +49,15 @@ class _LargeModel(torch.nn.Module):
class TestExportOptionsAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
with self.assertRaises(expected_exception_type):
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ExportOptions(diagnostic_options="DEBUG") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ResolvedExportOptions(options=12) # type: ignore[arg-type]
def test_dynamic_shapes_default(self):
options = ResolvedExportOptions(ExportOptions())
self.assertFalse(options.dynamic_shapes)
@ -109,6 +120,7 @@ class TestDynamoExportAPI(common_utils.TestCase):
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
# Because `ONNXProgramSerializer` is a Protocol class.
# `beartype` will not complain.
class CustomSerializer:
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
@ -184,8 +196,27 @@ class TestDynamoExportAPI(common_utils.TestCase):
),
)
def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
onnx_program = ONNXProgram(
onnx.ModelProto(),
io_adapter.InputAdapter(),
io_adapter.OutputAdapter(),
diagnostics.DiagnosticContext("test", "1.0"),
fake_context=None,
)
with self.assertRaises(roar.BeartypeException):
onnx_program.save(None) # type: ignore[arg-type]
onnx_program.model_proto
class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufONNXProgramSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]
def test_serialize_raises_when_model_greater_than_2gb(self):
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
serializer = ProtobufONNXProgramSerializer()

View File

@ -0,0 +1,84 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the internal beartype wrapper module."""
import unittest
from torch.onnx._internal import _beartype
from torch.testing._internal import common_utils
def beartype_installed():
try:
import beartype # noqa: F401
except ImportError:
return False
return True
def skip_if_beartype_not_installed(test_case):
return unittest.skipIf(not beartype_installed(), "beartype is not installed")(
test_case
)
def func_with_type_hint(x: int) -> int:
return x
def func_with_incorrect_type_hint(x: int) -> str:
return x # type: ignore[return-value]
@common_utils.instantiate_parametrized_tests
class TestBeartype(common_utils.TestCase):
def test_create_beartype_decorator_returns_no_op_decorator_when_disabled(self):
decorator = _beartype._create_beartype_decorator(
_beartype.RuntimeTypeCheckState.DISABLED,
)
decorated = decorator(func_with_incorrect_type_hint)
decorated("string_input") # type: ignore[arg-type]
@skip_if_beartype_not_installed
def test_create_beartype_decorator_warns_when_warnings(self):
decorator = _beartype._create_beartype_decorator(
_beartype.RuntimeTypeCheckState.WARNINGS,
)
decorated = decorator(func_with_incorrect_type_hint)
with self.assertWarns(_beartype.CallHintViolationWarning):
decorated("string_input") # type: ignore[arg-type]
@common_utils.parametrize("arg", [1, "string_input"])
@skip_if_beartype_not_installed
def test_create_beartype_decorator_errors_when_errors(self, arg):
import beartype
decorator = _beartype._create_beartype_decorator(
_beartype.RuntimeTypeCheckState.ERRORS,
)
decorated = decorator(func_with_incorrect_type_hint)
with self.assertRaises(beartype.roar.BeartypeCallHintViolation):
decorated(arg)
@skip_if_beartype_not_installed
def test_create_beartype_decorator_warning_calls_function_once(self):
call_count = 0
def func_with_incorrect_type_hint_and_side_effect(x: int) -> str:
nonlocal call_count
call_count += 1
return x # type: ignore[return-value]
decorator = _beartype._create_beartype_decorator(
_beartype.RuntimeTypeCheckState.WARNINGS,
)
decorated = decorator(func_with_incorrect_type_hint_and_side_effect)
decorated("string_input") # type: ignore[arg-type]
self.assertEqual(call_count, 1)
decorated(1)
# The return value violates the type hint, but the function is called
# only once.
self.assertEqual(call_count, 2)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -34,6 +34,7 @@ import pytorch_test_common
import torch
from torch import export as torch_export
from torch.onnx import _constants, verification
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import diagnostics
from torch.testing._internal import common_utils
from torch.testing._internal.opinfo import core as opinfo_core
@ -205,6 +206,7 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
if not is_model_script and not self.is_script:
_run_test(model, tracing_remained_onnx_input_idx)
@_beartype.beartype
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self,
model: _ModelType,
@ -358,6 +360,7 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
)
@_beartype.beartype
def run_ort(
onnx_model: Union[str, torch.onnx.ONNXProgram],
pytorch_inputs: Sequence[_InputArgsType],
@ -403,6 +406,7 @@ def run_ort(
return session.run(None, ort_input)
@_beartype.beartype
def _try_clone_model(model: _ModelType) -> _ModelType:
"""Used for preserving original model in case forward mutates model states."""
try:
@ -414,12 +418,14 @@ def _try_clone_model(model: _ModelType) -> _ModelType:
return model
@_beartype.beartype
def _try_clone_inputs(input_args, input_kwargs):
ref_input_args = copy.deepcopy(input_args)
ref_input_kwargs = copy.deepcopy(input_kwargs)
return ref_input_args, ref_input_kwargs
@_beartype.beartype
def _compare_pytorch_onnx_with_ort(
onnx_program: torch.onnx.ONNXProgram,
model: _ModelType,

View File

@ -22,7 +22,7 @@ import torch.onnx
from torch import nn
from torch._subclasses import fake_tensor
from torch.onnx._internal import exporter
from torch.onnx._internal import _beartype, exporter
from torch.onnx._internal.fx import (
diagnostics,
fx_symbolic_graph_extractor,
@ -721,6 +721,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
CustomModule(), (torch.randn(1, 2, 3),)
)
@_beartype.beartype
def _test_fx_symbolic_tracer_large_scale_exporter(
self,
model_name: str,
@ -954,6 +955,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
super().setUp()
self.ort_version = onnxruntime.__version__
@_beartype.beartype
def _test_fake_tensor_mode_exporter(
self,
model_name: str,

View File

@ -35,6 +35,10 @@ Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
transparently dispatched to their non inplace versions in
"run_symbolic_function". See Note [Export inplace](#export-inplace)
- Required: Annotate new symbolic functions with type annotations and decorate
with `@_beartype.beartype` to enable runtime type checking.
`@_beartype.beartype` should typically be the closest to the function to
ensure proper type checking.
### A note on Tensor types

View File

@ -0,0 +1,132 @@
# mypy: allow-untyped-defs
"""An internal wrapper for the beartype library.
The module returns a no-op decorator when the beartype library is not installed.
"""
import enum
import functools
import os
import traceback
import typing
import warnings
from types import ModuleType
try:
import beartype as _beartype_lib # type: ignore[import]
from beartype import roar as _roar # type: ignore[import]
# Beartype warns when we import from typing because the types are deprecated
# in Python 3.9. But there will be a long time until we can move to using
# the native container types for type annotations (when 3.9 is the lowest
# supported version). So we silence the warning.
warnings.filterwarnings(
"ignore",
category=_roar.BeartypeDecorHintPep585DeprecationWarning,
)
if _beartype_lib.__version__ == "0.16.0":
# beartype 0.16.0 has a bug that causes it to crash when used with
# PyTorch. See https://github.com/beartype/beartype/issues/282
warnings.warn("beartype 0.16.0 is not supported. Please upgrade to 0.16.1+.")
_beartype_lib = None # type: ignore[assignment]
except ImportError:
_beartype_lib = None # type: ignore[assignment]
except Exception as e:
# Warn errors that are not import errors (unexpected).
warnings.warn(f"{e}")
_beartype_lib = None # type: ignore[assignment]
@enum.unique
class RuntimeTypeCheckState(enum.Enum):
"""Runtime type check state."""
# Runtime type checking is disabled.
DISABLED = enum.auto()
# Runtime type checking is enabled but warnings are shown only.
WARNINGS = enum.auto()
# Runtime type checking is enabled.
ERRORS = enum.auto()
class CallHintViolationWarning(UserWarning):
"""Warning raised when a type hint is violated during a function call."""
pass
def _no_op_decorator(func):
return func
def _create_beartype_decorator(
runtime_check_state: RuntimeTypeCheckState,
):
# beartype needs to be imported outside of the function and aliased because
# this module overwrites the name "beartype".
if runtime_check_state == RuntimeTypeCheckState.DISABLED:
return _no_op_decorator
if _beartype_lib is None:
# If the beartype library is not installed, return a no-op decorator
return _no_op_decorator
assert isinstance(_beartype_lib, ModuleType)
if runtime_check_state == RuntimeTypeCheckState.ERRORS:
# Enable runtime type checking which errors on any type hint violation.
return _beartype_lib.beartype
# Warnings only
def beartype(func):
"""Warn on type hint violation."""
if "return" in func.__annotations__:
# Remove the return type from the func function's
# annotations so that the beartype decorator does not complain
# about the return type.
return_type = func.__annotations__["return"]
del func.__annotations__["return"]
beartyped = _beartype_lib.beartype(func)
# Restore the return type to the func function's annotations
func.__annotations__["return"] = return_type
else:
beartyped = _beartype_lib.beartype(func)
@functools.wraps(func)
def _coerce_beartype_exceptions_to_warnings(*args, **kwargs):
try:
return beartyped(*args, **kwargs)
except _roar.BeartypeCallHintParamViolation:
# Fall back to the original function if the beartype hint is violated.
warnings.warn(
traceback.format_exc(),
category=CallHintViolationWarning,
stacklevel=2,
)
return func(*args, **kwargs) # noqa: B012
return _coerce_beartype_exceptions_to_warnings
return beartype
if typing.TYPE_CHECKING:
# This is a hack to make mypy play nicely with the beartype decorator.
def beartype(func):
return func
else:
_TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK = os.getenv(
"TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK"
)
if _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "ERRORS":
_runtime_type_check_state = RuntimeTypeCheckState.ERRORS
elif _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "DISABLED":
_runtime_type_check_state = RuntimeTypeCheckState.DISABLED
else:
_runtime_type_check_state = RuntimeTypeCheckState.WARNINGS
beartype = _create_beartype_decorator(_runtime_type_check_state)
# Make sure that the beartype decorator is enabled whichever path we took.
assert beartype is not None

View File

@ -4,7 +4,8 @@ from __future__ import annotations
import contextlib
import gzip
from typing import List, Optional, TYPE_CHECKING
from collections.abc import Generator
from typing import List, Optional
import torch
@ -13,9 +14,6 @@ from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace
if TYPE_CHECKING:
from collections.abc import Generator
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
"""Returns the current C++ call stack.
@ -26,6 +24,7 @@ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.S
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
frame_messages = []
for frame in frames:
@ -71,6 +70,9 @@ class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current C++ call stack in the diagnostic."""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
# No need to skip this function because python frame is not recorded
# in cpp call stack.
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
stack.message = "C++ call stack"
self.with_stack(stack)
@ -198,6 +200,7 @@ def diagnose(
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)

View File

@ -6,6 +6,7 @@ import logging
import traceback
from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch.onnx._internal import _beartype
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, utils
@ -13,10 +14,12 @@ from torch.onnx._internal.diagnostics.infra import formatter, utils
MessageFormatterType = Callable[..., str]
@_beartype.beartype
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
return f"{formatter.display_name(fn)}. "
@_beartype.beartype
def format_exception_in_markdown(exception: Exception) -> str:
msg_list = ["### Exception log", "```"]
msg_list.extend(
@ -26,6 +29,7 @@ def format_exception_in_markdown(exception: Exception) -> str:
return "\n".join(msg_list)
@_beartype.beartype
def format_function_signature_in_markdown(
fn: Callable,
args: Tuple[Any, ...],
@ -42,6 +46,7 @@ def format_function_signature_in_markdown(
return "\n".join(msg_list)
@_beartype.beartype
def format_return_values_in_markdown(
return_values: Any,
format_argument: Callable[[Any], str] = formatter.format_argument,
@ -54,6 +59,7 @@ ModifierCallableType = Callable[
]
@_beartype.beartype
def diagnose_call(
rule: infra.Rule,
*,

View File

@ -7,7 +7,7 @@ import traceback
from typing import Any, Callable, Dict, List, Optional, Union
from torch._logging import LazyString
from torch.onnx._internal import _beartype
from torch.onnx._internal.diagnostics.infra import sarif
@ -35,6 +35,7 @@ def lazy_format_exception(exception: Exception) -> LazyString:
)
@_beartype.beartype
def snake_case_to_camel_case(s: str) -> str:
splits = s.split("_")
if len(splits) <= 1:
@ -42,14 +43,17 @@ def snake_case_to_camel_case(s: str) -> str:
return "".join([splits[0], *map(str.capitalize, splits[1:])])
@_beartype.beartype
def camel_case_to_snake_case(s: str) -> str:
return re.sub(r"([A-Z])", r"_\1", s).lower()
@_beartype.beartype
def kebab_case_to_snake_case(s: str) -> str:
return s.replace("-", "_")
@_beartype.beartype
def _convert_key(
object: Union[Dict[str, Any], Any], convert: Callable[[str], str]
) -> Union[Dict[str, Any], Any]:
@ -88,16 +92,19 @@ def _convert_key(
return new_dict
@_beartype.beartype
def sarif_to_json(attr_cls_obj: _SarifClass, indent: Optional[str] = " ") -> str:
dict = dataclasses.asdict(attr_cls_obj)
dict = _convert_key(dict, snake_case_to_camel_case)
return json.dumps(dict, indent=indent, separators=(",", ":"))
@_beartype.beartype
def format_argument(obj: Any) -> str:
return f"{type(obj)}"
@_beartype.beartype
def display_name(fn: Callable) -> str:
if hasattr(fn, "__qualname__"):
return fn.__qualname__

View File

@ -6,9 +6,11 @@ import inspect
import traceback
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
from torch.onnx._internal import _beartype
from torch.onnx._internal.diagnostics.infra import _infra, formatter
@_beartype.beartype
def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
"""Returns a StackFrame for the given traceback.FrameSummary."""
snippet = frame.line
@ -24,13 +26,14 @@ def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
)
@_beartype.beartype
def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
"""Returns the current Python call stack."""
if frames_to_skip < 0:
raise ValueError("frames_to_skip must be non-negative")
if frames_to_log < 0:
raise ValueError("frames_to_log must be non-negative")
frames_to_skip += 1 # Skip this function.
frames_to_skip += 2 # Skip this function and beartype.
stack = _infra.Stack()
# Frames are returned in order of oldest to newest.
frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
@ -51,6 +54,7 @@ def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[st
return source_lines, lineno, inspect.getsourcefile(fn)
@_beartype.beartype
def function_location(fn: Callable) -> _infra.Location:
"""Returns a Location for the given function."""
source_lines, lineno, uri = _function_source_info(fn)
@ -63,6 +67,7 @@ def function_location(fn: Callable) -> _infra.Location:
)
@_beartype.beartype
def function_state(
fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Mapping[str, Any]:

View File

@ -7,6 +7,7 @@ import abc
import contextlib
import dataclasses
import io
import logging
import os
@ -39,7 +40,7 @@ import torch.export as torch_export
import torch.utils._pytree as pytree
from torch._subclasses import fake_tensor
from torch.onnx._internal import io_adapter
from torch.onnx._internal import _beartype, io_adapter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import (
decomposition_table,
@ -52,8 +53,6 @@ from torch.onnx._internal.fx import (
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import io
import onnx
import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import]
@ -62,6 +61,15 @@ if TYPE_CHECKING:
)
from torch.onnx._internal.fx import diagnostics
else:
try:
# beartype needs this import due to runtime type checking.
# This cannot be normally imported at top level due to
# https://github.com/pytorch/pytorch/issues/103764
from torch.onnx._internal.fx import diagnostics
except ImportError:
# The error will be handled elsewhere when the exporter is used.
pass
_DEFAULT_OPSET_VERSION: Final[int] = 18
"""The default ONNX opset version the exporter will use if one is not specified explicitly
@ -170,6 +178,7 @@ class OnnxRegistry:
)
self._register(internal_name_instance, symbolic_function)
@_beartype.beartype
def _register(
self,
internal_qualified_name: registration.OpName,
@ -183,9 +192,10 @@ class OnnxRegistry:
"""
self._registry[internal_qualified_name].append(symbolic_function)
@_beartype.beartype
def register_op(
self,
function: Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction],
function: Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"],
namespace: str,
op_name: str,
overload: Optional[str] = None,
@ -215,6 +225,7 @@ class OnnxRegistry:
)
self._register(internal_name_instance, symbolic_function)
@_beartype.beartype
def get_op_functions(
self, namespace: str, op_name: str, overload: Optional[str] = None
) -> Optional[List[registration.ONNXFunction]]:
@ -237,6 +248,7 @@ class OnnxRegistry:
)
return self._registry.get(internal_name_instance)
@_beartype.beartype
def is_registered_op(
self, namespace: str, op_name: str, overload: Optional[str] = None
) -> bool:
@ -256,6 +268,7 @@ class OnnxRegistry:
)
return functions is not None
@_beartype.beartype
def _all_registered_ops(self) -> Set[str]:
"""Returns the set of all registered function names."""
return {
@ -297,6 +310,7 @@ class ExportOptions:
onnx_registry: Optional[OnnxRegistry] = None
"""The ONNX registry used to register ATen operators to ONNX functions."""
@_beartype.beartype
def __init__(
self,
*,
@ -342,9 +356,10 @@ class ResolvedExportOptions(ExportOptions):
"""The diagnostics context for the export. Responsible for recording diagnostics,
logging diagnostics, and generating the SARIF log."""
@_beartype.beartype
def __init__(
self,
options: Union[ExportOptions, ResolvedExportOptions],
options: Union[ExportOptions, "ResolvedExportOptions"],
model: Optional[Union[torch.nn.Module, Callable, torch_export.ExportedProgram]] = None, # type: ignore[name-defined]
):
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
@ -375,6 +390,7 @@ class ResolvedExportOptions(ExportOptions):
else:
T = TypeVar("T")
@_beartype.beartype
def resolve(value: Optional[T], fallback: Union[T, Callable[[], T]]) -> T:
if value is not None:
return value
@ -392,7 +408,7 @@ class ResolvedExportOptions(ExportOptions):
else:
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
self.fake_context = resolve(options.fake_context, None)
self.diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export",
torch.__version__,
@ -400,8 +416,10 @@ class ResolvedExportOptions(ExportOptions):
)
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
self.decomposition_table = decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
self.onnx_registry
self.decomposition_table = (
decomposition_table.create_onnx_friendly_decomposition_table(
self.onnx_registry
)
)
from torch.onnx._internal.fx import onnxfunction_dispatcher
@ -544,6 +562,7 @@ class ONNXProgramSerializer(Protocol):
class ProtobufONNXProgramSerializer:
"""Serializes ONNX graph as Protobuf."""
@_beartype.beartype
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
@ -565,6 +584,7 @@ class LargeProtobufONNXProgramSerializer:
def __init__(self, destination_path: str):
self._destination_path = destination_path
@_beartype.beartype
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
@ -593,7 +613,7 @@ class ONNXRuntimeOptions:
execution_provider_options: ONNX Runtime execution provider options.
"""
session_options: Optional[Sequence[onnxruntime.SessionOptions]] = None
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
"""ONNX Runtime session options."""
execution_providers: Optional[
@ -604,10 +624,11 @@ class ONNXRuntimeOptions:
execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
"""ONNX Runtime execution provider options."""
@_beartype.beartype
def __init__(
self,
*,
session_options: Optional[Sequence[onnxruntime.SessionOptions]] = None,
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
execution_providers: Optional[
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
] = None,
@ -642,6 +663,7 @@ class ONNXProgram:
Optional[Union[torch.nn.Module, Callable, torch_export.ExportedProgram]]
]
@_beartype.beartype
def __init__(
self,
model_proto: onnx.ModelProto, # type: ignore[name-defined]
@ -728,7 +750,7 @@ class ONNXProgram:
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
onnxruntime_input = {
k.name: v.numpy(force=True) # type: ignore[union-attr]
k.name: v.numpy(force=True)
for k, v in zip(ort_session.get_inputs(), onnx_input)
}
@ -849,6 +871,7 @@ class ONNXProgram:
return self._fake_context
@_beartype.beartype
def adapt_torch_inputs_to_onnx(
self,
*model_args,
@ -917,10 +940,11 @@ class ONNXProgram:
assert (
model_with_state_dict is not None
), "model_with_state_dict must be specified."
return self._input_adapter.apply( # type: ignore[return-value]
return self._input_adapter.apply(
*model_args, model=model_with_state_dict, **model_kwargs
)
@_beartype.beartype
def adapt_torch_outputs_to_onnx(
self,
model_outputs: Any,
@ -979,8 +1003,9 @@ class ONNXProgram:
assert (
model_with_state_dict is not None
), "model_with_state_dict must be specified."
return self._output_adapter.apply(model_outputs, model=model_with_state_dict) # type: ignore[return-value]
return self._output_adapter.apply(model_outputs, model=model_with_state_dict)
@_beartype.beartype
def save(
self,
destination: Union[str, io.BufferedIOBase],
@ -1069,6 +1094,7 @@ class ONNXProgram:
"External tensor data will be saved alongside the model on disk."
) from exc
@_beartype.beartype
def save_diagnostics(self, destination: str) -> None:
"""Saves the export diagnostics as a SARIF log to the specified destination path.
@ -1110,7 +1136,8 @@ class ONNXProgram:
# https://github.com/pytorch/pytorch/issues/103764
import onnx
return cls(
# TODO: Should we populate ONNXProgram with more info, such _model_torch for easier debug?
return ONNXProgram(
onnx.ModelProto(), # type: ignore[attr-defined]
io_adapter.InputAdapter(),
io_adapter.OutputAdapter(),
@ -1168,6 +1195,7 @@ class FXGraphExtractor(abc.ABC):
class Exporter:
@_beartype.beartype
def __init__(
self,
options: ResolvedExportOptions,
@ -1345,6 +1373,7 @@ class InvalidExportOptionsError(RuntimeError):
pass
@_beartype.beartype
def _assert_dependencies(export_options: ResolvedExportOptions):
opset_version = export_options.onnx_registry.opset_version
@ -1386,6 +1415,7 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
raise missing_opset("onnxscript")
@_beartype.beartype
def dynamo_export(
model: Union[torch.nn.Module, Callable, torch_export.ExportedProgram], # type: ignore[name-defined]
/,

View File

@ -17,7 +17,7 @@ import torch
import torch.fx
from torch._subclasses import fake_tensor
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
@ -128,6 +128,7 @@ def _unified_diff(a: str, b: str) -> str:
return diff
@_beartype.beartype
def _transform_diagnose_call_message_formatter(
run: Callable,
self: Transform,
@ -311,6 +312,7 @@ class AnalysisResult(abc.ABC): # noqa: B024
class Analysis(abc.ABC):
@_beartype.beartype
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,

View File

@ -9,9 +9,14 @@ import torch
import torch._ops
import torch.fx
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import registration
# NOTE: OnnxRegistry annotation: beartype is a runtime type checker for python3,
# so it doesn't work with TYPE_CHECKING
@_beartype.beartype
def _create_onnx_supports_op_overload_table(
registry,
) -> Set[Union[torch._ops.OperatorBase, Callable]]:
@ -71,6 +76,7 @@ def _create_onnx_supports_op_overload_table(
return table
@_beartype.beartype
def create_onnx_friendly_decomposition_table(
registry,
) -> Dict[torch._ops.OperatorBase, Callable]:

View File

@ -26,7 +26,7 @@ import torch._dynamo
import torch.export as torch_export
import torch.fx
import torch.onnx
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal import _beartype, exporter, io_adapter
from torch.utils import _pytree as pytree
@ -53,6 +53,7 @@ class _PyTreeExtensionContext:
for class_type in self._extensions:
pytree.SUPPORTED_NODES.pop(class_type)
@_beartype.beartype
def register_pytree_node(
self,
class_type: Type,
@ -82,11 +83,13 @@ class _PyTreeExtensionContext:
except ImportError as e:
return
@_beartype.beartype
def model_output_flatten(
output: modeling_outputs.ModelOutput,
) -> Tuple[List[Any], pytree.Context]:
return list(output.values()), (type(output), list(output.keys()))
@_beartype.beartype
def model_output_unflatten(
values: List[Any], context: pytree.Context
) -> modeling_outputs.ModelOutput:
@ -105,7 +108,7 @@ class _PyTreeExtensionContext:
for _, class_type in named_model_output_classes:
self.register_pytree_node(
class_type, model_output_flatten, model_output_unflatten # type: ignore[arg-type ]
class_type, model_output_flatten, model_output_unflatten
)
@ -229,6 +232,7 @@ class DynamoExport(exporter.FXGraphExtractor):
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
@_beartype.beartype
def pre_export_passes(
self,
options: exporter.ResolvedExportOptions,

View File

@ -16,7 +16,7 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
import torch
import torch.fx
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import (
_pass,
diagnostics,
@ -27,6 +27,7 @@ from torch.onnx._internal.fx import (
from torch.utils import _pytree
@_beartype.beartype
def _fx_node_to_onnx_message_formatter(
fn: Callable,
self,
@ -37,6 +38,7 @@ def _fx_node_to_onnx_message_formatter(
return f"FX Node: {node.op}:{node.target}[name={node.name}]. "
@_beartype.beartype
def _fx_graph_to_onnx_message_formatter(
fn: Callable,
self,
@ -85,6 +87,7 @@ def _location_from_fx_stack_trace(
return None
@_beartype.beartype
def _retrieve_or_adapt_input_to_graph_set(
fx_node_arg: fx_type_utils.Argument,
fx_name_to_onnxscript_value: Dict[
@ -194,7 +197,7 @@ def _retrieve_or_adapt_input_to_graph_set(
)
return sequence_elements
if isinstance(onnx_tensor, torch.dtype):
onnx_tensor = int( # type: ignore[call-overload]
onnx_tensor = int(
jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type()
)
# NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But
@ -226,11 +229,12 @@ def filter_incompatible_and_dtype_convert_kwargs(kwargs):
# default case.
continue
else:
value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload]
value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type())
filtered[key] = value
return filtered
@_beartype.beartype
def _fill_tensor_shape_type(
onnxscript_values: Union[
onnxscript_graph_building.TorchScriptTensor,
@ -309,6 +313,7 @@ def _fill_tensor_shape_type(
onnxscript_value.name = name
@_beartype.beartype
def _fill_in_default_kwargs(
node: torch.fx.Node,
) -> Tuple[List[fx_type_utils.Argument], Dict[str, fx_type_utils.Argument]]:
@ -343,6 +348,7 @@ def _fill_in_default_kwargs(
return complete_args, complete_kwargs
@_beartype.beartype
def _wrap_fx_args_as_onnxscript_args(
complete_args: List[fx_type_utils.Argument],
complete_kwargs: Dict[str, fx_type_utils.Argument],
@ -405,6 +411,7 @@ class FxOnnxInterpreter:
# DO NOT add other class-level attributes.
self.diagnostic_context = diagnostic_context
@_beartype.beartype
@diagnostics.diagnose_call(
diagnostics.rules.fx_node_to_onnx,
diagnostic_message_formatter=_fx_node_to_onnx_message_formatter,
@ -486,6 +493,7 @@ class FxOnnxInterpreter:
else:
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
@_beartype.beartype
@diagnostics.diagnose_call(
diagnostics.rules.fx_graph_to_onnx,
diagnostic_message_formatter=_fx_graph_to_onnx_message_formatter,
@ -581,6 +589,7 @@ class FxOnnxInterpreter:
return onnxscript_graph
@_beartype.beartype
def placeholder(
self,
node: torch.fx.Node,
@ -636,6 +645,7 @@ class FxOnnxInterpreter:
fx_name_to_onnxscript_value[node.name] = output
@_beartype.beartype
def call_function(
self,
node: torch.fx.Node,
@ -720,6 +730,7 @@ class FxOnnxInterpreter:
)
fx_name_to_onnxscript_value[node.name] = output
@_beartype.beartype
def output(
self,
node: torch.fx.Node,
@ -746,10 +757,12 @@ class FxOnnxInterpreter:
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name]
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
@_beartype.beartype
def call_method(self, node: torch.fx.Node):
# TODO(wechi): Support call_method.
raise RuntimeError("call_method is not supported yet.")
@_beartype.beartype
def call_module(
self,
node: torch.fx.Node,
@ -827,6 +840,7 @@ class FxOnnxInterpreter:
# Skip op_level_validation for call_module. Subgraph nodes are validated individually.
@_beartype.beartype
def get_attr(
self,
node: torch.fx.Node,

View File

@ -10,7 +10,7 @@ import torch.fx
import torch.onnx
import torch.onnx._internal.fx.passes as passes
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal import _beartype, exporter, io_adapter
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
@ -37,6 +37,7 @@ class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
exporter.
"""
@_beartype.beartype
def is_leaf_module(
self, module: torch.nn.Module, module_qualified_name: str
) -> bool:
@ -45,6 +46,7 @@ class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
# torch.fx._symbolic_trace.Tracer.call_module.
return False
@_beartype.beartype
def to_bool(self, obj: torch.fx.Proxy) -> bool:
# FIXME: This is a hack to tracing through if-else Python blocks.
# It may generate incorrect ONNX graphs if the if-else block
@ -80,6 +82,7 @@ def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
return wrapper, target
@_beartype.beartype
def _module_expansion_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
@ -157,6 +160,7 @@ class FXSymbolicTracer(exporter.FXGraphExtractor):
# TODO: plumb ``concrete_args`` to symbolic_trace call at ``generate_fx``
self.concrete_args = concrete_args
@_beartype.beartype
def _trace_into_fx_graph_via_fx_symbolic_trace(
self, model, model_args, model_kwargs
) -> torch.fx.GraphModule:
@ -234,6 +238,7 @@ class FXSymbolicTracer(exporter.FXGraphExtractor):
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
@_beartype.beartype
def pre_export_passes(
self,
options: exporter.ResolvedExportOptions,

View File

@ -22,7 +22,7 @@ from typing import (
import torch
import torch._ops
import torch.fx
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import (
diagnostics,
registration,
@ -31,13 +31,17 @@ from torch.onnx._internal.fx import (
if TYPE_CHECKING:
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
graph_building as onnxscript_graph_building,
)
from torch.onnx import OnnxRegistry
# For beartype
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
graph_building as onnxscript_graph_building,
)
@_beartype.beartype
def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
fn: Callable,
self,
@ -54,6 +58,7 @@ def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"
@_beartype.beartype
def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
fn: Callable,
self,
@ -92,7 +97,7 @@ class OnnxFunctionDispatcher:
def __init__(
self,
onnx_registry: OnnxRegistry,
onnx_registry: "OnnxRegistry",
diagnostic_context: diagnostics.DiagnosticContext,
):
"""Initialize the ONNX Function dispatcher.
@ -104,6 +109,7 @@ class OnnxFunctionDispatcher:
self.onnx_registry = onnx_registry
self.diagnostic_context = diagnostic_context
@_beartype.beartype
def dispatch(
self,
node: torch.fx.Node,
@ -114,7 +120,7 @@ class OnnxFunctionDispatcher:
],
onnx_kwargs: Dict[str, fx_type_utils.Argument],
diagnostic_context: diagnostics.DiagnosticContext,
) -> Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]:
) -> Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"]:
"""Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
Args:
node: The TorchFX node to dispatch the function for.
@ -142,6 +148,7 @@ class OnnxFunctionDispatcher:
diagnostic_context,
)
@_beartype.beartype
def _filter_or_keep_complex(
self,
node,
@ -189,6 +196,7 @@ class OnnxFunctionDispatcher:
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
return default_and_custom_functions
@_beartype.beartype
@diagnostics.diagnose_call(
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter,
@ -276,6 +284,7 @@ class OnnxFunctionDispatcher:
)
return symbolic_function_list[0].onnx_function
@_beartype.beartype
def _get_aten_name(
self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext
) -> registration.OpName:
@ -341,6 +350,7 @@ class OnnxFunctionDispatcher:
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
@_beartype.beartype
@diagnostics.diagnose_call(
diagnostics.rules.find_operator_overloads_in_onnx_registry,
diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter,
@ -542,6 +552,7 @@ class _OnnxSchemaChecker:
"""
return self._matching_score
@_beartype.beartype
def perfect_match_inputs(
self,
diagnostic: diagnostics.Diagnostic,
@ -676,6 +687,7 @@ class _OnnxSchemaChecker:
diagnostic.info("match score: %d", self.match_score)
return is_perfect_match
@_beartype.beartype
def _match_onnx_attribute_type(
self,
attribute_name: str,
@ -701,6 +713,7 @@ class _OnnxSchemaChecker:
return False
return True
@_beartype.beartype
def _record_matching_score(
self,
inputs: Sequence[
@ -755,10 +768,10 @@ class _OnnxSchemaChecker:
# NOTE: Referenced from onnxscript internal function.
# Importing this function makes the code less robust, as it is not a public API.
@_beartype.beartype
def _separate_input_attributes_from_arguments(
self,
param_schemas: Sequence[onnxscript.values.ParamSchema],
param_schemas: Sequence["onnxscript.values.ParamSchema"],
args: Sequence[
Optional[
Union[fx_type_utils.TensorLike, str, int, float, bool, list, complex]
@ -838,6 +851,7 @@ class _OnnxSchemaChecker:
return onnx_inputs, onnx_attributes
@_beartype.beartype
def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
"""Check if the node has complex dtype recursively."""
if (
@ -853,6 +867,7 @@ def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
return False
@_beartype.beartype
def _find_onnx_data_type(
torch_input: Optional[
Union[fx_type_utils.TensorLike, str, int, float, bool, list, tuple, complex]

View File

@ -13,7 +13,7 @@ import torch
import torch.fx
from torch.fx.experimental import symbolic_shapes
from torch.onnx import _constants, _type_utils as jit_type_utils
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import (
diagnostics,
fx_onnx_interpreter,
@ -22,6 +22,7 @@ from torch.onnx._internal.fx import (
from torch.utils import _pytree
@_beartype.beartype
def _op_level_debug_message_formatter(
fn: Callable,
self,
@ -36,6 +37,7 @@ def _op_level_debug_message_formatter(
)
@_beartype.beartype
@diagnostics.diagnose_call(
diagnostics.rules.op_level_debugging,
diagnostic_message_formatter=_op_level_debug_message_formatter,
@ -172,6 +174,7 @@ def validate_op_between_ort_torch(
diagnostic.level = diagnostics.levels.WARNING
@_beartype.beartype
def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size:
"""Convert SymInt to int in shape
@ -198,6 +201,7 @@ def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size:
return torch.Size(list_int_shape)
@_beartype.beartype
def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
shape = _convert_symint_to_int_in_shape(shape)
@ -234,6 +238,7 @@ def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
return torch.randn(shape, dtype=dtype)
@_beartype.beartype
def _fx_args_to_torch_args(
fx_args: List[fx_type_utils.Argument], fx_graph_module: torch.fx.GraphModule
) -> List[fx_type_utils.Argument]:
@ -263,7 +268,7 @@ def _fx_args_to_torch_args(
f"{type(fake_tensor)}."
)
elif isinstance(arg, Sequence):
wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module)) # type: ignore[arg-type]
wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module))
elif isinstance(arg, (int, float, torch.dtype)) or arg is None:
wrapped_args.append(arg)
elif isinstance(arg, torch.device):
@ -276,6 +281,7 @@ def _fx_args_to_torch_args(
return wrapped_args
@_beartype.beartype
def _wrap_fx_args_as_torch_args(
fx_args: List[fx_type_utils.Argument],
fx_kwargs: Dict[str, fx_type_utils.Argument],
@ -291,6 +297,7 @@ def _wrap_fx_args_as_torch_args(
# NOTE: Referenced from onnxscript internal function: _tag_arguments_with_param_schemas.
@_beartype.beartype
def _convert_torch_args_to_onnxfunction_args(
param_schemas: Sequence[onnxscript.values.ParamSchema],
args: List[fx_type_utils.Argument],
@ -351,6 +358,7 @@ def _convert_torch_args_to_onnxfunction_args(
return tagged_args, tagged_kwargs
@_beartype.beartype
def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any:
try:
import numpy as np
@ -365,7 +373,7 @@ def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any:
input = torch.view_as_real(input.resolve_conj())
return input.detach().cpu().numpy()
if isinstance(input, torch.dtype):
return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr,call-overload]
return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr]
if isinstance(input, (tuple, list)):
if len(input) == 0:
return np.array((), dtype=np.int64)

View File

@ -13,8 +13,10 @@ from typing import Callable, Dict, Optional, Tuple
import torch.fx
import torch.fx.traceback as fx_traceback
from torch.onnx._internal import _beartype
@_beartype.beartype
def wrap_graph_module_for_node_meta_preservation(
graph_module: torch.fx.GraphModule,
) -> Callable:
@ -40,6 +42,7 @@ def _get_node_base_name(node_name: str) -> Tuple[str, Optional[int]]:
return node_name, None
@_beartype.beartype
def set_node_name(
node: torch.fx.Node,
new_name: str,
@ -78,6 +81,7 @@ def set_node_name(
name_to_node_cache[new_name] = node
@_beartype.beartype
def replace_placeholder_name_and_target(
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
):

View File

@ -11,6 +11,7 @@ import torch.fx
from torch._dispatch import python as python_dispatch
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
@ -29,6 +30,7 @@ class Decompose(_pass.Transform):
self.enable_dynamic_axes = enable_dynamic_axes
self.allow_fake_constant = allow_fake_constant
@_beartype.beartype
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
assert not kwargs, "kwargs is not supported in Decompose."

View File

@ -11,6 +11,7 @@ import torch.func
import torch.fx
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree
@ -61,6 +62,7 @@ class Functionalize(_pass.Transform):
which are not needed for ONNX inference.
"""
@_beartype.beartype
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
@ -97,6 +99,7 @@ class Functionalize(_pass.Transform):
return wrapped
@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# To preserve stack trace info after `make_fx`.
module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)
@ -142,6 +145,7 @@ class RemoveInputMutation(_pass.Transform):
for inference. They could be useful for training.
"""
@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
for node in reversed(self.module.graph.nodes):
if (

View File

@ -23,6 +23,7 @@ from typing import (
import torch
import torch.fx
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics
from torch.utils import _pytree as pytree
@ -58,6 +59,7 @@ class _ModuleMeta:
_module_name: Final[str] # type: ignore[misc]
_raw_meta: Final[Tuple[Any, Any]] # type: ignore[misc]
@_beartype.beartype
def __init__(
self,
module_name: str,
@ -211,6 +213,7 @@ class _ModuleStackMeta:
_module_stack: Final[List[_ModuleMeta]] # type: ignore[misc]
@_beartype.beartype
def __init__(
self,
nn_module_stack_meta: Optional[
@ -230,7 +233,7 @@ class _ModuleStackMeta:
if is_exported_program:
is_exported_program = False
continue
self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type]
self.push(_ModuleMeta.from_raw_meta(item))
def __len__(self) -> int:
return len(self._module_stack)
@ -259,6 +262,7 @@ class _ModuleStackMeta:
return _ModuleMeta.create_root()
return self._module_stack[-1]
@_beartype.beartype
def is_superset_of(
self,
module_stack: _ModuleStackMeta,
@ -302,6 +306,7 @@ class _ModuleStackMeta:
"""Pushes a module meta to the stack."""
self._module_stack.append(module_meta)
@_beartype.beartype
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, _ModuleStackMeta):
return False
@ -838,6 +843,7 @@ class Modularize(_pass.Transform):
"""
@_beartype.beartype
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
@ -848,6 +854,7 @@ class Modularize(_pass.Transform):
self.module = module
self.is_exported_program = is_exported_program
@_beartype.beartype
def _run(self) -> torch.fx.GraphModule:
# DCE to remove unused nodes.
# If a submodule is unused, it is hard to analyze which nodes constitutes the submodule

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Dict, List, Sequence, Tuple, Union
import torch
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics
@ -30,6 +30,7 @@ class RestoreParameterAndBufferNames(_pass.Transform):
super().__init__(diagnostic_context, fx_module)
self.original_nn_module = original_nn_module
@_beartype.beartype
def _rename_param_and_buffer(
self,
diagnostic: diagnostics.Diagnostic,

View File

@ -27,6 +27,10 @@ from torch._refs.nn import functional as _functional_refs
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
# Imported to resolve beartype issue when type checking node.Argument.
from torch.fx.node import Node # noqa: F401
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
from torch.utils import _python_dispatch, _pytree
@ -1225,6 +1229,7 @@ class TypePromotionTable:
for rule in _EXTRA_TYPE_PROMOTION_RULE_SET:
self.add_rule(rule)
@_beartype.beartype
def add_rule(self, rule: TypePromotionRule) -> None:
"""Add a type promotion rule for a python op in a torch.ops module.
@ -1239,6 +1244,7 @@ class TypePromotionTable:
raise ValueError(f"Invalid type promotion rule: {rule}")
self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule
@_beartype.beartype
def get_rule(
self, py_op: torch._ops.OpOverloadPacket
) -> Optional[TypePromotionRule]:
@ -1246,6 +1252,7 @@ class TypePromotionTable:
return self._rule_table.get(str(py_op), None)
@_beartype.beartype
def get_type_promotion_rule(
diagnostic: diagnostics.Diagnostic,
node: torch.fx.Node,
@ -1298,6 +1305,7 @@ class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
return func(*args, **kwargs)
@_beartype.beartype
def find_compatible_op_overload(
op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict
) -> torch._ops.OpOverload:
@ -1383,6 +1391,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
node.meta["val"] = proxy_tensor.extract_val(out)
return out
@_beartype.beartype
def _create_node(
self,
graph: torch.fx.Graph,
@ -1404,6 +1413,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
self._run_node_and_set_meta(node)
return node
@_beartype.beartype
def _rerun_node_after_type_promotion(
self,
diagnostic: diagnostics.Diagnostic,
@ -1469,6 +1479,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
else:
raise RuntimeError(f"Unexpected node output type: {type(node_val)}.")
@_beartype.beartype
def _maybe_promote_arg(
self,
diagnostic: diagnostics.Diagnostic,
@ -1574,6 +1585,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}")
@_beartype.beartype
def _maybe_promote_node(
self,
diagnostic: diagnostics.Diagnostic,
@ -1708,6 +1720,7 @@ class InsertTypePromotion(_pass.Transform):
fake_args.append(meta_value)
return fake_args
@_beartype.beartype
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
assert not args, (
"`InsertTypePromotion` deduces symbolic fake arguments from the graph. "

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch
import torch.fx
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass
@ -17,6 +18,7 @@ class MovePlaceholderToFront(_pass.Transform):
nodes.
"""
@_beartype.beartype
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
graph_module = self.module
graph = graph_module.graph
@ -51,6 +53,7 @@ class ReplaceGetAttrWithPlaceholder(_pass.Transform):
), "Must run ReplaceGetAttrWithPlaceholder first"
return self._replaced_attrs
@_beartype.beartype
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
graph_module = self.module
graph = graph_module.graph

View File

@ -1,13 +1,11 @@
# mypy: allow-untyped-defs
import copy
import functools
from typing import List, TYPE_CHECKING, Union
import io
from typing import List, Union
import torch
if TYPE_CHECKING:
import io
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
@functools.lru_cache(None)

View File

@ -7,6 +7,7 @@ import types
from typing import Optional, TYPE_CHECKING, Union
import torch._ops
from torch.onnx._internal import _beartype
# We can only import onnx from this module in a type-checking context to ensure that
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
@ -26,7 +27,7 @@ class ONNXFunction:
"""
onnx_function: Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]
onnx_function: Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"]
op_full_name: str
is_custom: bool = False
is_complex: bool = False
@ -41,6 +42,7 @@ class OpName:
overload: str
@classmethod
@_beartype.beartype
def from_name_parts(
cls, namespace: str, op_name: str, overload: Optional[str] = None
) -> OpName:
@ -51,6 +53,7 @@ class OpName:
return cls(namespace, op_name, overload)
@classmethod
@_beartype.beartype
def from_qualified_name(cls, qualified_name: str) -> OpName:
"""When the name is <namespace>::<op_name>[.<overload>]"""
namespace, opname_overload = qualified_name.split("::")
@ -59,10 +62,12 @@ class OpName:
return cls(namespace, op_name, overload)
@classmethod
@_beartype.beartype
def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName:
return cls.from_qualified_name(op_overload.name())
@classmethod
@_beartype.beartype
def from_builtin_function(
cls, builtin_function: types.BuiltinFunctionType
) -> OpName:
@ -81,5 +86,6 @@ class OpName:
module = builtin_function.__module__ # _operators or math
return cls.from_qualified_name(module + "::" + op)
@_beartype.beartype
def qualified_name(self) -> str:
return f"{self.namespace}::{self.op_name}.{self.overload}"

View File

@ -8,7 +8,7 @@ from typing import Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal import _beartype
if TYPE_CHECKING:
import onnx
@ -16,12 +16,13 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor,
name: str,
location: str,
basepath: str,
dtype_override: Optional[onnx.TypeProto] = None, # type: ignore[name-defined]
dtype_override: Optional["onnx.TypeProto"] = None, # type: ignore[name-defined]
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
@ -113,6 +114,7 @@ def _convert_safetensors_to_torch_format(safetensors_file):
# TODO: generalize to allow more checkpoints formats (torch or gguf)
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,

View File

@ -11,7 +11,7 @@ from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Un
import torch._dynamo
import torch.fx
import torch.onnx
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal import _beartype, exporter, io_adapter
from torch.onnx._internal.diagnostics import infra
if TYPE_CHECKING:
@ -35,7 +35,7 @@ class TorchExport(exporter.FXGraphExtractor):
def generate_fx(
self,
options: exporter.ResolvedExportOptions,
model: ExportedProgram, # type: ignore[override]
model: "ExportedProgram", # type: ignore[override]
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
@ -96,6 +96,7 @@ class TorchExport(exporter.FXGraphExtractor):
# Export FX graph to ONNX ModelProto.
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]
@_beartype.beartype
def pre_export_passes(
self,
options: exporter.ResolvedExportOptions,

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import inspect
from typing import (
Any,
Callable,
@ -11,18 +13,15 @@ from typing import (
runtime_checkable,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
import torch
import torch.export as torch_export
from torch.onnx._internal import _beartype
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import inspect
# TODO(bowbao): Add diagnostics for IO adapters.
@ -56,6 +55,7 @@ class InputAdapter:
def __init__(self, steps: Optional[List[InputAdaptStep]] = None):
self._steps = steps or []
@_beartype.beartype
def append_step(self, step: InputAdaptStep) -> None:
"""Appends a step to the input adapt steps.
@ -64,6 +64,7 @@ class InputAdapter:
"""
self._steps.append(step)
@_beartype.beartype
def apply(
self,
*model_args,
@ -71,7 +72,7 @@ class InputAdapter:
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
] = None,
**model_kwargs,
) -> Sequence[Union[int, float, bool, str, torch.Tensor, torch.dtype, None]]:
) -> Sequence[Union[int, float, bool, str, "torch.Tensor", torch.dtype, None]]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
Args:
@ -118,6 +119,7 @@ class OutputAdapter:
def __init__(self, steps: Optional[List[OutputAdaptStep]] = None):
self._steps = steps or []
@_beartype.beartype
def append_step(self, step: OutputAdaptStep) -> None:
"""Appends a step to the output format steps.
@ -126,13 +128,14 @@ class OutputAdapter:
"""
self._steps.append(step)
@_beartype.beartype
def apply(
self,
model_outputs: Any,
model: Optional[
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
] = None,
) -> Sequence[Union[torch.Tensor, int, float, bool, str]]:
) -> Sequence[Union["torch.Tensor", int, float, bool, str]]:
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
Args:
@ -260,7 +263,7 @@ class MergeKwargsIntoArgsInputStep(InputAdaptStep):
class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
"""Append parameters and buffers to model's positional argument list."""
def __init__(self, inputs: Tuple[torch.Tensor, ...]) -> None:
def __init__(self, inputs: Tuple["torch.Tensor", ...]) -> None:
self.inputs = inputs
def apply(

View File

@ -13,7 +13,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Un
import torch
from torch import _C
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.onnx._internal import _beartype, registration
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
@ -43,7 +43,7 @@ class GraphContext:
block: _C.Block
opset: int
original_node: _C.Node
params_dict: Dict[str, _C.IValue]
params_dict: Dict[str, "_C.IValue"]
env: Dict[_C.Value, _C.Value]
values_in_env: Set[_C.Value]
new_nodes: List[_C.Node] = dataclasses.field(default_factory=list)
@ -53,6 +53,7 @@ class GraphContext:
def __getattr__(self, name: str) -> Any:
return getattr(self.graph, name)
@_beartype.beartype
def op(
self,
opname: str,
@ -90,6 +91,7 @@ class GraphContext:
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
@_beartype.beartype
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
"""Generates an ONNX ATen op node.
@ -107,6 +109,7 @@ class GraphContext:
# We are probably going to remove this only after the fx exporter is established.
at = aten_op
@_beartype.beartype
def onnxscript_op(
self,
onnx_fn,
@ -150,6 +153,7 @@ class GraphContext:
return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs)
@_beartype.beartype
def add_op_with_blocks(
graph_context: GraphContext,
opname: str,
@ -197,6 +201,7 @@ def add_op_with_blocks(
return output_values, tuple(new_contexts), node
@_beartype.beartype
def _add_op(
graph_context: GraphContext,
opname: str,
@ -260,6 +265,7 @@ def _add_op(
return tuple(node.outputs())
@_beartype.beartype
def _const_if_tensor(graph_context: GraphContext, arg):
if arg is None:
return arg
@ -308,18 +314,21 @@ def _create_node(
return node
@_beartype.beartype
def _is_onnx_list(value):
return isinstance(value, Iterable) and not isinstance(
value, (str, bytes, torch.Tensor)
)
@_beartype.beartype
def _scalar(x: torch.Tensor):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x[0]
@_beartype.beartype
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
r"""Initializes the right attribute based on type of value."""
m = _ATTR_PATTERN.match(key)
@ -336,10 +345,12 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
# TODO: Expose this to user when migrating symbolic helper functions to here.
@_beartype.beartype
def _is_tensor(x: _C.Value) -> bool:
return x.type().isSubtypeOf(_C.TensorType.get())
@_beartype.beartype
def get_device_from_value(value: _C.Value) -> Optional[torch.device]:
if not _is_tensor(value):
return None
@ -347,6 +358,7 @@ def get_device_from_value(value: _C.Value) -> Optional[torch.device]:
return tensor_type.device()
@_beartype.beartype
def parse_node_kind(kind: str) -> Tuple[str, str]:
"""Parse node kind into domain and Op name."""
if "::" not in kind:
@ -357,16 +369,19 @@ def parse_node_kind(kind: str) -> Tuple[str, str]:
return domain, opname
@_beartype.beartype
def is_aten(domain: str) -> bool:
"""Check if the domain is official."""
return domain == "aten"
@_beartype.beartype
def is_prim(domain: str) -> bool:
"""Check if the domain is official."""
return domain == "prim"
@_beartype.beartype
def is_onnx(domain: str) -> bool:
"""Check if the domain is official."""
return domain == "onnx"

View File

@ -14,9 +14,10 @@ import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import _constants, _exporter_states, errors
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
@_beartype.beartype
def export_as_test_case(
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
) -> str:
@ -72,6 +73,7 @@ def export_as_test_case(
return test_case_dir
@_beartype.beartype
def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
"""Load a self contained ONNX test case from a directory.
@ -122,6 +124,7 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
return model_bytes, inputs, outputs
@_beartype.beartype
def export_data(data, value_info_proto, f: str) -> None:
"""Export data to ONNX protobuf format.
@ -160,6 +163,7 @@ def export_data(data, value_info_proto, f: str) -> None:
)
@_beartype.beartype
def _export_file(
model_bytes: bytes,
f: Union[io.BytesIO, str],
@ -206,6 +210,7 @@ def _export_file(
raise ValueError("Unknown export type")
@_beartype.beartype
def _add_onnxscript_fn(
model_bytes: bytes,
custom_opsets: Mapping[str, int],
@ -238,6 +243,7 @@ def _add_onnxscript_fn(
return model_bytes
@_beartype.beartype
def _find_onnxscript_op(
graph_proto,
included_node_func: Set[str],

View File

@ -15,7 +15,6 @@ from typing import (
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeAlias
@ -32,11 +31,9 @@ from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.utils import _pytree
if TYPE_CHECKING:
import onnx
try:
# Use try-except to initialize package-dependent global variables.
import onnx
import onnxruntime # type: ignore[import]
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
@ -555,9 +552,9 @@ class OrtExecutionInfoPerSession:
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined]
# For the ONNX model stored in self.session, self.input_devices[i] is the
# i-th positional input's device.
self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices
self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices
# Similar to self.input_devices, but for outputs.
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
# (i.e., args passed into OrtBackend._ort_acclerated_call).
self.example_outputs: Union[

View File

@ -15,7 +15,7 @@ from typing import (
)
from torch.onnx import _constants, errors
from torch.onnx._internal import _beartype
OpsetVersion = int
@ -265,6 +265,7 @@ class SymbolicRegistry:
return set(self._registry)
@_beartype.beartype
def onnx_symbolic(
name: str,
opset: Union[OpsetVersion, Sequence[OpsetVersion]],
@ -313,6 +314,7 @@ def onnx_symbolic(
return wrapper
@_beartype.beartype
def custom_onnx_symbolic(
name: str,
opset: Union[OpsetVersion, Sequence[OpsetVersion]],

View File

@ -9,6 +9,7 @@ from typing import Dict, Literal, Optional, Union
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import errors
from torch.onnx._internal import _beartype
if typing.TYPE_CHECKING:
# Hack to help mypy to recognize torch._C.Value
@ -105,6 +106,7 @@ class JitScalarType(enum.IntEnum):
UNDEFINED = enum.auto() # 20
@classmethod
@_beartype.beartype
def _from_name(
cls, name: Union[ScalarName, TorchName, Optional[str]]
) -> JitScalarType:
@ -134,6 +136,7 @@ class JitScalarType(enum.IntEnum):
raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'")
@classmethod
@_beartype.beartype
def from_dtype(cls, dtype: Optional[torch.dtype]) -> JitScalarType:
"""Convert a torch dtype to JitScalarType.
@ -156,6 +159,7 @@ class JitScalarType(enum.IntEnum):
return _DTYPE_TO_SCALAR_TYPE[dtype]
@classmethod
@_beartype.beartype
def from_onnx_type(
cls, onnx_type: Optional[Union[int, _C_onnx.TensorProtoDataType]]
) -> JitScalarType:
@ -175,6 +179,7 @@ class JitScalarType(enum.IntEnum):
return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)]
@classmethod
@_beartype.beartype
def from_value(
cls, value: Union[None, torch._C.Value, torch.Tensor], default=None
) -> JitScalarType:
@ -242,18 +247,22 @@ class JitScalarType(enum.IntEnum):
value,
)
@_beartype.beartype
def scalar_name(self) -> ScalarName:
"""Convert a JitScalarType to a JIT scalar type name."""
return _SCALAR_TYPE_TO_NAME[self]
@_beartype.beartype
def torch_name(self) -> TorchName:
"""Convert a JitScalarType to a torch type name."""
return _SCALAR_TYPE_TO_TORCH_NAME[self]
@_beartype.beartype
def dtype(self) -> torch.dtype:
"""Convert a JitScalarType to a torch dtype."""
return _SCALAR_TYPE_TO_DTYPE[self]
@_beartype.beartype
def onnx_type(self) -> _C_onnx.TensorProtoDataType:
"""Convert a JitScalarType to an ONNX data type."""
if self not in _SCALAR_TYPE_TO_ONNX:
@ -262,6 +271,7 @@ class JitScalarType(enum.IntEnum):
)
return _SCALAR_TYPE_TO_ONNX[self]
@_beartype.beartype
def onnx_compatible(self) -> bool:
"""Return whether this JitScalarType is compatible with ONNX."""
return (
@ -271,11 +281,13 @@ class JitScalarType(enum.IntEnum):
)
@_beartype.beartype
def valid_scalar_name(scalar_name: Union[ScalarName, str]) -> bool:
"""Return whether the given scalar name is a valid JIT scalar type name."""
return scalar_name in _SCALAR_NAME_TO_TYPE
@_beartype.beartype
def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
"""Return whether the given torch name is a valid torch type name."""
return torch_name in _TORCH_NAME_TO_SCALAR_TYPE

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
import importlib
import inspect

View File

@ -27,10 +27,8 @@ from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _type_utils, errors, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils
if typing.TYPE_CHECKING:
from torch.types import Number
from torch.onnx._internal import _beartype, jit_utils
from torch.types import Number
# ---------------------------------------------------------------------------------
@ -50,6 +48,7 @@ _ValueDescriptor = Literal[
]
@_beartype.beartype
def _parse_arg(
value,
desc: _ValueDescriptor,
@ -119,6 +118,7 @@ def _parse_arg(
)
@_beartype.beartype
def _node_get(node: _C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type."""
assert isinstance(node, _C.Node)
@ -126,11 +126,13 @@ def _node_get(node: _C.Node, key: str):
return getattr(node, sel)(key)
@_beartype.beartype
def _is_onnx_constant(value: _C.Value):
"""Whether a Value is an ONNX constant."""
return value.node().kind() == "onnx::Constant"
@_beartype.beartype
def _maybe_get_const(
value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]],
descriptor: _ValueDescriptor,
@ -143,6 +145,7 @@ def _maybe_get_const(
return value
@_beartype.beartype
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, "t")
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
@ -150,6 +153,7 @@ def _maybe_get_scalar(value):
return value
@_beartype.beartype
def _get_const(value, desc, arg_name):
if not _is_constant(value):
raise errors.SymbolicValueError(
@ -160,6 +164,7 @@ def _get_const(value, desc, arg_name):
return _parse_arg(value, desc)
@_beartype.beartype
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
list_node = list_value.node()
if list_node.kind() != "prim::ListConstruct":
@ -171,6 +176,7 @@ def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
return list(list_node.inputs())
@_beartype.beartype
def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
tuple_node = tuple_value.node()
if not _is_tuple_construct(tuple_value):
@ -182,6 +188,7 @@ def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
return tuple(tuple_node.inputs())
@_beartype.beartype
def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
"""Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
Args:
@ -205,10 +212,12 @@ def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be unpacked.
@_beartype.beartype
def _is_packed_list(list_value: Any) -> bool:
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
@_beartype.beartype
def parse_args(*arg_descriptors: _ValueDescriptor):
"""A decorator which converts args from torch._C.Value to built-in types.
@ -287,6 +296,7 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
return decorator
@_beartype.beartype
def quantized_args(
*arg_q_descriptors: bool,
scale: Optional[float] = None,
@ -423,6 +433,7 @@ def quantized_args(
return decorator
@_beartype.beartype
def _scalar(x: Any) -> Optional[Number]:
"""Convert a scalar tensor into a Python value."""
if isinstance(x, torch.Tensor) and x.shape == ():
@ -430,6 +441,7 @@ def _scalar(x: Any) -> Optional[Number]:
return None
@_beartype.beartype
def _if_scalar_type_as(self, tensor):
"""
Convert self into the same type of tensor, as necessary.
@ -449,14 +461,17 @@ def _if_scalar_type_as(self, tensor):
return self
@_beartype.beartype
def _is_none(x: Any) -> bool:
return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False)
@_beartype.beartype
def _is_value(x: Any) -> bool:
return isinstance(x, _C.Value)
@_beartype.beartype
def _is_constant(value: Any) -> bool:
return not _is_value(value) or value.node().kind() in {
"onnx::Constant",
@ -464,6 +479,7 @@ def _is_constant(value: Any) -> bool:
}
@_beartype.beartype
def _is_tensor(x: _C.Value) -> bool:
return x.type().isSubtypeOf(_C.TensorType.get())
@ -475,10 +491,12 @@ def _as_list_type(jit_type: _C.JitType) -> Optional[_C.ListType]:
return None
@_beartype.beartype
def _is_list(x: _C.Value) -> bool:
return _as_list_type(x.type()) is not None
@_beartype.beartype
def _is_tensor_list(x: _C.Value) -> bool:
x_type = _as_list_type(x.type())
if x_type is None:
@ -486,6 +504,7 @@ def _is_tensor_list(x: _C.Value) -> bool:
return isinstance(x_type.getElementType(), _C.TensorType)
@_beartype.beartype
def _is_scalar_list(x: _C.Value) -> bool:
"""Checks if x is a scalar list, for example: List[float], List[int].
@ -499,10 +518,12 @@ def _is_scalar_list(x: _C.Value) -> bool:
return scalar_type.onnx_compatible()
@_beartype.beartype
def _is_tuple_construct(x: _C.Value) -> bool:
return x.node().kind() == "prim::TupleConstruct"
@_beartype.beartype
def is_complex_value(x: _C.Value) -> bool:
assert _is_value(x)
return _type_utils.JitScalarType.from_value(
@ -514,6 +535,7 @@ def is_complex_value(x: _C.Value) -> bool:
}
@_beartype.beartype
def _get_tensor_rank(x: _C.Value) -> Optional[int]:
if not _is_tensor(x) or x.type() is None:
return None
@ -522,6 +544,7 @@ def _get_tensor_rank(x: _C.Value) -> Optional[int]:
return x_type.dim()
@_beartype.beartype
def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
if not _is_tensor(x) or x.type() is None:
return None
@ -536,11 +559,13 @@ def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
return x_type.sizes()
@_beartype.beartype
def _get_tensor_dim_size(x: _C.Value, dim: int) -> Optional[int]:
sizes = _get_tensor_sizes(x)
return sizes[dim] if sizes else None
@_beartype.beartype
def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
if dim == -1:
tensor_rank = _get_tensor_rank(x)
@ -556,12 +581,14 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
return dim
@_beartype.beartype
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
_onnx_unsupported(f"{op}, {msg}", value)
@_beartype.beartype
def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoReturn:
message = (
f"Unsupported: ONNX export of operator {op_name}. "
@ -576,6 +603,7 @@ def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoRetur
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _onnx_opset_unsupported(
op_name: str,
current_opset: int,
@ -594,6 +622,7 @@ def _onnx_opset_unsupported(
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _onnx_opset_unsupported_detailed(
op_name: str,
current_opset: int,
@ -613,6 +642,7 @@ def _onnx_opset_unsupported_detailed(
raise errors.OnnxExporterError(message)
@_beartype.beartype
def _block_list_in_opset(name: str):
def symbolic_fn(*args, **kwargs):
raise errors.OnnxExporterError(
@ -624,6 +654,7 @@ def _block_list_in_opset(name: str):
return symbolic_fn
@_beartype.beartype
def _try_get_scalar_type(*args) -> Optional[_type_utils.JitScalarType]:
for arg in args:
scalar_type = _type_utils.JitScalarType.from_value(
@ -634,19 +665,21 @@ def _try_get_scalar_type(*args) -> Optional[_type_utils.JitScalarType]:
return None
@_beartype.beartype
def _type_promote_from_values(*args) -> _type_utils.JitScalarType:
undef = _type_utils.JitScalarType.UNDEFINED
jit_types = [_try_get_scalar_type(arg) for arg in args]
if len(jit_types) == 0:
return undef
if len(jit_types) == 1:
return jit_types[0] # type: ignore[return-value]
new_dtype = jit_types[0].dtype() # type: ignore[union-attr]
return jit_types[0]
new_dtype = jit_types[0].dtype()
for t in jit_types:
new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr]
new_dtype = torch.promote_types(new_dtype, t.dtype())
return _type_utils.JitScalarType.from_dtype(new_dtype)
@_beartype.beartype
def _maybe_cast_to_type(
g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType
):
@ -662,6 +695,7 @@ def _maybe_cast_to_type(
return value
@_beartype.beartype
def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True):
index_const = _maybe_get_scalar(index)
index_dim = _get_tensor_rank(index)
@ -686,6 +720,7 @@ def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=Tr
return g.op("Gather", self, index, axis_i=dim)
@_beartype.beartype
def _slice_helper(
g: jit_utils.GraphContext,
input,
@ -704,6 +739,7 @@ def _slice_helper(
return _slice10(g, input, axes, starts, ends, steps)
@_beartype.beartype
def _is_fp(value) -> bool:
return _type_utils.JitScalarType.from_value(
value, _type_utils.JitScalarType.UNDEFINED
@ -715,12 +751,14 @@ def _is_fp(value) -> bool:
}
@_beartype.beartype
def _is_bool(value) -> bool:
return _type_utils.JitScalarType.from_value(
value, _type_utils.JitScalarType.UNDEFINED
) in {_type_utils.JitScalarType.BOOL}
@_beartype.beartype
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
"""Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
@ -739,6 +777,7 @@ def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
return g.op("Constant", value_t=torch.tensor(scalar))
@_beartype.beartype
def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None):
if out is not None:
_unimplemented("Sort", "Out parameter is not supported")
@ -758,6 +797,7 @@ def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None
)
@_beartype.beartype
def _topk_helper(
g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None
):
@ -779,6 +819,7 @@ def _topk_helper(
)
@_beartype.beartype
def _lt_helper(g: jit_utils.GraphContext, input, other):
if g.opset <= 8:
from torch.onnx.symbolic_opset8 import lt as _lt8
@ -790,6 +831,7 @@ def _lt_helper(g: jit_utils.GraphContext, input, other):
return _lt9(g, input, other)
@_beartype.beartype
def _interpolate_warning(interpolate_mode):
onnx_op = (
"onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
@ -807,6 +849,7 @@ def _interpolate_warning(interpolate_mode):
)
@_beartype.beartype
def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
if _is_constant(axes_i[0]):
if g.opset >= 13:
@ -821,6 +864,7 @@ def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
return g.op("Unsqueeze", input, axes_i[0])
@_beartype.beartype
def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
if _is_constant(axes_i[0]):
if g.opset >= 13:
@ -846,6 +890,7 @@ def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
return g.op("Squeeze", input, axes_t)
@_beartype.beartype
def _reducesum_helper(
g: jit_utils.GraphContext,
input,
@ -877,6 +922,7 @@ def _reducesum_helper(
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
@_beartype.beartype
def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim):
output_size = _maybe_get_const(output_size, "is")
if _is_value(output_size):
@ -903,6 +949,7 @@ def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, d
return scales
@_beartype.beartype
def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
scales[0]
@ -919,6 +966,7 @@ def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
return scales
@_beartype.beartype
def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
if mode == "nearest":
align_corners = None
@ -930,6 +978,7 @@ def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
return scales, align_corners
@_beartype.beartype
def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
scale_factor_rank = _get_tensor_rank(scale_factor)
@ -947,6 +996,7 @@ def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
return scale_factor
@_beartype.beartype
def _interpolate_get_scales_and_mode(
g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners
):
@ -982,6 +1032,7 @@ def _interpolate_get_scales_and_mode(
return scale_factor, mode
@_beartype.beartype
def _argmin_argmax_helper(
g: jit_utils.GraphContext,
input: torch._C.Value,
@ -1020,6 +1071,7 @@ def _argmin_argmax_helper(
return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
@_beartype.beartype
def _interpolate_helper(name, dim, interpolate_mode):
@quantized_args(True, False, False)
def symbolic_fn(g, input, output_size, *args):
@ -1087,6 +1139,7 @@ def _interpolate_helper(name, dim, interpolate_mode):
return symbolic_fn
@_beartype.beartype
def __interpolate_helper(
g: jit_utils.GraphContext,
input,
@ -1186,6 +1239,7 @@ def __interpolate_helper(
) # only valid when mode="nearest"
@_beartype.beartype
def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
if g.opset < 11:
from torch.onnx.symbolic_opset9 import unbind
@ -1196,6 +1250,7 @@ def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
return unbind(g, self, dim, _outputs)
@_beartype.beartype
def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
if g.opset <= 10:
from torch.onnx.symbolic_opset9 import scatter
@ -1205,6 +1260,7 @@ def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
return scatter(g, self, dim, index, src)
@_beartype.beartype
def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
if g.opset <= 12:
split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
@ -1216,6 +1272,7 @@ def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
return split_out if reps > 1 else [split_out]
@_beartype.beartype
def _repeat_interleave_single_value_repeat_helper(
g: jit_utils.GraphContext, self, repeats, dim
):
@ -1237,7 +1294,7 @@ def _repeat_interleave_single_value_repeat_helper(
# repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'.
if const_repeats:
# 'Repeats' is a constant, 'repeats_per_dim' can be a constant.
onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type]
onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64)
onehot[dim + 1] = reps
repeats_per_dim = g.op("Constant", value_t=onehot)
else:
@ -1258,6 +1315,7 @@ def _repeat_interleave_single_value_repeat_helper(
return flatten(g, tiled, dim, dim + 1)
@_beartype.beartype
def _arange_cast_helper(
g: jit_utils.GraphContext, end, start=None, step=None, dtype=None
) -> Tuple[
@ -1300,6 +1358,7 @@ def _arange_cast_helper(
return scalar_type, end, start, step
@_beartype.beartype
def _arange_helper(g: jit_utils.GraphContext, *args):
if g.opset <= 10:
from torch.onnx.symbolic_opset9 import arange
@ -1308,6 +1367,7 @@ def _arange_helper(g: jit_utils.GraphContext, *args):
return arange(g, *args)
@_beartype.beartype
def _size_helper(g: jit_utils.GraphContext, self, dim):
full_shape = g.op("Shape", self)
from torch.onnx.symbolic_opset9 import select
@ -1315,6 +1375,7 @@ def _size_helper(g: jit_utils.GraphContext, self, dim):
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
@_beartype.beartype
def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
# 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
# 2. expand index => [..., dim, ...], same shape as self except for dim.
@ -1350,6 +1411,7 @@ def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
# allowzero=1 indicates that if any value in the 'shape' input is set to zero,
# the zero value is honored, similar to NumPy.
# allowzero=1 is only supported for opset version >= 14.
@_beartype.beartype
def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
shape = _maybe_get_const(shape, "is")
if not _is_value(shape):
@ -1364,6 +1426,7 @@ def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
return g.op("Reshape", input, shape, allowzero_i=allowzero)
@_beartype.beartype
def _batchnorm_helper(
g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var
):
@ -1421,6 +1484,7 @@ def _batchnorm_helper(
return weight, bias, running_mean, running_var
@_beartype.beartype
def _avgpool_helper(
tuple_fn: Callable[[Any], Sequence[int]],
padding: Union[int, Sequence[int]],
@ -1434,6 +1498,7 @@ def _avgpool_helper(
return tuple(tuple_fn(padding))
@_beartype.beartype
def check_training_mode(op_train_mode: int, op_name: str) -> None:
"""Warns the user if the model's training mode and the export mode do not agree."""
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
@ -1459,6 +1524,7 @@ def check_training_mode(op_train_mode: int, op_name: str) -> None:
)
@_beartype.beartype
def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
input_size = g.op("Shape", input)
slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
@ -1479,6 +1545,7 @@ def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
return _reshape_from_tensor(g, input, final_shape)
@_beartype.beartype
def _is_split_static(split_size_or_sizes, _outputs):
if _outputs is None:
return False
@ -1490,12 +1557,14 @@ def _is_split_static(split_size_or_sizes, _outputs):
return True
@_beartype.beartype
def _optional_input_placeholder_tensor(g):
n = g.op("prim::Constant")
n.setType(_C.OptionalType.ofTensor())
return n
@_beartype.beartype
def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
rank = _get_tensor_rank(self)
if rank is not None and any(
@ -1507,6 +1576,7 @@ def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
return g.op(op_name, self, keepdims_i=0)
@_beartype.beartype
def dequantize_helper(
g: jit_utils.GraphContext,
qtensor: _C.Value,
@ -1555,6 +1625,7 @@ def dequantize_helper(
)
@_beartype.beartype
def quantize_helper(
g: jit_utils.GraphContext,
tensor: _C.Value,
@ -1616,6 +1687,7 @@ def quantize_helper(
return g.op("prim::TupleConstruct", *args)
@_beartype.beartype
def requantize_bias_helper(
g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None
):
@ -1638,6 +1710,7 @@ def requantize_bias_helper(
return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
@_beartype.beartype
def args_have_same_dtype(args):
assert args
base_dtype = _type_utils.JitScalarType.from_value(args[0])
@ -1647,6 +1720,7 @@ def args_have_same_dtype(args):
return has_same_dtype
@_beartype.beartype
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
@ -1700,6 +1774,7 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
return self
@_beartype.beartype
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
@ -1721,7 +1796,9 @@ def _apply_params(*args, **kwargs):
return _apply
@_beartype.beartype
def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None or dim == ():
@ -1754,8 +1831,10 @@ def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
return symbolic
@_beartype.beartype
def _overload_by_arg_count(fn):
@functools.wraps(fn)
@_beartype.beartype
def wrapper(g, *args):
overloads = fn(g, *args)
for overload in overloads:
@ -1767,6 +1846,7 @@ def _overload_by_arg_count(fn):
return wrapper
@_beartype.beartype
def _reduce_with_dtype_helper(
onnx_op: str, name: str, allow_multi_dim_support: bool = True
):
@ -1821,6 +1901,7 @@ def _reduce_with_dtype_helper(
return reduce
@_beartype.beartype
def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.max(input)
if dim_or_y is None and keepdim is None:
@ -1841,6 +1922,7 @@ def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return max, indices
@_beartype.beartype
def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
# torch.min(input)
if dim_or_y is None and keepdim is None:
@ -1861,12 +1943,14 @@ def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return min, indices
@_beartype.beartype
def _numel_helper(g: jit_utils.GraphContext, self):
shape = g.op("Shape", self)
return g.op("ReduceProd", shape, keepdims_i=0)
@parse_args("v", "is", "i", "i")
@_beartype.beartype
def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim):
if g.opset < 18:
if dim is None:
@ -1939,6 +2023,7 @@ def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim)
return var, mean
@_beartype.beartype
def _embedding_bag_helper(
g: jit_utils.GraphContext,
embedding_matrix,
@ -2045,6 +2130,7 @@ def _embedding_bag_helper(
return loop.node().output(), None, None, None
@_beartype.beartype
def _linalg_vector_norm_helper(
g: jit_utils.GraphContext,
self: torch._C.Value,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
from __future__ import annotations
import functools
@ -21,7 +20,7 @@ from torch.onnx import (
symbolic_opset9 as opset9,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
@ -73,6 +72,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
@_onnx_symbolic("aten::div")
@_beartype.beartype
def div(g: jit_utils.GraphContext, self, other, *args):
if len(args) == 0:
return opset9.true_divide(g, self, other)
@ -81,6 +81,7 @@ def div(g: jit_utils.GraphContext, self, other, *args):
@symbolic_helper.parse_args("v", "v", "s")
@_beartype.beartype
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
if rounding_mode == "floor":
return _floor_divide(g, self, other)
@ -89,6 +90,7 @@ def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
@_onnx_symbolic("aten::_floor_divide")
@_beartype.beartype
def _floor_divide(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = opset9.true_divide(g, self, other)
@ -111,12 +113,14 @@ def _floor_divide(g: jit_utils.GraphContext, self, other):
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
@ -304,6 +308,7 @@ def _aten_max_pool_with_indices_onnx(
)
],
)
@_beartype.beartype
def _max_pool(name: str, expand_size: int, return_indices: bool):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
@ -394,9 +399,11 @@ def _adjust_attributes_of_avg_pool(
"aten::avg_pool3d",
decorate=[symbolic_helper._apply_params("avg_pool3d", 3)],
)
@_beartype.beartype
def _avg_pool(name, expand_size):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
@ -450,8 +457,10 @@ def _avg_pool(name, expand_size):
"aten::upsample_trilinear3d",
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
@_beartype.beartype
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
@ -470,6 +479,7 @@ def _interpolate(name, dim, interpolate_mode):
@_onnx_symbolic("aten::__interpolate")
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
@ -486,6 +496,7 @@ def __interpolate(
return g.op("Resize", input, scales, mode_s=mode)
@_beartype.beartype
def _slice(
g: jit_utils.GraphContext,
input: torch._C.Value,
@ -545,6 +556,7 @@ def _slice(
@_onnx_symbolic("aten::slice")
@_beartype.beartype
def slice(g: jit_utils.GraphContext, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
@ -568,6 +580,7 @@ def slice(g: jit_utils.GraphContext, self, *args):
@_onnx_symbolic("aten::flip")
@symbolic_helper.parse_args("v", "is")
@_beartype.beartype
def flip(g: jit_utils.GraphContext, input, dims):
return symbolic_helper._slice_helper(
g,
@ -580,12 +593,14 @@ def flip(g: jit_utils.GraphContext, input, dims):
@_onnx_symbolic("aten::fmod")
@_beartype.beartype
def fmod(g: jit_utils.GraphContext, input, other):
return g.op("Mod", input, other, fmod_i=1)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
@ -672,6 +687,7 @@ def embedding_bag(
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
@ -719,11 +735,13 @@ def fake_quantize_per_tensor_affine(
@_onnx_symbolic("aten::isinf")
@_beartype.beartype
def isinf(g: jit_utils.GraphContext, input):
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
@_onnx_symbolic("aten::isfinite")
@_beartype.beartype
def isfinite(g: jit_utils.GraphContext, input):
inf_node = isinf(g, input)
nan_node = opset9.isnan(g, input)
@ -731,6 +749,7 @@ def isfinite(g: jit_utils.GraphContext, input):
@_onnx_symbolic("aten::quantize_per_tensor")
@_beartype.beartype
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
# TODO(justinchuby): Extract all the cast ops into a helper function.
@ -742,12 +761,14 @@ def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dty
@_onnx_symbolic("aten::dequantize")
@_beartype.beartype
def dequantize(g: jit_utils.GraphContext, input):
return symbolic_helper.dequantize_helper(g, input)[0]
@_onnx_symbolic("aten::nan_to_num")
@symbolic_helper.parse_args("v", "f", "f", "f")
@_beartype.beartype
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
# Cannot create a int type tensor with inf/nan values, so we simply
# return the original tensor
@ -803,6 +824,7 @@ def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
# introduced in opset version 10.
@_onnx_symbolic("quantized::linear")
@_beartype.beartype
def quantized_linear(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
@ -817,6 +839,7 @@ def quantized_linear(
@_onnx_symbolic("quantized::linear_relu")
@_beartype.beartype
def quantized_linear_relu(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
@ -832,6 +855,7 @@ def quantized_linear_relu(
@_onnx_symbolic("quantized::add")
@_beartype.beartype
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -842,6 +866,7 @@ def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
@_onnx_symbolic("quantized::add_relu")
@_beartype.beartype
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -853,6 +878,7 @@ def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point)
@_onnx_symbolic("quantized::mul")
@_beartype.beartype
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
@ -863,6 +889,7 @@ def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -872,6 +899,7 @@ def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
@_onnx_symbolic("quantized::sigmoid")
@_beartype.beartype
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -881,6 +909,7 @@ def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
@_onnx_symbolic("quantized::leaky_relu")
@_beartype.beartype
def quantized_leaky_relu(
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
):
@ -892,6 +921,7 @@ def quantized_leaky_relu(
@_onnx_symbolic("quantized::layer_norm")
@_beartype.beartype
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
@ -910,6 +940,7 @@ def quantized_layer_norm(
@_onnx_symbolic("quantized::group_norm")
@_beartype.beartype
def quantized_group_norm(
g: jit_utils.GraphContext,
x,
@ -929,6 +960,7 @@ def quantized_group_norm(
@_onnx_symbolic("quantized::instance_norm")
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
@_beartype.beartype
def quantized_instance_norm(
g: jit_utils.GraphContext,
q_input,
@ -948,6 +980,7 @@ def quantized_instance_norm(
@_onnx_symbolic("quantized::conv1d_relu")
@_beartype.beartype
def quantized_conv1d_relu(
g: jit_utils.GraphContext,
q_input,
@ -972,6 +1005,7 @@ def quantized_conv1d_relu(
@_onnx_symbolic("quantized::conv2d_relu")
@_beartype.beartype
def quantized_conv2d_relu(
g: jit_utils.GraphContext,
q_input,
@ -996,6 +1030,7 @@ def quantized_conv2d_relu(
@_onnx_symbolic("quantized::conv3d_relu")
@_beartype.beartype
def quantized_conv3d_relu(
g: jit_utils.GraphContext,
q_input,
@ -1020,6 +1055,7 @@ def quantized_conv3d_relu(
@_onnx_symbolic("quantized::conv1d")
@_beartype.beartype
def quantized_conv1d(
g: jit_utils.GraphContext,
q_input,
@ -1043,6 +1079,7 @@ def quantized_conv1d(
@_onnx_symbolic("quantized::conv2d")
@_beartype.beartype
def quantized_conv2d(
g: jit_utils.GraphContext,
q_input,
@ -1066,6 +1103,7 @@ def quantized_conv2d(
@_onnx_symbolic("quantized::conv3d")
@_beartype.beartype
def quantized_conv3d(
g: jit_utils.GraphContext,
q_input,
@ -1089,6 +1127,7 @@ def quantized_conv3d(
@_onnx_symbolic("quantized::conv_transpose1d")
@_beartype.beartype
def quantized_conv_transpose1d(
g: jit_utils.GraphContext,
q_input,
@ -1115,6 +1154,7 @@ def quantized_conv_transpose1d(
@_onnx_symbolic("quantized::conv_transpose2d")
@_beartype.beartype
def quantized_conv_transpose2d(
g: jit_utils.GraphContext,
q_input,
@ -1141,6 +1181,7 @@ def quantized_conv_transpose2d(
@_onnx_symbolic("quantized::conv_transpose3d")
@_beartype.beartype
def quantized_conv_transpose3d(
g: jit_utils.GraphContext,
q_input,
@ -1168,6 +1209,7 @@ def quantized_conv_transpose3d(
@_onnx_symbolic("quantized::cat")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def quantized_cat(
g: jit_utils.GraphContext,
q_inputs: _C.Value,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 11."""
from __future__ import annotations
@ -19,7 +18,7 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
@ -90,6 +89,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
@_onnx_symbolic("aten::hardtanh")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
@_beartype.beartype
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
@ -108,7 +108,9 @@ def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val:
@_onnx_symbolic("aten::clamp")
@_beartype.beartype
def clamp(g: jit_utils.GraphContext, self, min, max):
@_beartype.beartype
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not symbolic_helper._is_none(tensor):
return g.op(
@ -144,6 +146,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
@_onnx_symbolic("aten::clamp_min")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_min(g: jit_utils.GraphContext, self, min):
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(min) == 0:
@ -159,6 +162,7 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
@_onnx_symbolic("aten::clamp_max")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_max(g: jit_utils.GraphContext, self, max):
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(max) == 0:
@ -173,6 +177,7 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
@_onnx_symbolic("aten::relu6")
@_beartype.beartype
def relu6(g: jit_utils.GraphContext, input):
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
@ -192,11 +197,13 @@ def relu6(g: jit_utils.GraphContext, input):
# Opset 11 gather accepts negative indices
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "v")
@_beartype.beartype
def select(g: jit_utils.GraphContext, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
@_onnx_symbolic("aten::index_put")
@_beartype.beartype
def index_put(
g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False
):
@ -322,6 +329,7 @@ def index_put(
@_onnx_symbolic("aten::pixel_shuffle")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None and rank != 4:
@ -357,12 +365,14 @@ def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
"aten::upsample_bicubic2d",
decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")],
)
@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
@_onnx_symbolic("aten::__interpolate")
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
@ -380,6 +390,7 @@ def __interpolate(
@_onnx_symbolic("aten::gather")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
@ -388,6 +399,7 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
@_onnx_symbolic("aten::scatter")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
src_type = _type_utils.JitScalarType.from_value(src)
src = symbolic_helper._maybe_get_scalar(src)
@ -409,6 +421,7 @@ def scatter(g: jit_utils.GraphContext, self, dim, index, src):
@_onnx_symbolic("aten::cumsum")
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
if dtype and dtype.node().kind() != "prim::Constant":
@ -423,12 +436,14 @@ def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
@_onnx_symbolic("aten::masked_select")
@_beartype.beartype
def masked_select(g: jit_utils.GraphContext, self, mask):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
return g.op("GatherND", self, index)
@_onnx_symbolic("aten::masked_scatter")
@_beartype.beartype
def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
# NOTE: source can have more elements than needed.
@ -446,6 +461,7 @@ def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
@_onnx_symbolic("aten::len")
@_beartype.beartype
def _len(g: jit_utils.GraphContext, self):
if (
symbolic_helper._is_tensor_list(self)
@ -457,6 +473,7 @@ def _len(g: jit_utils.GraphContext, self):
@_onnx_symbolic("aten::__getitem_")
@_beartype.beartype
def __getitem_(g: jit_utils.GraphContext, self, i):
if symbolic_helper._is_tensor_list(self):
# SequenceAt requires that the input be a List of Tensors
@ -468,17 +485,20 @@ def __getitem_(g: jit_utils.GraphContext, self, i):
@_onnx_symbolic("aten::_set_item")
@_beartype.beartype
def _set_item(g: jit_utils.GraphContext, tensor_list, i, v):
tensor_list = g.op("SequenceErase", tensor_list, i)
return g.op("SequenceInsert", tensor_list, v, i)
@_onnx_symbolic("aten::append")
@_beartype.beartype
def append(g: jit_utils.GraphContext, self, tensor):
return g.op("SequenceInsert", self, tensor)
@_onnx_symbolic("aten::add")
@_beartype.beartype
def add(g: jit_utils.GraphContext, self, other, alpha=None):
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
tensor_list_node = other.node()
@ -496,22 +516,26 @@ def add(g: jit_utils.GraphContext, self, other, alpha=None):
@_onnx_symbolic("aten::insert")
@_beartype.beartype
def insert(g: jit_utils.GraphContext, self, pos, tensor):
return g.op("SequenceInsert", self, tensor, pos)
@_onnx_symbolic("aten::pop")
@_beartype.beartype
def pop(g: jit_utils.GraphContext, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_onnx_symbolic("aten::Delete")
@_beartype.beartype
def Delete(g: jit_utils.GraphContext, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_onnx_symbolic("aten::cat")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def cat(g: jit_utils.GraphContext, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.cat(g, tensor_list, dim)
@ -521,6 +545,7 @@ def cat(g: jit_utils.GraphContext, tensor_list, dim):
@_onnx_symbolic("aten::stack")
@_beartype.beartype
def stack(g: jit_utils.GraphContext, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.stack(g, tensor_list, dim)
@ -531,6 +556,7 @@ def stack(g: jit_utils.GraphContext, tensor_list, dim):
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op(
"Unique", self, sorted_i=sorted, outputs=4
@ -540,6 +566,7 @@ def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_cou
@_onnx_symbolic("aten::unique_dim")
@symbolic_helper.parse_args("v", "i", "i", "i", "i")
@_beartype.beartype
def unique_dim(
g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
):
@ -551,6 +578,7 @@ def unique_dim(
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
@ -559,12 +587,14 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::argsort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
_, indices = symbolic_helper._sort_helper(
g, self, dim, decending=decending, out=out
@ -574,6 +604,7 @@ def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
@_onnx_symbolic("aten::round")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def round(g: jit_utils.GraphContext, self, decimals=0):
if not symbolic_helper._is_fp(self):
return self
@ -587,6 +618,7 @@ def round(g: jit_utils.GraphContext, self, decimals=0):
@_onnx_symbolic("aten::remainder")
@_beartype.beartype
def remainder(g: jit_utils.GraphContext, input, other):
if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
return opset9.remainder(g, input, other)
@ -595,6 +627,7 @@ def remainder(g: jit_utils.GraphContext, input, other):
@_onnx_symbolic("aten::split")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
@ -633,12 +666,14 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No
@_onnx_symbolic("aten::split_with_sizes")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
@_onnx_symbolic("aten::unbind")
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
if _outputs is None:
return g.op(
@ -652,6 +687,7 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
return opset9.unbind(g, self, dim, _outputs)
@_beartype.beartype
def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
"""Generate paddings in ONNX order based on pad in pytorch.
@ -710,6 +746,7 @@ def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
@_onnx_symbolic("aten::constant_pad_nd")
@_beartype.beartype
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
mode = "constant"
value = symbolic_helper._maybe_get_scalar(value)
@ -721,6 +758,7 @@ def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
@_onnx_symbolic("aten::reflection_pad1d")
@_onnx_symbolic("aten::reflection_pad2d")
@_onnx_symbolic("aten::reflection_pad3d")
@_beartype.beartype
def reflection_pad(g: jit_utils.GraphContext, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(g, input, padding)
@ -730,6 +768,7 @@ def reflection_pad(g: jit_utils.GraphContext, input, padding):
@_onnx_symbolic("aten::replication_pad1d")
@_onnx_symbolic("aten::replication_pad2d")
@_onnx_symbolic("aten::replication_pad3d")
@_beartype.beartype
def replication_pad(g: jit_utils.GraphContext, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(g, input, padding)
@ -737,6 +776,7 @@ def replication_pad(g: jit_utils.GraphContext, input, padding):
@_onnx_symbolic("aten::pad")
@_beartype.beartype
def pad(
g: jit_utils.GraphContext,
input: _C.Value,
@ -758,16 +798,19 @@ def pad(
@_onnx_symbolic("aten::linalg_det")
@_beartype.beartype
def linalg_det(g: jit_utils.GraphContext, self):
return g.op("Det", self)
@_onnx_symbolic("aten::logdet")
@_beartype.beartype
def logdet(g: jit_utils.GraphContext, input):
return opset9.log(g, linalg_det(g, input))
@_onnx_symbolic("aten::arange")
@_beartype.beartype
def arange(g: jit_utils.GraphContext, *args):
def _get_arange_dtype(dtype):
dtype = symbolic_helper._maybe_get_const(dtype, "i")
@ -841,6 +884,7 @@ def arange(g: jit_utils.GraphContext, *args):
@_onnx_symbolic("aten::_dim_arange")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def _dim_arange(g: jit_utils.GraphContext, like, dim):
like_shape = g.op("Shape", like)
stop = g.op(
@ -851,6 +895,7 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
@_onnx_symbolic("aten::size")
@symbolic_helper.quantized_args(True, quantize_output=False)
@_beartype.beartype
def size(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Shape", self)
@ -858,6 +903,7 @@ def size(g: jit_utils.GraphContext, self, dim=None):
@_onnx_symbolic("aten::squeeze")
@_beartype.beartype
def squeeze(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
@ -909,6 +955,7 @@ def squeeze(g: jit_utils.GraphContext, self, dim=None):
@_onnx_symbolic("aten::unsqueeze")
@_beartype.beartype
def unsqueeze(g: jit_utils.GraphContext, self, dim):
if symbolic_helper._is_constant(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
@ -917,11 +964,13 @@ def unsqueeze(g: jit_utils.GraphContext, self, dim):
@_onnx_symbolic("aten::mm")
@_beartype.beartype
def mm(g: jit_utils.GraphContext, self, other):
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::index")
@_beartype.beartype
def index(g: jit_utils.GraphContext, self, index):
if symbolic_helper._is_packed_list(index):
indices = symbolic_helper._unpack_list(index)
@ -942,6 +991,7 @@ def index(g: jit_utils.GraphContext, self, index):
@_onnx_symbolic("aten::index_fill")
@_beartype.beartype
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
@ -954,6 +1004,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
@_onnx_symbolic("aten::index_copy")
@_beartype.beartype
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
@ -964,6 +1015,7 @@ def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
@_onnx_symbolic("aten::bitwise_right_shift")
@_onnx_symbolic("aten::__rshift_")
@_beartype.beartype
def __rshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
@ -998,6 +1050,7 @@ def __rshift_(g: jit_utils.GraphContext, self, other):
@_onnx_symbolic("aten::bitwise_left_shift")
@_onnx_symbolic("aten::__lshift_")
@_beartype.beartype
def __lshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
@ -1030,6 +1083,7 @@ def __lshift_(g: jit_utils.GraphContext, self, other):
return lshift
@_beartype.beartype
def _get_im2col_indices_along_dim(
g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d
):
@ -1073,6 +1127,7 @@ def _get_im2col_indices_along_dim(
return block_mask
@_beartype.beartype
def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
@ -1081,6 +1136,7 @@ def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, paddin
return g.op("Pad", input, pad)
@_beartype.beartype
def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
@ -1099,6 +1155,7 @@ def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_
@_onnx_symbolic("aten::im2col")
@symbolic_helper.parse_args("v", "is", "is", "is", "is")
@_beartype.beartype
def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
@ -1151,6 +1208,7 @@ def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, str
@_onnx_symbolic("aten::narrow")
@_beartype.beartype
def narrow(g: jit_utils.GraphContext, input, dim, start, length):
end = g.op("Add", start, length)
return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end)
@ -1159,6 +1217,7 @@ def narrow(g: jit_utils.GraphContext, input, dim, start, length):
@_onnx_symbolic("aten::flatten")
@symbolic_helper.quantized_args(True, False, False)
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
dim = symbolic_helper._get_tensor_rank(input)
if dim == 1:
@ -1185,6 +1244,7 @@ def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
@_beartype.beartype
def linalg_vector_norm(
g: jit_utils.GraphContext,
self,
@ -1198,6 +1258,7 @@ def linalg_vector_norm(
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
@ -1226,6 +1287,7 @@ def embedding_bag(
@_onnx_symbolic("aten::embedding_renorm")
@symbolic_helper.parse_args("v", "v", "f", "f")
@_beartype.beartype
def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
unique_indices = g.op("Unique", indices)
partial_weight = g.op("Gather", weight, unique_indices)
@ -1264,6 +1326,7 @@ def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_
@_onnx_symbolic("aten::chunk")
@_beartype.beartype
def chunk(g: jit_utils.GraphContext, self, chunks, dim):
# Calculate chunk size for dynamic chunk
dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
@ -1281,6 +1344,7 @@ def chunk(g: jit_utils.GraphContext, self, chunks, dim):
@_onnx_symbolic("aten::normal")
@_beartype.beartype
def normal(
g: jit_utils.GraphContext,
mean,
@ -1304,6 +1368,7 @@ def normal(
@_onnx_symbolic("aten::atleast_1d")
@_beartype.beartype
def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 1D
@ -1330,6 +1395,7 @@ def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
@_onnx_symbolic("aten::atleast_2d")
@_beartype.beartype
def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 2D
# If it's 1D, unsqueeze to 2D
@ -1363,6 +1429,7 @@ def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
@_onnx_symbolic("aten::atleast_3d")
@_beartype.beartype
def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 3D
# If it's 1D, unsqueeze to 3D
@ -1407,6 +1474,7 @@ def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
@_onnx_symbolic("prim::ConstantChunk")
@_beartype.beartype
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
@ -1428,6 +1496,7 @@ def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
@_onnx_symbolic("aten::hstack")
@_beartype.beartype
def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
tensor_list = atleast_1d(g, tensor_list)
first_tensor = g.op(
@ -1460,6 +1529,7 @@ def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
@_onnx_symbolic("aten::vstack")
@_beartype.beartype
def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
tensor_list = atleast_2d(g, tensor_list)
return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0)

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
from __future__ import annotations
import functools
@ -15,7 +14,7 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
@ -46,6 +45,7 @@ __all__ = [
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
@_beartype.beartype
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
if not tensors:
raise RuntimeError("Einsum inputs are empty.")
@ -66,6 +66,7 @@ def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
@_onnx_symbolic("aten::einsum")
@symbolic_helper.parse_args("s", "v", "is")
@_beartype.beartype
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
tensors = symbolic_helper._unpack_list(tensor_list)
return _einsum_helper(g, equation, tensors)
@ -73,6 +74,7 @@ def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
@_onnx_symbolic("aten::outer")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def outer(g: jit_utils.GraphContext, input, other):
# make sure to cast other to self's type
if _type_utils.JitScalarType.from_value(
@ -86,6 +88,7 @@ def outer(g: jit_utils.GraphContext, input, other):
return _einsum_helper(g, "i,j->ij", [input, other])
@_beartype.beartype
def _dropout_returns_masked_input_and_mask(
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
@ -102,6 +105,7 @@ def _dropout_returns_masked_input_and_mask(
@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def dropout(g: jit_utils.GraphContext, input, p, train):
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
return masked
@ -109,11 +113,13 @@ def dropout(g: jit_utils.GraphContext, input, p, train):
@_onnx_symbolic("aten::native_dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def native_dropout(g: jit_utils.GraphContext, input, p, train):
return _dropout_returns_masked_input_and_mask(g, input, p, train)
@_onnx_symbolic("aten::nll_loss")
@_beartype.beartype
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
@ -147,6 +153,7 @@ def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_
@_onnx_symbolic("aten::nll_loss2d")
@_beartype.beartype
def nll_loss2d(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
@ -154,6 +161,7 @@ def nll_loss2d(
@_onnx_symbolic("aten::nll_loss_nd")
@_beartype.beartype
def nll_loss_nd(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
@ -161,6 +169,7 @@ def nll_loss_nd(
@_onnx_symbolic("aten::cross_entropy_loss")
@_beartype.beartype
def cross_entropy_loss(
g: jit_utils.GraphContext,
self,
@ -209,6 +218,7 @@ def cross_entropy_loss(
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
@_beartype.beartype
def binary_cross_entropy_with_logits(
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
):
@ -253,6 +263,7 @@ def binary_cross_entropy_with_logits(
@_onnx_symbolic("aten::celu")
@_beartype.beartype
def celu(g: jit_utils.GraphContext, self, alpha):
alpha = symbolic_helper._maybe_get_const(alpha, "f")
# if the input is of type double cast it to float
@ -269,6 +280,7 @@ def celu(g: jit_utils.GraphContext, self, alpha):
@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmax(
g: jit_utils.GraphContext,
input: torch._C.Value,
@ -280,6 +292,7 @@ def argmax(
@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmin(
g: jit_utils.GraphContext,
input: torch._C.Value,
@ -290,22 +303,26 @@ def argmin(
@_onnx_symbolic("aten::pow")
@_beartype.beartype
def pow(g: jit_utils.GraphContext, self, exponent):
return g.op("Pow", self, exponent)
@_onnx_symbolic("aten::ge")
@_beartype.beartype
def ge(g: jit_utils.GraphContext, input, other):
return g.op("GreaterOrEqual", input, other)
@_onnx_symbolic("aten::le")
@_beartype.beartype
def le(g: jit_utils.GraphContext, input, other):
return g.op("LessOrEqual", input, other)
@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
const_size = symbolic_helper._maybe_get_const(size, "i")
const_step = symbolic_helper._maybe_get_const(step, "i")
@ -382,6 +399,7 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
@_onnx_symbolic("aten::tensordot")
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
@_beartype.beartype
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
if out is not None:
symbolic_helper._unimplemented(

View File

@ -16,7 +16,7 @@ from torch.onnx import (
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
@ -24,6 +24,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
@_onnx_symbolic("aten::softmax")
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
softmax = g.op("Softmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
@ -37,6 +38,7 @@ def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
@_onnx_symbolic("aten::log_softmax")
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
@ -49,6 +51,7 @@ def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
@_onnx_symbolic("aten::frobenius_norm")
@symbolic_helper.parse_args("v", "v", "i")
@_beartype.beartype
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
dim_val = symbolic_helper._maybe_get_const(dim, "is")
if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
@ -60,6 +63,7 @@ def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
@_onnx_symbolic("aten::split")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
@ -116,11 +120,13 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No
@_onnx_symbolic("aten::split_with_sizes")
@_beartype.beartype
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
@_onnx_symbolic("aten::unsafe_split")
@_beartype.beartype
def unsafe_split(
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
):
@ -128,6 +134,7 @@ def unsafe_split(
@_onnx_symbolic("aten::unsafe_split_with_sizes")
@_beartype.beartype
def unsafe_split_with_sizes(
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
):
@ -136,6 +143,7 @@ def unsafe_split_with_sizes(
@_onnx_symbolic("aten::tensor_split")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def tensor_split(
g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None
):
@ -274,6 +282,7 @@ def tensor_split(
@_onnx_symbolic("aten::unbind")
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
if _outputs is None:
return g.op(
@ -296,12 +305,14 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
@_onnx_symbolic("aten::nonzero_numpy")
# Emitted from `torch.nonzero(x, as_tuple=True)`
@_beartype.beartype
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
@_onnx_symbolic("aten::where")
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if not symbolic_helper._is_bool(condition):
@ -316,6 +327,7 @@ def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=
@_onnx_symbolic("aten::fake_quantize_per_channel_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
@_beartype.beartype
def fake_quantize_per_channel_affine(
g: jit_utils.GraphContext,
inputs,
@ -351,6 +363,7 @@ def fake_quantize_per_channel_affine(
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
@ -387,7 +400,9 @@ def fake_quantize_per_tensor_affine(
return g.op("DequantizeLinear", quantized, scale, zero_point)
@_beartype.beartype
def _reduce_op_symbolic(onnx_op_name):
@_beartype.beartype
def symbolic(g, self, dim=None, keepdim=None):
self = symbolic_helper._maybe_cast_reduce_op_input(g, self)
if dim is None:
@ -404,12 +419,15 @@ def _reduce_op_symbolic(onnx_op_name):
"aten::sum",
decorate=[symbolic_helper._apply_params("ReduceSum", "sum")],
)
@_beartype.beartype
def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)
@symbolic_helper._overload_by_arg_count
@_beartype.beartype
def reduce(g, *args, **kwargs):
@symbolic_helper.parse_args("v", "none")
@_beartype.beartype
def reduce_nodim(g, self, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
@ -428,6 +446,7 @@ def _reduce_with_dtype(onnx_op, name):
return result
@symbolic_helper.parse_args("v", "v", "i", "none")
@_beartype.beartype
def reduce_dim(g, self, dim, keepdim, dtype):
dtype_onnx = None
if dtype.node().kind() == "onnx::Constant":
@ -454,6 +473,7 @@ def _reduce_with_dtype(onnx_op, name):
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097
# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ...
@_onnx_symbolic("aten::unflatten")
@_beartype.beartype
def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size):
input_dim = symbolic_helper._get_tensor_rank(input)
if input_dim is None:
@ -498,6 +518,7 @@ def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size):
@_onnx_symbolic("aten::unsafe_chunk")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
if _outputs is None:
return g.op(
@ -525,6 +546,7 @@ def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
@_onnx_symbolic("aten::tile")
@_beartype.beartype
def tile(g: jit_utils.GraphContext, self, dims):
self_shape = g.op("Shape", self)
self_rank = g.op("Size", self_shape)
@ -580,6 +602,7 @@ def tile(g: jit_utils.GraphContext, self, dims):
@_onnx_symbolic("aten::repeat_interleave")
@_beartype.beartype
def repeat_interleave(
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
):
@ -722,6 +745,7 @@ def repeat_interleave(
@_onnx_symbolic("aten::diagonal")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
rank = symbolic_helper._get_tensor_rank(self)
# Replace negative indexing when rank is known
@ -844,6 +868,7 @@ def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
@_onnx_symbolic("quantized::linear")
@_beartype.beartype
def quantized_linear(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
@ -860,6 +885,7 @@ def quantized_linear(
@_onnx_symbolic("quantized::linear_relu")
@_beartype.beartype
def quantized_linear_relu(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
@ -877,6 +903,7 @@ def quantized_linear_relu(
@_onnx_symbolic("quantized::conv1d_relu")
@_beartype.beartype
def quantized_conv1d_relu(
g: jit_utils.GraphContext,
q_input,
@ -903,6 +930,7 @@ def quantized_conv1d_relu(
@_onnx_symbolic("quantized::conv2d_relu")
@_beartype.beartype
def quantized_conv2d_relu(
g: jit_utils.GraphContext,
q_input,
@ -929,6 +957,7 @@ def quantized_conv2d_relu(
@_onnx_symbolic("quantized::conv3d_relu")
@_beartype.beartype
def quantized_conv3d_relu(
g: jit_utils.GraphContext,
q_input,
@ -955,6 +984,7 @@ def quantized_conv3d_relu(
@_onnx_symbolic("quantized::conv1d")
@_beartype.beartype
def quantized_conv1d(
g: jit_utils.GraphContext,
q_input,
@ -980,6 +1010,7 @@ def quantized_conv1d(
@_onnx_symbolic("quantized::conv2d")
@_beartype.beartype
def quantized_conv2d(
g: jit_utils.GraphContext,
q_input,
@ -1005,6 +1036,7 @@ def quantized_conv2d(
@_onnx_symbolic("quantized::conv3d")
@_beartype.beartype
def quantized_conv3d(
g: jit_utils.GraphContext,
q_input,
@ -1030,6 +1062,7 @@ def quantized_conv3d(
@_onnx_symbolic("quantized::conv_transpose1d")
@_beartype.beartype
def quantized_conv_transpose1d(
g: jit_utils.GraphContext,
q_input,
@ -1058,6 +1091,7 @@ def quantized_conv_transpose1d(
@_onnx_symbolic("quantized::conv_transpose2d")
@_beartype.beartype
def quantized_conv_transpose2d(
g: jit_utils.GraphContext,
q_input,
@ -1086,6 +1120,7 @@ def quantized_conv_transpose2d(
@_onnx_symbolic("quantized::conv_transpose3d")
@_beartype.beartype
def quantized_conv_transpose3d(
g: jit_utils.GraphContext,
q_input,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 14.
Note [ONNX operators that are added/updated in opset 14]
@ -24,7 +23,7 @@ from typing import Optional
import torch
from torch.onnx import _constants, _type_utils, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
__all__ = [
"hardswish",
@ -41,16 +40,19 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
@_onnx_symbolic("aten::hardswish")
@symbolic_helper.parse_args("v")
@_beartype.beartype
def hardswish(g: jit_utils.GraphContext, self):
return g.op("HardSwish", self)
@_onnx_symbolic("aten::tril")
@_beartype.beartype
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=0)
@_onnx_symbolic("aten::triu")
@_beartype.beartype
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=1)
@ -58,6 +60,7 @@ def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def reshape(g: jit_utils.GraphContext, self, shape):
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
@ -66,6 +69,7 @@ def reshape(g: jit_utils.GraphContext, self, shape):
@_onnx_symbolic("aten::batch_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
@_beartype.beartype
def batch_norm(
g: jit_utils.GraphContext,
input,
@ -120,6 +124,7 @@ def batch_norm(
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -134,6 +139,7 @@ def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
# NOTE: Need op.Trilu
@_onnx_symbolic("aten::scaled_dot_product_attention")
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v")
@_beartype.beartype
def scaled_dot_product_attention(
g: jit_utils.GraphContext,
query: torch._C.Value,
@ -206,6 +212,7 @@ def scaled_dot_product_attention(
return g.op("MatMul", attn_weight, value)
@_beartype.beartype
def _attention_scale(
g: jit_utils.GraphContext, query: torch._C.Value
) -> torch._C.Value:
@ -242,6 +249,7 @@ def _attention_scale(
return scale
@_beartype.beartype
def _causal_attention_mask(
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
) -> torch._C.Value:

View File

@ -31,12 +31,13 @@ import functools
import torch
from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
@_onnx_symbolic("aten::__is_")
@_beartype.beartype
def aten__is_(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
@ -49,11 +50,13 @@ def aten__is_(g: jit_utils.GraphContext, self, other):
@_onnx_symbolic("aten::__isnot_")
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
@_beartype.beartype
def aten__isnot_(g: jit_utils.GraphContext, self, other):
return aten__is_(g, self, other)
@_onnx_symbolic("aten::bernoulli")
@_beartype.beartype
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
if out is not None and not symbolic_helper._is_none(out):
symbolic_helper._unimplemented(
@ -69,6 +72,7 @@ def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None
@_onnx_symbolic("prim::unchecked_cast")
@_beartype.beartype
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast

View File

@ -34,7 +34,7 @@ from torch.nn.functional import (
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
@ -43,6 +43,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def grid_sampler(
g: jit_utils.GraphContext,
input,
@ -68,6 +69,7 @@ def grid_sampler(
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
@ -115,6 +117,7 @@ def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
@_onnx_symbolic("aten::scatter_reduce")
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
@_beartype.beartype
def scatter_reduce(
g: jit_utils.GraphContext,
self: torch._C.Value,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 17.
Note [ONNX Operators that are added/updated in opset 17]
@ -23,7 +22,7 @@ from typing import Optional, Sequence
import torch
from torch import _C
from torch.onnx import _type_utils, errors, symbolic_helper
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
@ -97,6 +96,7 @@ def _compute_edge_sizes(n_fft, window_size):
@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
@_beartype.beartype
def stft(
g: jit_utils.GraphContext,
input: _C.Value,
@ -154,7 +154,7 @@ def stft(
signal,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
elif signal_rank is None or signal_rank > 2:
elif signal_rank > 2:
raise errors.SymbolicValueError(
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
f"Current rank of signal is {signal_rank}, please reduce it.",

View File

@ -25,7 +25,7 @@ from typing import List, Optional, Sequence, Tuple
import torch
from torch import _C
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@ -39,6 +39,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@_onnx_symbolic("aten::__and_")
@_onnx_symbolic("aten::bitwise_and")
@_beartype.beartype
def __and_(g: jit_utils.GraphContext, self, other):
# do type promotion (scalars don't seem to apply)
args = [self, other]
@ -56,6 +57,7 @@ def __and_(g: jit_utils.GraphContext, self, other):
@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(
g,
input: _C.Value,
@ -103,6 +105,7 @@ def col2im(
)
],
)
@_beartype.beartype
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
return symbolic_helper._reduce_with_dtype_helper(
onnx_op, name, allow_multi_dim_support
@ -112,6 +115,7 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
@_onnx_symbolic("aten::native_layer_norm")
@symbolic_helper.quantized_args(True, False, False, False)
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
@_beartype.beartype
def _native_layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
@ -125,6 +129,7 @@ def _native_layer_norm(
@_onnx_symbolic("aten::glu")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def _glu(g: jit_utils.GraphContext, input, dim):
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
if dim_size is not None:
@ -138,24 +143,28 @@ def _glu(g: jit_utils.GraphContext, input, dim):
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::maximum")
@symbolic_helper.quantized_args(True, True)
@_beartype.beartype
def maximum(g: jit_utils.GraphContext, input, other):
return max(g, input, dim_or_y=other)
@_onnx_symbolic("aten::min")
# TODO(justinchuby): Support multiple quantized args in output
@_beartype.beartype
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
@_onnx_symbolic("aten::minimum")
@symbolic_helper.quantized_args(True, True)
@_beartype.beartype
def minimum(g: jit_utils.GraphContext, input, other):
return min(g, input, dim_or_y=other)
@ -163,6 +172,7 @@ def minimum(g: jit_utils.GraphContext, input, other):
@_onnx_symbolic("aten::amax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
@ -171,6 +181,7 @@ def amax(g: jit_utils.GraphContext, self, dim, keepdim):
@_onnx_symbolic("aten::amin")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
@ -179,6 +190,7 @@ def amin(g: jit_utils.GraphContext, self, dim, keepdim):
@_onnx_symbolic("aten::aminmax")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "v", "i")
@_beartype.beartype
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
if not symbolic_helper._is_none(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
@ -193,6 +205,7 @@ def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
@_onnx_symbolic("aten::var_mean")
@_beartype.beartype
def _var_mean(g: jit_utils.GraphContext, input, *args):
if len(args) == 1:
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
@ -202,6 +215,7 @@ def _var_mean(g: jit_utils.GraphContext, input, *args):
@_onnx_symbolic("aten::logsumexp")
@symbolic_helper.parse_args("v", "is", "i")
@_beartype.beartype
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
if dim is None:
return g.op("ReduceLogSumExp", input, keepdims_i=0)
@ -212,6 +226,7 @@ def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
@_onnx_symbolic("aten::linalg_matrix_norm")
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
@_beartype.beartype
def _linalg_matrix_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,
@ -225,6 +240,7 @@ def _linalg_matrix_norm(
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
@ -253,6 +269,7 @@ def embedding_bag(
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
@_beartype.beartype
def linalg_vector_norm(
g: jit_utils.GraphContext,
self: torch._C.Value,

View File

@ -27,7 +27,7 @@ import torch.nn.functional as F
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@ -46,6 +46,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def _grid_sampler(
g: jit_utils.GraphContext,
input: _C.Value,
@ -70,6 +71,7 @@ def _grid_sampler(
@_onnx_symbolic("aten::affine_grid_generator")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def _affine_grid_generator(
g: jit_utils.GraphContext,
theta: _C.Value,
@ -86,5 +88,6 @@ def _affine_grid_generator(
@_onnx_symbolic("aten::gelu")
@symbolic_helper.parse_args("v", "s")
@_beartype.beartype
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
return g.op("Gelu", self, approximate_s=approximate)

View File

@ -165,7 +165,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
old_type = arg0_type
if old_type not in floating_scalar_types:
old_type = old_type.scalar_name() # type: ignore[assignment]
old_type = old_type.scalar_name()
args = tuple(
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
for arg in args

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import contextlib
import copy
import inspect
import io
import re
import typing
import warnings
@ -37,13 +38,17 @@ from torch.onnx import ( # noqa: F401
_constants,
_exporter_states,
errors,
symbolic_caffe2,
symbolic_helper,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
if typing.TYPE_CHECKING:
import io
from torch.onnx._internal import (
_beartype,
diagnostics,
jit_utils,
onnx_proto_utils,
registration,
)
__all__ = [
"is_in_onnx_export",
@ -73,6 +78,7 @@ _params_dict = {} # type: ignore[var-annotated]
@contextlib.contextmanager
@_beartype.beartype
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
r"""A context manager to temporarily set the training mode of ``model``
to ``mode``, resetting it when we exit the with-block.
@ -121,6 +127,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
@contextlib.contextmanager
@_beartype.beartype
def disable_apex_o2_state_dict_hook(
model: Union[torch.nn.Module, torch.jit.ScriptFunction]
):
@ -154,6 +161,7 @@ def disable_apex_o2_state_dict_hook(
@contextlib.contextmanager
@_beartype.beartype
def setup_onnx_logging(verbose: bool):
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
if is_originally_enabled or verbose:
@ -166,6 +174,7 @@ def setup_onnx_logging(verbose: bool):
@contextlib.contextmanager
@_beartype.beartype
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
with select_model_mode_for_export(
model, mode
@ -559,6 +568,7 @@ def export(
return None
@_beartype.beartype
def _is_constant_tensor_list(node):
if node.kind() != "prim::Constant":
return False
@ -573,6 +583,7 @@ def _is_constant_tensor_list(node):
# get generated in constant prop. So we split them back into prim::ListConstructs
@_beartype.beartype
def _split_tensor_list_constants(g, block):
for node in block.nodes():
for subblock in node.blocks():
@ -595,6 +606,7 @@ def _split_tensor_list_constants(g, block):
node.output().replaceAllUsesWith(lc)
@_beartype.beartype
def _optimize_graph(
graph: _C.Graph,
operator_export_type: _C_onnx.OperatorExportTypes,
@ -705,6 +717,7 @@ def _optimize_graph(
return graph
@_beartype.beartype
def warn_on_static_input_change(input_states):
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
@ -734,11 +747,13 @@ def warn_on_static_input_change(input_states):
warnings.warn(warning)
@_beartype.beartype
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
return arg_value
@_beartype.beartype
def _decide_keep_init_as_input(
keep_initializers_as_inputs: Optional[bool],
operator_export_type: _C_onnx.OperatorExportTypes,
@ -782,12 +797,14 @@ def _decide_keep_init_as_input(
return val_keep_init_as_ip
@_beartype.beartype
def _decide_add_node_names(add_node_names, operator_export_type):
return _resolve_args_by_export_type(
"add_node_names", add_node_names, operator_export_type
)
@_beartype.beartype
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
do_constant_folding = _resolve_args_by_export_type(
"do_constant_folding", do_constant_folding, operator_export_type
@ -806,6 +823,7 @@ def _decide_constant_folding(do_constant_folding, operator_export_type, training
return do_constant_folding
@_beartype.beartype
def _signature(model) -> inspect.Signature:
should_be_callable = getattr(model, "forward", model)
if callable(should_be_callable):
@ -813,6 +831,7 @@ def _signature(model) -> inspect.Signature:
raise ValueError("model has no forward method and is not callable")
@_beartype.beartype
def _decide_input_format(model, args):
try:
sig = _signature(model)
@ -851,6 +870,7 @@ def _decide_input_format(model, args):
return args
@_beartype.beartype
def _from_dynamic_axes_to_dynamic_shapes(
model,
dynamic_axes: Optional[
@ -909,6 +929,7 @@ def _from_dynamic_axes_to_dynamic_shapes(
return dynamic_shapes
@_beartype.beartype
def _trace(func, args, operator_export_type, return_outs=False):
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
@ -929,6 +950,7 @@ def _trace(func, args, operator_export_type, return_outs=False):
return trace_graph
@_beartype.beartype
def _trace_and_get_graph_from_model(model, args):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
@ -959,6 +981,7 @@ def _trace_and_get_graph_from_model(model, args):
return trace_graph, torch_out
@_beartype.beartype
def _get_param_count_list(method_graph, args_params):
param_count_list = []
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
@ -971,9 +994,11 @@ def _get_param_count_list(method_graph, args_params):
return param_count_list
@_beartype.beartype
def _check_flatten_did_not_remove(original, jit_flattened):
"""torch.jit._flatten removes None. Check if it did so in this case."""
@_beartype.beartype
def flatten(x):
if isinstance(x, (list, tuple)):
for inner in x:
@ -1046,6 +1071,7 @@ def _create_jit_graph(
return graph, params, torch_out, None
@_beartype.beartype
def _get_named_param_dict(graph, params):
input_and_param_names = [val.debugName() for val in graph.inputs()]
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
@ -1053,6 +1079,7 @@ def _get_named_param_dict(graph, params):
return _params_dict
@_beartype.beartype
def _get_example_outputs(model, args):
input_args = copy.deepcopy(args)
input_kwargs = {}
@ -1077,6 +1104,7 @@ _qtype_vtype_map = {
}
@_beartype.beartype
def unpack_quantized_tensor(value, cast_onnx_accepted=True):
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
q_value_dequantize = value.dequantize()
@ -1097,6 +1125,7 @@ def unpack_quantized_tensor(value, cast_onnx_accepted=True):
return (value,)
@_beartype.beartype
def _pre_trace_quant_model(model, args):
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
original model.
@ -1110,6 +1139,7 @@ def _pre_trace_quant_model(model, args):
return model
@_beartype.beartype
def _model_to_graph(
model,
args,
@ -1245,6 +1275,7 @@ def _model_to_graph(
return graph, params_dict, torch_out
@_beartype.beartype
@torch._disable_dynamo
def export_to_pretty_string(
model,
@ -1322,6 +1353,7 @@ def export_to_pretty_string(
)
@_beartype.beartype
def unconvertible_ops(
model,
args,
@ -1390,6 +1422,7 @@ def unconvertible_ops(
return graph, unsupported_ops
@_beartype.beartype
def _setup_trace_module_map(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
@ -1467,11 +1500,13 @@ def _setup_trace_module_map(
return module_typenames
@_beartype.beartype
def _reset_trace_module_map():
torch.jit._trace._trace_module_map = None
_C._jit_pass_onnx_clear_scope_records()
@_beartype.beartype
def _get_module_attributes(module):
annotations = typing.get_type_hints(type(module))
base_m_annotations = typing.get_type_hints(torch.nn.Module)
@ -1496,6 +1531,7 @@ def _get_module_attributes(module):
return attrs
@_beartype.beartype
def _export(
model,
args,
@ -1704,6 +1740,7 @@ def _export(
return torch_out
@_beartype.beartype
def _apply_friendly_debug_names(graph, params):
for n in graph.nodes():
for v in n.inputs():
@ -1716,7 +1753,9 @@ def _apply_friendly_debug_names(graph, params):
params[new_name] = params.pop(old_name)
@_beartype.beartype
def _set_input_and_output_names(graph, input_names, output_names):
@_beartype.beartype
def set_names(node_list, name_list, descriptor):
if name_list is None:
return
@ -1747,6 +1786,7 @@ def _set_input_and_output_names(graph, input_names, output_names):
set_names(list(graph.outputs()), output_names, "output")
@_beartype.beartype
def _run_symbolic_method(g, op_name, symbolic_fn, args):
r"""
This trampoline function gets invoked for every symbolic method
@ -1772,18 +1812,22 @@ def _run_symbolic_method(g, op_name, symbolic_fn, args):
raise
@_beartype.beartype
def _add_block(node: _C.Node) -> _C.Block:
return node.addBlock()
@_beartype.beartype
def _add_input_to_block(block: _C.Block):
return block.addInputToBlock() # type: ignore[attr-defined]
@_beartype.beartype
def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
return block.registerOutput(value)
@_beartype.beartype
def _should_aten_fallback(
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
@ -1805,6 +1849,7 @@ def _should_aten_fallback(
return False
@_beartype.beartype
def _need_symbolic_context(symbolic_fn: Callable) -> bool:
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
params = tuple(inspect.signature(symbolic_fn).parameters.values())
@ -1820,6 +1865,7 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
return issubclass(param_type, _exporter_states.SymbolicContext)
@_beartype.beartype
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
"""Decorator that provides the symbolic context to the symbolic function if needed."""
if _need_symbolic_context(symbolic_fn):
@ -1844,6 +1890,7 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
return symbolic_fn
@_beartype.beartype
def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema()
@ -1852,6 +1899,7 @@ def _get_aten_op_overload_name(n: _C.Node) -> str:
return _C.parse_schema(schema).overload_name
@_beartype.beartype
def _run_symbolic_function(
graph: _C.Graph,
block: _C.Block,
@ -1964,6 +2012,7 @@ def _run_symbolic_function(
raise
@_beartype.beartype
def _verify_custom_op_name(symbolic_name: str):
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
raise errors.OnnxExporterError(
@ -1980,6 +2029,7 @@ def _verify_custom_op_name(symbolic_name: str):
)
@_beartype.beartype
def register_custom_op_symbolic(
symbolic_name: str,
symbolic_fn: Callable,
@ -2016,6 +2066,7 @@ def register_custom_op_symbolic(
)(symbolic_fn)
@_beartype.beartype
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
"""Unregisters ``symbolic_name``.
@ -2034,6 +2085,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
registration.registry.unregister(symbolic_name, opset_version)
@_beartype.beartype
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
"""Ensures dynamic axes argument is follows the expected format."""
if len(dynamic_axes) == 0:

View File

@ -40,7 +40,7 @@ import torch._C._onnx as _C_onnx
from torch import _C
from torch.onnx import _constants, _experimental, _exporter_states, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils
from torch.onnx._internal import _beartype, onnx_proto_utils
from torch.types import Number
_ORT_PROVIDERS = ("CPUExecutionProvider",)
@ -99,6 +99,7 @@ class VerificationOptions:
acceptable_error_percentage: Optional[float] = None
@_beartype.beartype
def _flatten_tuples(elem):
flattened = []
for t in elem:
@ -128,6 +129,7 @@ def _to_numpy(elem) -> Union[list, np.ndarray]:
return elem
@_beartype.beartype
def _inline_flatten_list(inputs, res_list) -> list:
for i in inputs:
res_list.append(i) if not isinstance(
@ -136,6 +138,7 @@ def _inline_flatten_list(inputs, res_list) -> list:
return res_list
@_beartype.beartype
def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
value_unpacked = []
for value in values:
@ -145,6 +148,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
return [_to_numpy(v) for v in value_unpacked]
@_beartype.beartype
def _run_onnx(onnx_session, inputs) -> _OutputsType:
kw_inputs = {}
if inputs and isinstance(inputs[-1], dict):
@ -175,6 +179,7 @@ def _run_onnx(onnx_session, inputs) -> _OutputsType:
return onnx_outs
@_beartype.beartype
def _ort_session(
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
):
@ -198,6 +203,7 @@ def _ort_session(
return ort_session
@_beartype.beartype
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
try:
import onnx
@ -214,6 +220,7 @@ def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
return onnx_session
@_beartype.beartype
def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
if backend == OnnxBackend.REFERENCE:
onnx_session = _onnx_reference_evaluator_session(model)
@ -224,6 +231,7 @@ def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
return onnx_session
@_beartype.beartype
def _compare_onnx_pytorch_outputs_in_np(
onnx_outs: _OutputsType,
pt_outs: _OutputsType,
@ -273,6 +281,7 @@ def _compare_onnx_pytorch_outputs_in_np(
raise
@_beartype.beartype
def _compare_onnx_pytorch_outputs(
onnx_outs: _OutputsType,
pt_outs: Any,
@ -301,6 +310,7 @@ def _compare_onnx_pytorch_outputs(
_compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
@_beartype.beartype
def _prepare_input_for_pytorch(args, kwargs):
"""Prepare input for PyTorch model execution.
@ -327,6 +337,7 @@ def _prepare_input_for_pytorch(args, kwargs):
return args, kwargs
@_beartype.beartype
def _prepare_input_for_export(args, kwargs):
"""Prepare input for ONNX model export.
@ -351,6 +362,7 @@ def _prepare_input_for_export(args, kwargs):
return onnx_inputs
@_beartype.beartype
def _prepare_input_for_onnx(
args, kwargs, remained_onnx_input_idx: Optional[Sequence[int]], flatten: bool
):
@ -380,6 +392,7 @@ def _prepare_input_for_onnx(
return onnx_inputs
@_beartype.beartype
def _try_clone_model(model):
"""Used for preserving original model in case forward mutates model states."""
try:
@ -391,6 +404,7 @@ def _try_clone_model(model):
return model
@_beartype.beartype
def _compare_onnx_pytorch_model(
pt_model: _ModelType,
onnx_model_f: Union[str, io.BytesIO],
@ -416,6 +430,7 @@ def _compare_onnx_pytorch_model(
"""
onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
@_beartype.beartype
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
# TODO: remove this and treat mutating model separately. See #77679
@ -444,6 +459,7 @@ def _compare_onnx_pytorch_model(
class _GraphDiff:
"""A class to represent the difference between two graphs."""
@_beartype.beartype
def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
"""Construct a _GraphDiff object.
@ -454,13 +470,16 @@ class _GraphDiff:
self.graph_a = graph_a
self.graph_b = graph_b
@_beartype.beartype
def __str__(self):
"""See function :func:`diff_report`."""
return self.diff_report()
@_beartype.beartype
def _indent(self, lines: str) -> str:
return "\n".join(["\t" + line for line in lines.splitlines()])
@_beartype.beartype
def diff_report(self) -> str:
"""Return a string representation of the graph difference.
@ -510,6 +529,7 @@ class _GraphDiff:
return "\n".join(graph_diff_report)
@_beartype.beartype
def _check_graph_diff(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
@ -551,6 +571,7 @@ def _check_graph_diff(
return ""
@_beartype.beartype
def _traced_graph_from_model(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
args: Tuple[Any, ...],
@ -578,6 +599,7 @@ def _traced_graph_from_model(
return jit_graph
@_beartype.beartype
def _onnx_graph_from_model(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
args: Tuple[Any, ...],
@ -642,6 +664,7 @@ def _onnx_graph_from_model(
return onnx_graph
@_beartype.beartype
def _onnx_graph_from_aten_graph(
graph: torch.Graph,
export_options: _experimental.ExportOptions,
@ -705,6 +728,7 @@ def _onnx_graph_from_aten_graph(
return graph, params_dict
@_beartype.beartype
def _onnx_proto_from_onnx_graph(
onnx_graph: torch.Graph,
export_options: _experimental.ExportOptions,
@ -738,6 +762,7 @@ def _onnx_proto_from_onnx_graph(
return proto, export_map
@_beartype.beartype
def check_export_model_diff(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
@ -781,6 +806,7 @@ def check_export_model_diff(
)
@_beartype.beartype
def verify(
model: _ModelType,
input_args: _InputArgsType,
@ -869,6 +895,7 @@ def verify(
)
@_beartype.beartype
def verify_aten_graph(
graph: torch.Graph,
input_args: Tuple[Any, ...],
@ -969,6 +996,7 @@ class GraphInfoPrettyPrinter:
self.upper_printer = None
self.lower_printer = None
@_beartype.beartype
def _total_rows(self) -> int:
if self.graph_info is None:
return 1
@ -978,6 +1006,7 @@ class GraphInfoPrettyPrinter:
)
return 2 # Two lines: node count + id.
@_beartype.beartype
def _node_count_segment_str(self) -> str:
if self.graph_info is None:
return "..."
@ -991,16 +1020,19 @@ class GraphInfoPrettyPrinter:
return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}"
@_beartype.beartype
def _graph_id_segment_str(self) -> str:
if self.graph_info is None:
return ""
return f"id: {self.graph_info.id}"
@_beartype.beartype
def _max_segment_columns(self) -> int:
return max(
map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
)
@_beartype.beartype
def _graph_segment_str_at_line(self, line: int) -> str:
"""Get the string representation of the graph segment at the given line."""
if line == 0:
@ -1015,6 +1047,7 @@ class GraphInfoPrettyPrinter:
return " " * self._max_segment_columns()
return ""
@_beartype.beartype
def _connector_segment_str_at_line(self, line: int) -> str:
"""Get the connector segment string at the given line."""
if self.upper_printer is None and self.lower_printer is None:
@ -1031,6 +1064,7 @@ class GraphInfoPrettyPrinter:
return " "
return ""
@_beartype.beartype
def _children_str_at_line(self, line: int) -> str:
"""Get the string representation of the children at the given line.
@ -1052,6 +1086,7 @@ class GraphInfoPrettyPrinter:
)
return ""
@_beartype.beartype
def _str_at_line(self, line: int) -> str:
"""Get the string representation of the graph at the given line."""
return (
@ -1101,6 +1136,7 @@ class OnnxTestCaseRepro:
)
@classmethod
@_beartype.beartype
def create_test_case_repro(
cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None
):
@ -1138,6 +1174,7 @@ class OnnxTestCaseRepro:
dir,
)
@_beartype.beartype
def validate(self, options: VerificationOptions):
"""Run the ONNX test case with options.backend, and compare with the expected outputs.
@ -1249,16 +1286,19 @@ class GraphInfo:
else:
print(" No mismatch ".center(80, "="))
@_beartype.beartype
def has_mismatch(self) -> bool:
"""Return True if the subgraph has output mismatch between torch and ONNX."""
return self.mismatch_error is not None
@_beartype.beartype
def essential_node_count(self) -> int:
"""Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
return sum(
1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
)
@_beartype.beartype
def essential_node_kinds(self) -> Set[str]:
"""Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
return {
@ -1267,7 +1307,8 @@ class GraphInfo:
if n.kind() not in self._EXCLUDED_NODE_KINDS
}
def all_mismatch_leaf_graph_info(self) -> List[GraphInfo]:
@_beartype.beartype
def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]:
"""Return a list of all leaf `GraphInfo` objects that have mismatch."""
if not self.has_mismatch():
return []
@ -1289,7 +1330,8 @@ class GraphInfo:
return results
def find_partition(self, id: str) -> Optional[GraphInfo]:
@_beartype.beartype
def find_partition(self, id: str) -> Optional["GraphInfo"]:
"""Find the `GraphInfo` object with the given id."""
if id == self.id:
return self
@ -1301,6 +1343,7 @@ class GraphInfo:
return self.lower_graph_info.find_partition(id)
return None
@_beartype.beartype
def export_repro(
self, repro_dir: Optional[str] = None, name: Optional[str] = None
) -> str:
@ -1341,6 +1384,7 @@ class GraphInfo:
proto, self.input_args, self.pt_outs, repro_dir, name
)
@_beartype.beartype
def _graph_partition_pivot(self) -> int:
"""Find the pivot index to partition the graph.
@ -1363,6 +1407,7 @@ class GraphInfo:
return included_node_indices[half_idx] + 1
return -1
@_beartype.beartype
def _partition_upper_graph(self) -> torch.Graph:
pivot = self._graph_partition_pivot()
if pivot == -1:
@ -1406,6 +1451,7 @@ class GraphInfo:
return graph
@_beartype.beartype
def _partition_lower_graph(self) -> torch.Graph:
pivot = self._graph_partition_pivot()
if pivot == -1:
@ -1460,6 +1506,7 @@ class GraphInfo:
return graph
@_beartype.beartype
def _partition_node(
self,
node: torch.Node,
@ -1498,6 +1545,7 @@ class GraphInfo:
):
covered_bridge_values.add(process_bridge_value(output))
@_beartype.beartype
def _partition_nodes(
self,
graph: torch.Graph,
@ -1537,6 +1585,7 @@ class GraphInfo:
complete_lower_nodes_set,
)
@_beartype.beartype
def _bridge_kwargs(self):
pt_outs = self.pt_outs
graph_outputs = list(self.graph.outputs())
@ -1546,6 +1595,7 @@ class GraphInfo:
), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
@_beartype.beartype
def _args_and_params_for_partition_graph(
self,
graph: torch.Graph,
@ -1562,6 +1612,7 @@ class GraphInfo:
), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
return args, params
@_beartype.beartype
def verify_export(
self, options: VerificationOptions
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
@ -1591,6 +1642,7 @@ class GraphInfo:
verification_options=options,
)
@_beartype.beartype
def find_mismatch(
self,
options: Optional[VerificationOptions] = None,
@ -1670,6 +1722,7 @@ class GraphInfo:
self.lower_graph_info.find_mismatch(options)
@_beartype.beartype
def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
all_nodes = set(nodes)
for n in nodes:
@ -1678,10 +1731,12 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
return all_nodes
@_beartype.beartype
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
return any(use.user in nodes for use in value.uses())
@_beartype.beartype
def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
for output in node.outputs():
if _has_uses_by_nodes(output, nodes):
@ -1689,10 +1744,12 @@ def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
return False
@_beartype.beartype
def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
return value.node() in nodes
@_beartype.beartype
def find_mismatch(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
input_args: Tuple[Any, ...],