mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ONNX] Remove beartype usage (#130484)"
This reverts commit 1794c35912025aa44b0d70f67ff664b4f7bd1014.
Reverted https://github.com/pytorch/pytorch/pull/130484 on behalf of https://github.com/clee2000 due to test_sympy_utils failure is real https://github.com/pytorch/pytorch/actions/runs/9961499559/job/27523758780 1794c35912
. Dr CI is matching with commits in current commit? ([comment](https://github.com/pytorch/pytorch/pull/130484#issuecomment-2231575577))
This commit is contained in:
@ -10,6 +10,7 @@ retry () {
|
||||
|
||||
# A bunch of custom pip dependencies for ONNX
|
||||
pip_install \
|
||||
beartype==0.15.0 \
|
||||
filelock==3.9.0 \
|
||||
flatbuffers==2.0 \
|
||||
mock==5.0.1 \
|
||||
|
@ -3,16 +3,18 @@ import io
|
||||
import os
|
||||
|
||||
import onnx
|
||||
from beartype import roar
|
||||
|
||||
import torch
|
||||
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
|
||||
from torch.onnx._internal import exporter
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
from torch.onnx._internal.exporter import (
|
||||
LargeProtobufONNXProgramSerializer,
|
||||
ONNXProgramSerializer,
|
||||
ProtobufONNXProgramSerializer,
|
||||
ResolvedExportOptions,
|
||||
)
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
@ -47,6 +49,15 @@ class _LargeModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestExportOptionsAPI(common_utils.TestCase):
|
||||
def test_raise_on_invalid_argument_type(self):
|
||||
expected_exception_type = roar.BeartypeException
|
||||
with self.assertRaises(expected_exception_type):
|
||||
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
|
||||
with self.assertRaises(expected_exception_type):
|
||||
ExportOptions(diagnostic_options="DEBUG") # type: ignore[arg-type]
|
||||
with self.assertRaises(expected_exception_type):
|
||||
ResolvedExportOptions(options=12) # type: ignore[arg-type]
|
||||
|
||||
def test_dynamic_shapes_default(self):
|
||||
options = ResolvedExportOptions(ExportOptions())
|
||||
self.assertFalse(options.dynamic_shapes)
|
||||
@ -109,6 +120,7 @@ class TestDynamoExportAPI(common_utils.TestCase):
|
||||
|
||||
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
|
||||
# Because `ONNXProgramSerializer` is a Protocol class.
|
||||
# `beartype` will not complain.
|
||||
class CustomSerializer:
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
@ -184,8 +196,27 @@ class TestDynamoExportAPI(common_utils.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_raise_on_invalid_save_argument_type(self):
|
||||
with self.assertRaises(roar.BeartypeException):
|
||||
ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
|
||||
onnx_program = ONNXProgram(
|
||||
onnx.ModelProto(),
|
||||
io_adapter.InputAdapter(),
|
||||
io_adapter.OutputAdapter(),
|
||||
diagnostics.DiagnosticContext("test", "1.0"),
|
||||
fake_context=None,
|
||||
)
|
||||
with self.assertRaises(roar.BeartypeException):
|
||||
onnx_program.save(None) # type: ignore[arg-type]
|
||||
onnx_program.model_proto
|
||||
|
||||
|
||||
class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase):
|
||||
def test_raise_on_invalid_argument_type(self):
|
||||
with self.assertRaises(roar.BeartypeException):
|
||||
serializer = ProtobufONNXProgramSerializer()
|
||||
serializer.serialize(None, None) # type: ignore[arg-type]
|
||||
|
||||
def test_serialize_raises_when_model_greater_than_2gb(self):
|
||||
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
|
||||
serializer = ProtobufONNXProgramSerializer()
|
||||
|
84
test/onnx/internal/test_beartype.py
Normal file
84
test/onnx/internal/test_beartype.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the internal beartype wrapper module."""
|
||||
|
||||
import unittest
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def beartype_installed():
|
||||
try:
|
||||
import beartype # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def skip_if_beartype_not_installed(test_case):
|
||||
return unittest.skipIf(not beartype_installed(), "beartype is not installed")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def func_with_type_hint(x: int) -> int:
|
||||
return x
|
||||
|
||||
|
||||
def func_with_incorrect_type_hint(x: int) -> str:
|
||||
return x # type: ignore[return-value]
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestBeartype(common_utils.TestCase):
|
||||
def test_create_beartype_decorator_returns_no_op_decorator_when_disabled(self):
|
||||
decorator = _beartype._create_beartype_decorator(
|
||||
_beartype.RuntimeTypeCheckState.DISABLED,
|
||||
)
|
||||
decorated = decorator(func_with_incorrect_type_hint)
|
||||
decorated("string_input") # type: ignore[arg-type]
|
||||
|
||||
@skip_if_beartype_not_installed
|
||||
def test_create_beartype_decorator_warns_when_warnings(self):
|
||||
decorator = _beartype._create_beartype_decorator(
|
||||
_beartype.RuntimeTypeCheckState.WARNINGS,
|
||||
)
|
||||
decorated = decorator(func_with_incorrect_type_hint)
|
||||
with self.assertWarns(_beartype.CallHintViolationWarning):
|
||||
decorated("string_input") # type: ignore[arg-type]
|
||||
|
||||
@common_utils.parametrize("arg", [1, "string_input"])
|
||||
@skip_if_beartype_not_installed
|
||||
def test_create_beartype_decorator_errors_when_errors(self, arg):
|
||||
import beartype
|
||||
|
||||
decorator = _beartype._create_beartype_decorator(
|
||||
_beartype.RuntimeTypeCheckState.ERRORS,
|
||||
)
|
||||
decorated = decorator(func_with_incorrect_type_hint)
|
||||
with self.assertRaises(beartype.roar.BeartypeCallHintViolation):
|
||||
decorated(arg)
|
||||
|
||||
@skip_if_beartype_not_installed
|
||||
def test_create_beartype_decorator_warning_calls_function_once(self):
|
||||
call_count = 0
|
||||
|
||||
def func_with_incorrect_type_hint_and_side_effect(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x # type: ignore[return-value]
|
||||
|
||||
decorator = _beartype._create_beartype_decorator(
|
||||
_beartype.RuntimeTypeCheckState.WARNINGS,
|
||||
)
|
||||
decorated = decorator(func_with_incorrect_type_hint_and_side_effect)
|
||||
decorated("string_input") # type: ignore[arg-type]
|
||||
self.assertEqual(call_count, 1)
|
||||
decorated(1)
|
||||
# The return value violates the type hint, but the function is called
|
||||
# only once.
|
||||
self.assertEqual(call_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -34,6 +34,7 @@ import pytorch_test_common
|
||||
import torch
|
||||
from torch import export as torch_export
|
||||
from torch.onnx import _constants, verification
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.opinfo import core as opinfo_core
|
||||
@ -205,6 +206,7 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
||||
if not is_model_script and not self.is_script:
|
||||
_run_test(model, tracing_remained_onnx_input_idx)
|
||||
|
||||
@_beartype.beartype
|
||||
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
|
||||
self,
|
||||
model: _ModelType,
|
||||
@ -358,6 +360,7 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def run_ort(
|
||||
onnx_model: Union[str, torch.onnx.ONNXProgram],
|
||||
pytorch_inputs: Sequence[_InputArgsType],
|
||||
@ -403,6 +406,7 @@ def run_ort(
|
||||
return session.run(None, ort_input)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_clone_model(model: _ModelType) -> _ModelType:
|
||||
"""Used for preserving original model in case forward mutates model states."""
|
||||
try:
|
||||
@ -414,12 +418,14 @@ def _try_clone_model(model: _ModelType) -> _ModelType:
|
||||
return model
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_clone_inputs(input_args, input_kwargs):
|
||||
ref_input_args = copy.deepcopy(input_args)
|
||||
ref_input_kwargs = copy.deepcopy(input_kwargs)
|
||||
return ref_input_args, ref_input_kwargs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_pytorch_onnx_with_ort(
|
||||
onnx_program: torch.onnx.ONNXProgram,
|
||||
model: _ModelType,
|
||||
|
@ -22,7 +22,7 @@ import torch.onnx
|
||||
from torch import nn
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.onnx._internal import exporter
|
||||
from torch.onnx._internal import _beartype, exporter
|
||||
from torch.onnx._internal.fx import (
|
||||
diagnostics,
|
||||
fx_symbolic_graph_extractor,
|
||||
@ -721,6 +721,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
CustomModule(), (torch.randn(1, 2, 3),)
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _test_fx_symbolic_tracer_large_scale_exporter(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -954,6 +955,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
super().setUp()
|
||||
self.ort_version = onnxruntime.__version__
|
||||
|
||||
@_beartype.beartype
|
||||
def _test_fake_tensor_mode_exporter(
|
||||
self,
|
||||
model_name: str,
|
||||
|
@ -35,6 +35,10 @@ Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.
|
||||
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
|
||||
transparently dispatched to their non inplace versions in
|
||||
"run_symbolic_function". See Note [Export inplace](#export-inplace)
|
||||
- Required: Annotate new symbolic functions with type annotations and decorate
|
||||
with `@_beartype.beartype` to enable runtime type checking.
|
||||
`@_beartype.beartype` should typically be the closest to the function to
|
||||
ensure proper type checking.
|
||||
|
||||
### A note on Tensor types
|
||||
|
||||
|
132
torch/onnx/_internal/_beartype.py
Normal file
132
torch/onnx/_internal/_beartype.py
Normal file
@ -0,0 +1,132 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""An internal wrapper for the beartype library.
|
||||
|
||||
The module returns a no-op decorator when the beartype library is not installed.
|
||||
"""
|
||||
import enum
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
import typing
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
|
||||
try:
|
||||
import beartype as _beartype_lib # type: ignore[import]
|
||||
from beartype import roar as _roar # type: ignore[import]
|
||||
|
||||
# Beartype warns when we import from typing because the types are deprecated
|
||||
# in Python 3.9. But there will be a long time until we can move to using
|
||||
# the native container types for type annotations (when 3.9 is the lowest
|
||||
# supported version). So we silence the warning.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=_roar.BeartypeDecorHintPep585DeprecationWarning,
|
||||
)
|
||||
|
||||
if _beartype_lib.__version__ == "0.16.0":
|
||||
# beartype 0.16.0 has a bug that causes it to crash when used with
|
||||
# PyTorch. See https://github.com/beartype/beartype/issues/282
|
||||
warnings.warn("beartype 0.16.0 is not supported. Please upgrade to 0.16.1+.")
|
||||
_beartype_lib = None # type: ignore[assignment]
|
||||
except ImportError:
|
||||
_beartype_lib = None # type: ignore[assignment]
|
||||
except Exception as e:
|
||||
# Warn errors that are not import errors (unexpected).
|
||||
warnings.warn(f"{e}")
|
||||
_beartype_lib = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@enum.unique
|
||||
class RuntimeTypeCheckState(enum.Enum):
|
||||
"""Runtime type check state."""
|
||||
|
||||
# Runtime type checking is disabled.
|
||||
DISABLED = enum.auto()
|
||||
# Runtime type checking is enabled but warnings are shown only.
|
||||
WARNINGS = enum.auto()
|
||||
# Runtime type checking is enabled.
|
||||
ERRORS = enum.auto()
|
||||
|
||||
|
||||
class CallHintViolationWarning(UserWarning):
|
||||
"""Warning raised when a type hint is violated during a function call."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _no_op_decorator(func):
|
||||
return func
|
||||
|
||||
|
||||
def _create_beartype_decorator(
|
||||
runtime_check_state: RuntimeTypeCheckState,
|
||||
):
|
||||
# beartype needs to be imported outside of the function and aliased because
|
||||
# this module overwrites the name "beartype".
|
||||
|
||||
if runtime_check_state == RuntimeTypeCheckState.DISABLED:
|
||||
return _no_op_decorator
|
||||
if _beartype_lib is None:
|
||||
# If the beartype library is not installed, return a no-op decorator
|
||||
return _no_op_decorator
|
||||
|
||||
assert isinstance(_beartype_lib, ModuleType)
|
||||
|
||||
if runtime_check_state == RuntimeTypeCheckState.ERRORS:
|
||||
# Enable runtime type checking which errors on any type hint violation.
|
||||
return _beartype_lib.beartype
|
||||
|
||||
# Warnings only
|
||||
def beartype(func):
|
||||
"""Warn on type hint violation."""
|
||||
|
||||
if "return" in func.__annotations__:
|
||||
# Remove the return type from the func function's
|
||||
# annotations so that the beartype decorator does not complain
|
||||
# about the return type.
|
||||
return_type = func.__annotations__["return"]
|
||||
del func.__annotations__["return"]
|
||||
beartyped = _beartype_lib.beartype(func)
|
||||
# Restore the return type to the func function's annotations
|
||||
func.__annotations__["return"] = return_type
|
||||
else:
|
||||
beartyped = _beartype_lib.beartype(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def _coerce_beartype_exceptions_to_warnings(*args, **kwargs):
|
||||
try:
|
||||
return beartyped(*args, **kwargs)
|
||||
except _roar.BeartypeCallHintParamViolation:
|
||||
# Fall back to the original function if the beartype hint is violated.
|
||||
warnings.warn(
|
||||
traceback.format_exc(),
|
||||
category=CallHintViolationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return func(*args, **kwargs) # noqa: B012
|
||||
|
||||
return _coerce_beartype_exceptions_to_warnings
|
||||
|
||||
return beartype
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# This is a hack to make mypy play nicely with the beartype decorator.
|
||||
def beartype(func):
|
||||
return func
|
||||
|
||||
else:
|
||||
_TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK = os.getenv(
|
||||
"TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK"
|
||||
)
|
||||
if _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "ERRORS":
|
||||
_runtime_type_check_state = RuntimeTypeCheckState.ERRORS
|
||||
elif _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "DISABLED":
|
||||
_runtime_type_check_state = RuntimeTypeCheckState.DISABLED
|
||||
else:
|
||||
_runtime_type_check_state = RuntimeTypeCheckState.WARNINGS
|
||||
beartype = _create_beartype_decorator(_runtime_type_check_state)
|
||||
# Make sure that the beartype decorator is enabled whichever path we took.
|
||||
assert beartype is not None
|
@ -4,7 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gzip
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from collections.abc import Generator
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,9 +14,6 @@ from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
||||
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
|
||||
from torch.utils import cpp_backtrace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
|
||||
"""Returns the current C++ call stack.
|
||||
@ -26,6 +24,7 @@ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.S
|
||||
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
|
||||
|
||||
"""
|
||||
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
|
||||
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
|
||||
frame_messages = []
|
||||
for frame in frames:
|
||||
@ -71,6 +70,9 @@ class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
|
||||
|
||||
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
|
||||
"""Records the current C++ call stack in the diagnostic."""
|
||||
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
|
||||
# No need to skip this function because python frame is not recorded
|
||||
# in cpp call stack.
|
||||
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
|
||||
stack.message = "C++ call stack"
|
||||
self.with_stack(stack)
|
||||
@ -198,6 +200,7 @@ def diagnose(
|
||||
This is a wrapper around `context.log` that uses the global diagnostic
|
||||
context.
|
||||
"""
|
||||
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
|
||||
diagnostic = TorchScriptOnnxExportDiagnostic(
|
||||
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ import logging
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
from torch.onnx._internal.diagnostics.infra import formatter, utils
|
||||
|
||||
@ -13,10 +14,12 @@ from torch.onnx._internal.diagnostics.infra import formatter, utils
|
||||
MessageFormatterType = Callable[..., str]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
|
||||
return f"{formatter.display_name(fn)}. "
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def format_exception_in_markdown(exception: Exception) -> str:
|
||||
msg_list = ["### Exception log", "```"]
|
||||
msg_list.extend(
|
||||
@ -26,6 +29,7 @@ def format_exception_in_markdown(exception: Exception) -> str:
|
||||
return "\n".join(msg_list)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def format_function_signature_in_markdown(
|
||||
fn: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
@ -42,6 +46,7 @@ def format_function_signature_in_markdown(
|
||||
return "\n".join(msg_list)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def format_return_values_in_markdown(
|
||||
return_values: Any,
|
||||
format_argument: Callable[[Any], str] = formatter.format_argument,
|
||||
@ -54,6 +59,7 @@ ModifierCallableType = Callable[
|
||||
]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def diagnose_call(
|
||||
rule: infra.Rule,
|
||||
*,
|
||||
|
@ -7,7 +7,7 @@ import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from torch._logging import LazyString
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.diagnostics.infra import sarif
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ def lazy_format_exception(exception: Exception) -> LazyString:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def snake_case_to_camel_case(s: str) -> str:
|
||||
splits = s.split("_")
|
||||
if len(splits) <= 1:
|
||||
@ -42,14 +43,17 @@ def snake_case_to_camel_case(s: str) -> str:
|
||||
return "".join([splits[0], *map(str.capitalize, splits[1:])])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def camel_case_to_snake_case(s: str) -> str:
|
||||
return re.sub(r"([A-Z])", r"_\1", s).lower()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def kebab_case_to_snake_case(s: str) -> str:
|
||||
return s.replace("-", "_")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _convert_key(
|
||||
object: Union[Dict[str, Any], Any], convert: Callable[[str], str]
|
||||
) -> Union[Dict[str, Any], Any]:
|
||||
@ -88,16 +92,19 @@ def _convert_key(
|
||||
return new_dict
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def sarif_to_json(attr_cls_obj: _SarifClass, indent: Optional[str] = " ") -> str:
|
||||
dict = dataclasses.asdict(attr_cls_obj)
|
||||
dict = _convert_key(dict, snake_case_to_camel_case)
|
||||
return json.dumps(dict, indent=indent, separators=(",", ":"))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def format_argument(obj: Any) -> str:
|
||||
return f"{type(obj)}"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def display_name(fn: Callable) -> str:
|
||||
if hasattr(fn, "__qualname__"):
|
||||
return fn.__qualname__
|
||||
|
@ -6,9 +6,11 @@ import inspect
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.diagnostics.infra import _infra, formatter
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
|
||||
"""Returns a StackFrame for the given traceback.FrameSummary."""
|
||||
snippet = frame.line
|
||||
@ -24,13 +26,14 @@ def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
|
||||
"""Returns the current Python call stack."""
|
||||
if frames_to_skip < 0:
|
||||
raise ValueError("frames_to_skip must be non-negative")
|
||||
if frames_to_log < 0:
|
||||
raise ValueError("frames_to_log must be non-negative")
|
||||
frames_to_skip += 1 # Skip this function.
|
||||
frames_to_skip += 2 # Skip this function and beartype.
|
||||
stack = _infra.Stack()
|
||||
# Frames are returned in order of oldest to newest.
|
||||
frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
|
||||
@ -51,6 +54,7 @@ def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[st
|
||||
return source_lines, lineno, inspect.getsourcefile(fn)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def function_location(fn: Callable) -> _infra.Location:
|
||||
"""Returns a Location for the given function."""
|
||||
source_lines, lineno, uri = _function_source_info(fn)
|
||||
@ -63,6 +67,7 @@ def function_location(fn: Callable) -> _infra.Location:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def function_state(
|
||||
fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
|
@ -7,6 +7,7 @@ import abc
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -39,7 +40,7 @@ import torch.export as torch_export
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
from torch.onnx._internal import io_adapter
|
||||
from torch.onnx._internal import _beartype, io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
from torch.onnx._internal.fx import (
|
||||
decomposition_table,
|
||||
@ -52,8 +53,6 @@ from torch.onnx._internal.fx import (
|
||||
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
|
||||
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
import onnx
|
||||
import onnxruntime # type: ignore[import]
|
||||
import onnxscript # type: ignore[import]
|
||||
@ -62,6 +61,15 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
else:
|
||||
try:
|
||||
# beartype needs this import due to runtime type checking.
|
||||
# This cannot be normally imported at top level due to
|
||||
# https://github.com/pytorch/pytorch/issues/103764
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
except ImportError:
|
||||
# The error will be handled elsewhere when the exporter is used.
|
||||
pass
|
||||
|
||||
_DEFAULT_OPSET_VERSION: Final[int] = 18
|
||||
"""The default ONNX opset version the exporter will use if one is not specified explicitly
|
||||
@ -170,6 +178,7 @@ class OnnxRegistry:
|
||||
)
|
||||
self._register(internal_name_instance, symbolic_function)
|
||||
|
||||
@_beartype.beartype
|
||||
def _register(
|
||||
self,
|
||||
internal_qualified_name: registration.OpName,
|
||||
@ -183,9 +192,10 @@ class OnnxRegistry:
|
||||
"""
|
||||
self._registry[internal_qualified_name].append(symbolic_function)
|
||||
|
||||
@_beartype.beartype
|
||||
def register_op(
|
||||
self,
|
||||
function: Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction],
|
||||
function: Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"],
|
||||
namespace: str,
|
||||
op_name: str,
|
||||
overload: Optional[str] = None,
|
||||
@ -215,6 +225,7 @@ class OnnxRegistry:
|
||||
)
|
||||
self._register(internal_name_instance, symbolic_function)
|
||||
|
||||
@_beartype.beartype
|
||||
def get_op_functions(
|
||||
self, namespace: str, op_name: str, overload: Optional[str] = None
|
||||
) -> Optional[List[registration.ONNXFunction]]:
|
||||
@ -237,6 +248,7 @@ class OnnxRegistry:
|
||||
)
|
||||
return self._registry.get(internal_name_instance)
|
||||
|
||||
@_beartype.beartype
|
||||
def is_registered_op(
|
||||
self, namespace: str, op_name: str, overload: Optional[str] = None
|
||||
) -> bool:
|
||||
@ -256,6 +268,7 @@ class OnnxRegistry:
|
||||
)
|
||||
return functions is not None
|
||||
|
||||
@_beartype.beartype
|
||||
def _all_registered_ops(self) -> Set[str]:
|
||||
"""Returns the set of all registered function names."""
|
||||
return {
|
||||
@ -297,6 +310,7 @@ class ExportOptions:
|
||||
onnx_registry: Optional[OnnxRegistry] = None
|
||||
"""The ONNX registry used to register ATen operators to ONNX functions."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -342,9 +356,10 @@ class ResolvedExportOptions(ExportOptions):
|
||||
"""The diagnostics context for the export. Responsible for recording diagnostics,
|
||||
logging diagnostics, and generating the SARIF log."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
options: Union[ExportOptions, ResolvedExportOptions],
|
||||
options: Union[ExportOptions, "ResolvedExportOptions"],
|
||||
model: Optional[Union[torch.nn.Module, Callable, torch_export.ExportedProgram]] = None, # type: ignore[name-defined]
|
||||
):
|
||||
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
|
||||
@ -375,6 +390,7 @@ class ResolvedExportOptions(ExportOptions):
|
||||
else:
|
||||
T = TypeVar("T")
|
||||
|
||||
@_beartype.beartype
|
||||
def resolve(value: Optional[T], fallback: Union[T, Callable[[], T]]) -> T:
|
||||
if value is not None:
|
||||
return value
|
||||
@ -392,7 +408,7 @@ class ResolvedExportOptions(ExportOptions):
|
||||
else:
|
||||
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
||||
|
||||
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
|
||||
self.fake_context = resolve(options.fake_context, None)
|
||||
self.diagnostic_context = diagnostics.DiagnosticContext(
|
||||
"torch.onnx.dynamo_export",
|
||||
torch.__version__,
|
||||
@ -400,8 +416,10 @@ class ResolvedExportOptions(ExportOptions):
|
||||
)
|
||||
|
||||
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
|
||||
self.decomposition_table = decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
|
||||
self.onnx_registry
|
||||
self.decomposition_table = (
|
||||
decomposition_table.create_onnx_friendly_decomposition_table(
|
||||
self.onnx_registry
|
||||
)
|
||||
)
|
||||
|
||||
from torch.onnx._internal.fx import onnxfunction_dispatcher
|
||||
@ -544,6 +562,7 @@ class ONNXProgramSerializer(Protocol):
|
||||
class ProtobufONNXProgramSerializer:
|
||||
"""Serializes ONNX graph as Protobuf."""
|
||||
|
||||
@_beartype.beartype
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
@ -565,6 +584,7 @@ class LargeProtobufONNXProgramSerializer:
|
||||
def __init__(self, destination_path: str):
|
||||
self._destination_path = destination_path
|
||||
|
||||
@_beartype.beartype
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
@ -593,7 +613,7 @@ class ONNXRuntimeOptions:
|
||||
execution_provider_options: ONNX Runtime execution provider options.
|
||||
"""
|
||||
|
||||
session_options: Optional[Sequence[onnxruntime.SessionOptions]] = None
|
||||
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
|
||||
"""ONNX Runtime session options."""
|
||||
|
||||
execution_providers: Optional[
|
||||
@ -604,10 +624,11 @@ class ONNXRuntimeOptions:
|
||||
execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
|
||||
"""ONNX Runtime execution provider options."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_options: Optional[Sequence[onnxruntime.SessionOptions]] = None,
|
||||
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
|
||||
execution_providers: Optional[
|
||||
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
|
||||
] = None,
|
||||
@ -642,6 +663,7 @@ class ONNXProgram:
|
||||
Optional[Union[torch.nn.Module, Callable, torch_export.ExportedProgram]]
|
||||
]
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
model_proto: onnx.ModelProto, # type: ignore[name-defined]
|
||||
@ -728,7 +750,7 @@ class ONNXProgram:
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
|
||||
|
||||
onnxruntime_input = {
|
||||
k.name: v.numpy(force=True) # type: ignore[union-attr]
|
||||
k.name: v.numpy(force=True)
|
||||
for k, v in zip(ort_session.get_inputs(), onnx_input)
|
||||
}
|
||||
|
||||
@ -849,6 +871,7 @@ class ONNXProgram:
|
||||
|
||||
return self._fake_context
|
||||
|
||||
@_beartype.beartype
|
||||
def adapt_torch_inputs_to_onnx(
|
||||
self,
|
||||
*model_args,
|
||||
@ -917,10 +940,11 @@ class ONNXProgram:
|
||||
assert (
|
||||
model_with_state_dict is not None
|
||||
), "model_with_state_dict must be specified."
|
||||
return self._input_adapter.apply( # type: ignore[return-value]
|
||||
return self._input_adapter.apply(
|
||||
*model_args, model=model_with_state_dict, **model_kwargs
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def adapt_torch_outputs_to_onnx(
|
||||
self,
|
||||
model_outputs: Any,
|
||||
@ -979,8 +1003,9 @@ class ONNXProgram:
|
||||
assert (
|
||||
model_with_state_dict is not None
|
||||
), "model_with_state_dict must be specified."
|
||||
return self._output_adapter.apply(model_outputs, model=model_with_state_dict) # type: ignore[return-value]
|
||||
return self._output_adapter.apply(model_outputs, model=model_with_state_dict)
|
||||
|
||||
@_beartype.beartype
|
||||
def save(
|
||||
self,
|
||||
destination: Union[str, io.BufferedIOBase],
|
||||
@ -1069,6 +1094,7 @@ class ONNXProgram:
|
||||
"External tensor data will be saved alongside the model on disk."
|
||||
) from exc
|
||||
|
||||
@_beartype.beartype
|
||||
def save_diagnostics(self, destination: str) -> None:
|
||||
"""Saves the export diagnostics as a SARIF log to the specified destination path.
|
||||
|
||||
@ -1110,7 +1136,8 @@ class ONNXProgram:
|
||||
# https://github.com/pytorch/pytorch/issues/103764
|
||||
import onnx
|
||||
|
||||
return cls(
|
||||
# TODO: Should we populate ONNXProgram with more info, such _model_torch for easier debug?
|
||||
return ONNXProgram(
|
||||
onnx.ModelProto(), # type: ignore[attr-defined]
|
||||
io_adapter.InputAdapter(),
|
||||
io_adapter.OutputAdapter(),
|
||||
@ -1168,6 +1195,7 @@ class FXGraphExtractor(abc.ABC):
|
||||
|
||||
|
||||
class Exporter:
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
options: ResolvedExportOptions,
|
||||
@ -1345,6 +1373,7 @@ class InvalidExportOptionsError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _assert_dependencies(export_options: ResolvedExportOptions):
|
||||
opset_version = export_options.onnx_registry.opset_version
|
||||
|
||||
@ -1386,6 +1415,7 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
|
||||
raise missing_opset("onnxscript")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def dynamo_export(
|
||||
model: Union[torch.nn.Module, Callable, torch_export.ExportedProgram], # type: ignore[name-defined]
|
||||
/,
|
||||
|
@ -17,7 +17,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
|
||||
|
||||
|
||||
@ -128,6 +128,7 @@ def _unified_diff(a: str, b: str) -> str:
|
||||
return diff
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _transform_diagnose_call_message_formatter(
|
||||
run: Callable,
|
||||
self: Transform,
|
||||
@ -311,6 +312,7 @@ class AnalysisResult(abc.ABC): # noqa: B024
|
||||
|
||||
|
||||
class Analysis(abc.ABC):
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
|
@ -9,9 +9,14 @@ import torch
|
||||
import torch._ops
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
from torch.onnx._internal.fx import registration
|
||||
|
||||
|
||||
# NOTE: OnnxRegistry annotation: beartype is a runtime type checker for python3,
|
||||
# so it doesn't work with TYPE_CHECKING
|
||||
@_beartype.beartype
|
||||
def _create_onnx_supports_op_overload_table(
|
||||
registry,
|
||||
) -> Set[Union[torch._ops.OperatorBase, Callable]]:
|
||||
@ -71,6 +76,7 @@ def _create_onnx_supports_op_overload_table(
|
||||
return table
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def create_onnx_friendly_decomposition_table(
|
||||
registry,
|
||||
) -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
|
@ -26,7 +26,7 @@ import torch._dynamo
|
||||
import torch.export as torch_export
|
||||
import torch.fx
|
||||
import torch.onnx
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
from torch.onnx._internal import _beartype, exporter, io_adapter
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ class _PyTreeExtensionContext:
|
||||
for class_type in self._extensions:
|
||||
pytree.SUPPORTED_NODES.pop(class_type)
|
||||
|
||||
@_beartype.beartype
|
||||
def register_pytree_node(
|
||||
self,
|
||||
class_type: Type,
|
||||
@ -82,11 +83,13 @@ class _PyTreeExtensionContext:
|
||||
except ImportError as e:
|
||||
return
|
||||
|
||||
@_beartype.beartype
|
||||
def model_output_flatten(
|
||||
output: modeling_outputs.ModelOutput,
|
||||
) -> Tuple[List[Any], pytree.Context]:
|
||||
return list(output.values()), (type(output), list(output.keys()))
|
||||
|
||||
@_beartype.beartype
|
||||
def model_output_unflatten(
|
||||
values: List[Any], context: pytree.Context
|
||||
) -> modeling_outputs.ModelOutput:
|
||||
@ -105,7 +108,7 @@ class _PyTreeExtensionContext:
|
||||
|
||||
for _, class_type in named_model_output_classes:
|
||||
self.register_pytree_node(
|
||||
class_type, model_output_flatten, model_output_unflatten # type: ignore[arg-type ]
|
||||
class_type, model_output_flatten, model_output_unflatten
|
||||
)
|
||||
|
||||
|
||||
@ -229,6 +232,7 @@ class DynamoExport(exporter.FXGraphExtractor):
|
||||
|
||||
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
|
||||
|
||||
@_beartype.beartype
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: exporter.ResolvedExportOptions,
|
||||
|
@ -16,7 +16,7 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.onnx import _type_utils as jit_type_utils
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import (
|
||||
_pass,
|
||||
diagnostics,
|
||||
@ -27,6 +27,7 @@ from torch.onnx._internal.fx import (
|
||||
from torch.utils import _pytree
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _fx_node_to_onnx_message_formatter(
|
||||
fn: Callable,
|
||||
self,
|
||||
@ -37,6 +38,7 @@ def _fx_node_to_onnx_message_formatter(
|
||||
return f"FX Node: {node.op}:{node.target}[name={node.name}]. "
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _fx_graph_to_onnx_message_formatter(
|
||||
fn: Callable,
|
||||
self,
|
||||
@ -85,6 +87,7 @@ def _location_from_fx_stack_trace(
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _retrieve_or_adapt_input_to_graph_set(
|
||||
fx_node_arg: fx_type_utils.Argument,
|
||||
fx_name_to_onnxscript_value: Dict[
|
||||
@ -194,7 +197,7 @@ def _retrieve_or_adapt_input_to_graph_set(
|
||||
)
|
||||
return sequence_elements
|
||||
if isinstance(onnx_tensor, torch.dtype):
|
||||
onnx_tensor = int( # type: ignore[call-overload]
|
||||
onnx_tensor = int(
|
||||
jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type()
|
||||
)
|
||||
# NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But
|
||||
@ -226,11 +229,12 @@ def filter_incompatible_and_dtype_convert_kwargs(kwargs):
|
||||
# default case.
|
||||
continue
|
||||
else:
|
||||
value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload]
|
||||
value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type())
|
||||
filtered[key] = value
|
||||
return filtered
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _fill_tensor_shape_type(
|
||||
onnxscript_values: Union[
|
||||
onnxscript_graph_building.TorchScriptTensor,
|
||||
@ -309,6 +313,7 @@ def _fill_tensor_shape_type(
|
||||
onnxscript_value.name = name
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _fill_in_default_kwargs(
|
||||
node: torch.fx.Node,
|
||||
) -> Tuple[List[fx_type_utils.Argument], Dict[str, fx_type_utils.Argument]]:
|
||||
@ -343,6 +348,7 @@ def _fill_in_default_kwargs(
|
||||
return complete_args, complete_kwargs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _wrap_fx_args_as_onnxscript_args(
|
||||
complete_args: List[fx_type_utils.Argument],
|
||||
complete_kwargs: Dict[str, fx_type_utils.Argument],
|
||||
@ -405,6 +411,7 @@ class FxOnnxInterpreter:
|
||||
# DO NOT add other class-level attributes.
|
||||
self.diagnostic_context = diagnostic_context
|
||||
|
||||
@_beartype.beartype
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.fx_node_to_onnx,
|
||||
diagnostic_message_formatter=_fx_node_to_onnx_message_formatter,
|
||||
@ -486,6 +493,7 @@ class FxOnnxInterpreter:
|
||||
else:
|
||||
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
|
||||
|
||||
@_beartype.beartype
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.fx_graph_to_onnx,
|
||||
diagnostic_message_formatter=_fx_graph_to_onnx_message_formatter,
|
||||
@ -581,6 +589,7 @@ class FxOnnxInterpreter:
|
||||
|
||||
return onnxscript_graph
|
||||
|
||||
@_beartype.beartype
|
||||
def placeholder(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
@ -636,6 +645,7 @@ class FxOnnxInterpreter:
|
||||
|
||||
fx_name_to_onnxscript_value[node.name] = output
|
||||
|
||||
@_beartype.beartype
|
||||
def call_function(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
@ -720,6 +730,7 @@ class FxOnnxInterpreter:
|
||||
)
|
||||
fx_name_to_onnxscript_value[node.name] = output
|
||||
|
||||
@_beartype.beartype
|
||||
def output(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
@ -746,10 +757,12 @@ class FxOnnxInterpreter:
|
||||
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name]
|
||||
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
|
||||
|
||||
@_beartype.beartype
|
||||
def call_method(self, node: torch.fx.Node):
|
||||
# TODO(wechi): Support call_method.
|
||||
raise RuntimeError("call_method is not supported yet.")
|
||||
|
||||
@_beartype.beartype
|
||||
def call_module(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
@ -827,6 +840,7 @@ class FxOnnxInterpreter:
|
||||
|
||||
# Skip op_level_validation for call_module. Subgraph nodes are validated individually.
|
||||
|
||||
@_beartype.beartype
|
||||
def get_attr(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
|
@ -10,7 +10,7 @@ import torch.fx
|
||||
import torch.onnx
|
||||
|
||||
import torch.onnx._internal.fx.passes as passes
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
from torch.onnx._internal import _beartype, exporter, io_adapter
|
||||
|
||||
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
||||
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
||||
@ -37,6 +37,7 @@ class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
||||
exporter.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def is_leaf_module(
|
||||
self, module: torch.nn.Module, module_qualified_name: str
|
||||
) -> bool:
|
||||
@ -45,6 +46,7 @@ class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
||||
# torch.fx._symbolic_trace.Tracer.call_module.
|
||||
return False
|
||||
|
||||
@_beartype.beartype
|
||||
def to_bool(self, obj: torch.fx.Proxy) -> bool:
|
||||
# FIXME: This is a hack to tracing through if-else Python blocks.
|
||||
# It may generate incorrect ONNX graphs if the if-else block
|
||||
@ -80,6 +82,7 @@ def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
|
||||
return wrapper, target
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _module_expansion_symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
@ -157,6 +160,7 @@ class FXSymbolicTracer(exporter.FXGraphExtractor):
|
||||
# TODO: plumb ``concrete_args`` to symbolic_trace call at ``generate_fx``
|
||||
self.concrete_args = concrete_args
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||
self, model, model_args, model_kwargs
|
||||
) -> torch.fx.GraphModule:
|
||||
@ -234,6 +238,7 @@ class FXSymbolicTracer(exporter.FXGraphExtractor):
|
||||
|
||||
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
|
||||
|
||||
@_beartype.beartype
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: exporter.ResolvedExportOptions,
|
||||
|
@ -22,7 +22,7 @@ from typing import (
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import (
|
||||
diagnostics,
|
||||
registration,
|
||||
@ -31,13 +31,17 @@ from torch.onnx._internal.fx import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnxscript # type: ignore[import]
|
||||
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
graph_building as onnxscript_graph_building,
|
||||
)
|
||||
|
||||
from torch.onnx import OnnxRegistry
|
||||
|
||||
|
||||
# For beartype
|
||||
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
graph_building as onnxscript_graph_building,
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
|
||||
fn: Callable,
|
||||
self,
|
||||
@ -54,6 +58,7 @@ def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
|
||||
return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
|
||||
fn: Callable,
|
||||
self,
|
||||
@ -92,7 +97,7 @@ class OnnxFunctionDispatcher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onnx_registry: OnnxRegistry,
|
||||
onnx_registry: "OnnxRegistry",
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
):
|
||||
"""Initialize the ONNX Function dispatcher.
|
||||
@ -104,6 +109,7 @@ class OnnxFunctionDispatcher:
|
||||
self.onnx_registry = onnx_registry
|
||||
self.diagnostic_context = diagnostic_context
|
||||
|
||||
@_beartype.beartype
|
||||
def dispatch(
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
@ -114,7 +120,7 @@ class OnnxFunctionDispatcher:
|
||||
],
|
||||
onnx_kwargs: Dict[str, fx_type_utils.Argument],
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
) -> Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]:
|
||||
) -> Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"]:
|
||||
"""Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
|
||||
Args:
|
||||
node: The TorchFX node to dispatch the function for.
|
||||
@ -142,6 +148,7 @@ class OnnxFunctionDispatcher:
|
||||
diagnostic_context,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _filter_or_keep_complex(
|
||||
self,
|
||||
node,
|
||||
@ -189,6 +196,7 @@ class OnnxFunctionDispatcher:
|
||||
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
|
||||
return default_and_custom_functions
|
||||
|
||||
@_beartype.beartype
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.find_opschema_matched_symbolic_function,
|
||||
diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter,
|
||||
@ -276,6 +284,7 @@ class OnnxFunctionDispatcher:
|
||||
)
|
||||
return symbolic_function_list[0].onnx_function
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_aten_name(
|
||||
self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext
|
||||
) -> registration.OpName:
|
||||
@ -341,6 +350,7 @@ class OnnxFunctionDispatcher:
|
||||
diagnostic_context.log(diagnostic)
|
||||
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
|
||||
|
||||
@_beartype.beartype
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.find_operator_overloads_in_onnx_registry,
|
||||
diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter,
|
||||
@ -542,6 +552,7 @@ class _OnnxSchemaChecker:
|
||||
"""
|
||||
return self._matching_score
|
||||
|
||||
@_beartype.beartype
|
||||
def perfect_match_inputs(
|
||||
self,
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
@ -676,6 +687,7 @@ class _OnnxSchemaChecker:
|
||||
diagnostic.info("match score: %d", self.match_score)
|
||||
return is_perfect_match
|
||||
|
||||
@_beartype.beartype
|
||||
def _match_onnx_attribute_type(
|
||||
self,
|
||||
attribute_name: str,
|
||||
@ -701,6 +713,7 @@ class _OnnxSchemaChecker:
|
||||
return False
|
||||
return True
|
||||
|
||||
@_beartype.beartype
|
||||
def _record_matching_score(
|
||||
self,
|
||||
inputs: Sequence[
|
||||
@ -755,10 +768,10 @@ class _OnnxSchemaChecker:
|
||||
|
||||
# NOTE: Referenced from onnxscript internal function.
|
||||
# Importing this function makes the code less robust, as it is not a public API.
|
||||
|
||||
@_beartype.beartype
|
||||
def _separate_input_attributes_from_arguments(
|
||||
self,
|
||||
param_schemas: Sequence[onnxscript.values.ParamSchema],
|
||||
param_schemas: Sequence["onnxscript.values.ParamSchema"],
|
||||
args: Sequence[
|
||||
Optional[
|
||||
Union[fx_type_utils.TensorLike, str, int, float, bool, list, complex]
|
||||
@ -838,6 +851,7 @@ class _OnnxSchemaChecker:
|
||||
return onnx_inputs, onnx_attributes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
|
||||
"""Check if the node has complex dtype recursively."""
|
||||
if (
|
||||
@ -853,6 +867,7 @@ def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_onnx_data_type(
|
||||
torch_input: Optional[
|
||||
Union[fx_type_utils.TensorLike, str, int, float, bool, list, tuple, complex]
|
||||
|
@ -13,7 +13,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.onnx import _constants, _type_utils as jit_type_utils
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import (
|
||||
diagnostics,
|
||||
fx_onnx_interpreter,
|
||||
@ -22,6 +22,7 @@ from torch.onnx._internal.fx import (
|
||||
from torch.utils import _pytree
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _op_level_debug_message_formatter(
|
||||
fn: Callable,
|
||||
self,
|
||||
@ -36,6 +37,7 @@ def _op_level_debug_message_formatter(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.op_level_debugging,
|
||||
diagnostic_message_formatter=_op_level_debug_message_formatter,
|
||||
@ -172,6 +174,7 @@ def validate_op_between_ort_torch(
|
||||
diagnostic.level = diagnostics.levels.WARNING
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size:
|
||||
"""Convert SymInt to int in shape
|
||||
|
||||
@ -198,6 +201,7 @@ def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size:
|
||||
return torch.Size(list_int_shape)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
|
||||
shape = _convert_symint_to_int_in_shape(shape)
|
||||
|
||||
@ -234,6 +238,7 @@ def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
|
||||
return torch.randn(shape, dtype=dtype)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _fx_args_to_torch_args(
|
||||
fx_args: List[fx_type_utils.Argument], fx_graph_module: torch.fx.GraphModule
|
||||
) -> List[fx_type_utils.Argument]:
|
||||
@ -263,7 +268,7 @@ def _fx_args_to_torch_args(
|
||||
f"{type(fake_tensor)}."
|
||||
)
|
||||
elif isinstance(arg, Sequence):
|
||||
wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module)) # type: ignore[arg-type]
|
||||
wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module))
|
||||
elif isinstance(arg, (int, float, torch.dtype)) or arg is None:
|
||||
wrapped_args.append(arg)
|
||||
elif isinstance(arg, torch.device):
|
||||
@ -276,6 +281,7 @@ def _fx_args_to_torch_args(
|
||||
return wrapped_args
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _wrap_fx_args_as_torch_args(
|
||||
fx_args: List[fx_type_utils.Argument],
|
||||
fx_kwargs: Dict[str, fx_type_utils.Argument],
|
||||
@ -291,6 +297,7 @@ def _wrap_fx_args_as_torch_args(
|
||||
|
||||
|
||||
# NOTE: Referenced from onnxscript internal function: _tag_arguments_with_param_schemas.
|
||||
@_beartype.beartype
|
||||
def _convert_torch_args_to_onnxfunction_args(
|
||||
param_schemas: Sequence[onnxscript.values.ParamSchema],
|
||||
args: List[fx_type_utils.Argument],
|
||||
@ -351,6 +358,7 @@ def _convert_torch_args_to_onnxfunction_args(
|
||||
return tagged_args, tagged_kwargs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any:
|
||||
try:
|
||||
import numpy as np
|
||||
@ -365,7 +373,7 @@ def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any:
|
||||
input = torch.view_as_real(input.resolve_conj())
|
||||
return input.detach().cpu().numpy()
|
||||
if isinstance(input, torch.dtype):
|
||||
return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr,call-overload]
|
||||
return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr]
|
||||
if isinstance(input, (tuple, list)):
|
||||
if len(input) == 0:
|
||||
return np.array((), dtype=np.int64)
|
||||
|
@ -13,8 +13,10 @@ from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def wrap_graph_module_for_node_meta_preservation(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> Callable:
|
||||
@ -40,6 +42,7 @@ def _get_node_base_name(node_name: str) -> Tuple[str, Optional[int]]:
|
||||
return node_name, None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def set_node_name(
|
||||
node: torch.fx.Node,
|
||||
new_name: str,
|
||||
@ -78,6 +81,7 @@ def set_node_name(
|
||||
name_to_node_cache[new_name] = node
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def replace_placeholder_name_and_target(
|
||||
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
|
||||
):
|
||||
|
@ -11,6 +11,7 @@ import torch.fx
|
||||
from torch._dispatch import python as python_dispatch
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
|
||||
@ -29,6 +30,7 @@ class Decompose(_pass.Transform):
|
||||
self.enable_dynamic_axes = enable_dynamic_axes
|
||||
self.allow_fake_constant = allow_fake_constant
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
assert not kwargs, "kwargs is not supported in Decompose."
|
||||
|
||||
|
@ -11,6 +11,7 @@ import torch.func
|
||||
import torch.fx
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -61,6 +62,7 @@ class Functionalize(_pass.Transform):
|
||||
which are not needed for ONNX inference.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
@ -97,6 +99,7 @@ class Functionalize(_pass.Transform):
|
||||
|
||||
return wrapped
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args) -> torch.fx.GraphModule:
|
||||
# To preserve stack trace info after `make_fx`.
|
||||
module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)
|
||||
@ -142,6 +145,7 @@ class RemoveInputMutation(_pass.Transform):
|
||||
for inference. They could be useful for training.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args) -> torch.fx.GraphModule:
|
||||
for node in reversed(self.module.graph.nodes):
|
||||
if (
|
||||
|
@ -23,6 +23,7 @@ from typing import (
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -58,6 +59,7 @@ class _ModuleMeta:
|
||||
_module_name: Final[str] # type: ignore[misc]
|
||||
_raw_meta: Final[Tuple[Any, Any]] # type: ignore[misc]
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
@ -211,6 +213,7 @@ class _ModuleStackMeta:
|
||||
|
||||
_module_stack: Final[List[_ModuleMeta]] # type: ignore[misc]
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
nn_module_stack_meta: Optional[
|
||||
@ -230,7 +233,7 @@ class _ModuleStackMeta:
|
||||
if is_exported_program:
|
||||
is_exported_program = False
|
||||
continue
|
||||
self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type]
|
||||
self.push(_ModuleMeta.from_raw_meta(item))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._module_stack)
|
||||
@ -259,6 +262,7 @@ class _ModuleStackMeta:
|
||||
return _ModuleMeta.create_root()
|
||||
return self._module_stack[-1]
|
||||
|
||||
@_beartype.beartype
|
||||
def is_superset_of(
|
||||
self,
|
||||
module_stack: _ModuleStackMeta,
|
||||
@ -302,6 +306,7 @@ class _ModuleStackMeta:
|
||||
"""Pushes a module meta to the stack."""
|
||||
self._module_stack.append(module_meta)
|
||||
|
||||
@_beartype.beartype
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, _ModuleStackMeta):
|
||||
return False
|
||||
@ -838,6 +843,7 @@ class Modularize(_pass.Transform):
|
||||
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
@ -848,6 +854,7 @@ class Modularize(_pass.Transform):
|
||||
self.module = module
|
||||
self.is_exported_program = is_exported_program
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self) -> torch.fx.GraphModule:
|
||||
# DCE to remove unused nodes.
|
||||
# If a submodule is unused, it is hard to analyze which nodes constitutes the submodule
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ class RestoreParameterAndBufferNames(_pass.Transform):
|
||||
super().__init__(diagnostic_context, fx_module)
|
||||
self.original_nn_module = original_nn_module
|
||||
|
||||
@_beartype.beartype
|
||||
def _rename_param_and_buffer(
|
||||
self,
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
|
@ -27,6 +27,10 @@ from torch._refs.nn import functional as _functional_refs
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
|
||||
# Imported to resolve beartype issue when type checking node.Argument.
|
||||
from torch.fx.node import Node # noqa: F401
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
|
||||
from torch.utils import _python_dispatch, _pytree
|
||||
|
||||
@ -1225,6 +1229,7 @@ class TypePromotionTable:
|
||||
for rule in _EXTRA_TYPE_PROMOTION_RULE_SET:
|
||||
self.add_rule(rule)
|
||||
|
||||
@_beartype.beartype
|
||||
def add_rule(self, rule: TypePromotionRule) -> None:
|
||||
"""Add a type promotion rule for a python op in a torch.ops module.
|
||||
|
||||
@ -1239,6 +1244,7 @@ class TypePromotionTable:
|
||||
raise ValueError(f"Invalid type promotion rule: {rule}")
|
||||
self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule
|
||||
|
||||
@_beartype.beartype
|
||||
def get_rule(
|
||||
self, py_op: torch._ops.OpOverloadPacket
|
||||
) -> Optional[TypePromotionRule]:
|
||||
@ -1246,6 +1252,7 @@ class TypePromotionTable:
|
||||
return self._rule_table.get(str(py_op), None)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def get_type_promotion_rule(
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
node: torch.fx.Node,
|
||||
@ -1298,6 +1305,7 @@ class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def find_compatible_op_overload(
|
||||
op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict
|
||||
) -> torch._ops.OpOverload:
|
||||
@ -1383,6 +1391,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
||||
node.meta["val"] = proxy_tensor.extract_val(out)
|
||||
return out
|
||||
|
||||
@_beartype.beartype
|
||||
def _create_node(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
@ -1404,6 +1413,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
||||
self._run_node_and_set_meta(node)
|
||||
return node
|
||||
|
||||
@_beartype.beartype
|
||||
def _rerun_node_after_type_promotion(
|
||||
self,
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
@ -1469,6 +1479,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected node output type: {type(node_val)}.")
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_promote_arg(
|
||||
self,
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
@ -1574,6 +1585,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
||||
|
||||
raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}")
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_promote_node(
|
||||
self,
|
||||
diagnostic: diagnostics.Diagnostic,
|
||||
@ -1708,6 +1720,7 @@ class InsertTypePromotion(_pass.Transform):
|
||||
fake_args.append(meta_value)
|
||||
return fake_args
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
assert not args, (
|
||||
"`InsertTypePromotion` deduces symbolic fake arguments from the graph. "
|
||||
|
@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import _pass
|
||||
|
||||
|
||||
@ -17,6 +18,7 @@ class MovePlaceholderToFront(_pass.Transform):
|
||||
nodes.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
graph_module = self.module
|
||||
graph = graph_module.graph
|
||||
@ -51,6 +53,7 @@ class ReplaceGetAttrWithPlaceholder(_pass.Transform):
|
||||
), "Must run ReplaceGetAttrWithPlaceholder first"
|
||||
return self._replaced_attrs
|
||||
|
||||
@_beartype.beartype
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
graph_module = self.module
|
||||
graph = graph_module.graph
|
||||
|
@ -1,13 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import functools
|
||||
from typing import List, TYPE_CHECKING, Union
|
||||
import io
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
|
||||
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
|
||||
@functools.lru_cache(None)
|
||||
|
@ -7,6 +7,7 @@ import types
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._ops
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
# We can only import onnx from this module in a type-checking context to ensure that
|
||||
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
|
||||
@ -26,7 +27,7 @@ class ONNXFunction:
|
||||
|
||||
"""
|
||||
|
||||
onnx_function: Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]
|
||||
onnx_function: Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"]
|
||||
op_full_name: str
|
||||
is_custom: bool = False
|
||||
is_complex: bool = False
|
||||
@ -41,6 +42,7 @@ class OpName:
|
||||
overload: str
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_name_parts(
|
||||
cls, namespace: str, op_name: str, overload: Optional[str] = None
|
||||
) -> OpName:
|
||||
@ -51,6 +53,7 @@ class OpName:
|
||||
return cls(namespace, op_name, overload)
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_qualified_name(cls, qualified_name: str) -> OpName:
|
||||
"""When the name is <namespace>::<op_name>[.<overload>]"""
|
||||
namespace, opname_overload = qualified_name.split("::")
|
||||
@ -59,10 +62,12 @@ class OpName:
|
||||
return cls(namespace, op_name, overload)
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName:
|
||||
return cls.from_qualified_name(op_overload.name())
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_builtin_function(
|
||||
cls, builtin_function: types.BuiltinFunctionType
|
||||
) -> OpName:
|
||||
@ -81,5 +86,6 @@ class OpName:
|
||||
module = builtin_function.__module__ # _operators or math
|
||||
return cls.from_qualified_name(module + "::" + op)
|
||||
|
||||
@_beartype.beartype
|
||||
def qualified_name(self) -> str:
|
||||
return f"{self.namespace}::{self.op_name}.{self.overload}"
|
||||
|
@ -8,7 +8,7 @@ from typing import Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.onnx import _type_utils as jit_type_utils
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
@ -16,12 +16,13 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor,
|
||||
name: str,
|
||||
location: str,
|
||||
basepath: str,
|
||||
dtype_override: Optional[onnx.TypeProto] = None, # type: ignore[name-defined]
|
||||
dtype_override: Optional["onnx.TypeProto"] = None, # type: ignore[name-defined]
|
||||
) -> onnx.TensorProto: # type: ignore[name-defined]
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
@ -113,6 +114,7 @@ def _convert_safetensors_to_torch_format(safetensors_file):
|
||||
|
||||
|
||||
# TODO: generalize to allow more checkpoints formats (torch or gguf)
|
||||
@_beartype.beartype
|
||||
def save_model_with_external_data(
|
||||
basepath: str,
|
||||
model_location: str,
|
||||
|
@ -11,7 +11,7 @@ from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Un
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
import torch.onnx
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
from torch.onnx._internal import _beartype, exporter, io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,7 +35,7 @@ class TorchExport(exporter.FXGraphExtractor):
|
||||
def generate_fx(
|
||||
self,
|
||||
options: exporter.ResolvedExportOptions,
|
||||
model: ExportedProgram, # type: ignore[override]
|
||||
model: "ExportedProgram", # type: ignore[override]
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
) -> torch.fx.GraphModule:
|
||||
@ -96,6 +96,7 @@ class TorchExport(exporter.FXGraphExtractor):
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]
|
||||
|
||||
@_beartype.beartype
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: exporter.ResolvedExportOptions,
|
||||
|
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -11,18 +13,15 @@ from typing import (
|
||||
runtime_checkable,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.export as torch_export
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import inspect
|
||||
|
||||
# TODO(bowbao): Add diagnostics for IO adapters.
|
||||
|
||||
|
||||
@ -56,6 +55,7 @@ class InputAdapter:
|
||||
def __init__(self, steps: Optional[List[InputAdaptStep]] = None):
|
||||
self._steps = steps or []
|
||||
|
||||
@_beartype.beartype
|
||||
def append_step(self, step: InputAdaptStep) -> None:
|
||||
"""Appends a step to the input adapt steps.
|
||||
|
||||
@ -64,6 +64,7 @@ class InputAdapter:
|
||||
"""
|
||||
self._steps.append(step)
|
||||
|
||||
@_beartype.beartype
|
||||
def apply(
|
||||
self,
|
||||
*model_args,
|
||||
@ -71,7 +72,7 @@ class InputAdapter:
|
||||
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
|
||||
] = None,
|
||||
**model_kwargs,
|
||||
) -> Sequence[Union[int, float, bool, str, torch.Tensor, torch.dtype, None]]:
|
||||
) -> Sequence[Union[int, float, bool, str, "torch.Tensor", torch.dtype, None]]:
|
||||
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
||||
|
||||
Args:
|
||||
@ -118,6 +119,7 @@ class OutputAdapter:
|
||||
def __init__(self, steps: Optional[List[OutputAdaptStep]] = None):
|
||||
self._steps = steps or []
|
||||
|
||||
@_beartype.beartype
|
||||
def append_step(self, step: OutputAdaptStep) -> None:
|
||||
"""Appends a step to the output format steps.
|
||||
|
||||
@ -126,13 +128,14 @@ class OutputAdapter:
|
||||
"""
|
||||
self._steps.append(step)
|
||||
|
||||
@_beartype.beartype
|
||||
def apply(
|
||||
self,
|
||||
model_outputs: Any,
|
||||
model: Optional[
|
||||
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
|
||||
] = None,
|
||||
) -> Sequence[Union[torch.Tensor, int, float, bool, str]]:
|
||||
) -> Sequence[Union["torch.Tensor", int, float, bool, str]]:
|
||||
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
||||
|
||||
Args:
|
||||
@ -260,7 +263,7 @@ class MergeKwargsIntoArgsInputStep(InputAdaptStep):
|
||||
class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
|
||||
"""Append parameters and buffers to model's positional argument list."""
|
||||
|
||||
def __init__(self, inputs: Tuple[torch.Tensor, ...]) -> None:
|
||||
def __init__(self, inputs: Tuple["torch.Tensor", ...]) -> None:
|
||||
self.inputs = inputs
|
||||
|
||||
def apply(
|
||||
|
@ -13,7 +13,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Un
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal import _beartype, registration
|
||||
|
||||
|
||||
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
||||
@ -43,7 +43,7 @@ class GraphContext:
|
||||
block: _C.Block
|
||||
opset: int
|
||||
original_node: _C.Node
|
||||
params_dict: Dict[str, _C.IValue]
|
||||
params_dict: Dict[str, "_C.IValue"]
|
||||
env: Dict[_C.Value, _C.Value]
|
||||
values_in_env: Set[_C.Value]
|
||||
new_nodes: List[_C.Node] = dataclasses.field(default_factory=list)
|
||||
@ -53,6 +53,7 @@ class GraphContext:
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.graph, name)
|
||||
|
||||
@_beartype.beartype
|
||||
def op(
|
||||
self,
|
||||
opname: str,
|
||||
@ -90,6 +91,7 @@ class GraphContext:
|
||||
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
|
||||
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
|
||||
|
||||
@_beartype.beartype
|
||||
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
|
||||
"""Generates an ONNX ATen op node.
|
||||
|
||||
@ -107,6 +109,7 @@ class GraphContext:
|
||||
# We are probably going to remove this only after the fx exporter is established.
|
||||
at = aten_op
|
||||
|
||||
@_beartype.beartype
|
||||
def onnxscript_op(
|
||||
self,
|
||||
onnx_fn,
|
||||
@ -150,6 +153,7 @@ class GraphContext:
|
||||
return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def add_op_with_blocks(
|
||||
graph_context: GraphContext,
|
||||
opname: str,
|
||||
@ -197,6 +201,7 @@ def add_op_with_blocks(
|
||||
return output_values, tuple(new_contexts), node
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_op(
|
||||
graph_context: GraphContext,
|
||||
opname: str,
|
||||
@ -260,6 +265,7 @@ def _add_op(
|
||||
return tuple(node.outputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _const_if_tensor(graph_context: GraphContext, arg):
|
||||
if arg is None:
|
||||
return arg
|
||||
@ -308,18 +314,21 @@ def _create_node(
|
||||
return node
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_onnx_list(value):
|
||||
return isinstance(value, Iterable) and not isinstance(
|
||||
value, (str, bytes, torch.Tensor)
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _scalar(x: torch.Tensor):
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
assert x.numel() == 1
|
||||
return x[0]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
r"""Initializes the right attribute based on type of value."""
|
||||
m = _ATTR_PATTERN.match(key)
|
||||
@ -336,10 +345,12 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
|
||||
|
||||
# TODO: Expose this to user when migrating symbolic helper functions to here.
|
||||
@_beartype.beartype
|
||||
def _is_tensor(x: _C.Value) -> bool:
|
||||
return x.type().isSubtypeOf(_C.TensorType.get())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def get_device_from_value(value: _C.Value) -> Optional[torch.device]:
|
||||
if not _is_tensor(value):
|
||||
return None
|
||||
@ -347,6 +358,7 @@ def get_device_from_value(value: _C.Value) -> Optional[torch.device]:
|
||||
return tensor_type.device()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def parse_node_kind(kind: str) -> Tuple[str, str]:
|
||||
"""Parse node kind into domain and Op name."""
|
||||
if "::" not in kind:
|
||||
@ -357,16 +369,19 @@ def parse_node_kind(kind: str) -> Tuple[str, str]:
|
||||
return domain, opname
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def is_aten(domain: str) -> bool:
|
||||
"""Check if the domain is official."""
|
||||
return domain == "aten"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def is_prim(domain: str) -> bool:
|
||||
"""Check if the domain is official."""
|
||||
return domain == "prim"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def is_onnx(domain: str) -> bool:
|
||||
"""Check if the domain is official."""
|
||||
return domain == "onnx"
|
||||
|
@ -14,9 +14,10 @@ import torch
|
||||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch.onnx import _constants, _exporter_states, errors
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_as_test_case(
|
||||
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
|
||||
) -> str:
|
||||
@ -72,6 +73,7 @@ def export_as_test_case(
|
||||
return test_case_dir
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
|
||||
"""Load a self contained ONNX test case from a directory.
|
||||
|
||||
@ -122,6 +124,7 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
|
||||
return model_bytes, inputs, outputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_data(data, value_info_proto, f: str) -> None:
|
||||
"""Export data to ONNX protobuf format.
|
||||
|
||||
@ -160,6 +163,7 @@ def export_data(data, value_info_proto, f: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export_file(
|
||||
model_bytes: bytes,
|
||||
f: Union[io.BytesIO, str],
|
||||
@ -206,6 +210,7 @@ def _export_file(
|
||||
raise ValueError("Unknown export type")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_onnxscript_fn(
|
||||
model_bytes: bytes,
|
||||
custom_opsets: Mapping[str, int],
|
||||
@ -238,6 +243,7 @@ def _add_onnxscript_fn(
|
||||
return model_bytes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_onnxscript_op(
|
||||
graph_proto,
|
||||
included_node_func: Set[str],
|
||||
|
@ -15,7 +15,6 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
@ -32,11 +31,9 @@ from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
||||
from torch.utils import _pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
try:
|
||||
# Use try-except to initialize package-dependent global variables.
|
||||
import onnx
|
||||
import onnxruntime # type: ignore[import]
|
||||
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
|
||||
|
||||
@ -555,9 +552,9 @@ class OrtExecutionInfoPerSession:
|
||||
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined]
|
||||
# For the ONNX model stored in self.session, self.input_devices[i] is the
|
||||
# i-th positional input's device.
|
||||
self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices
|
||||
self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices
|
||||
# Similar to self.input_devices, but for outputs.
|
||||
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
|
||||
self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices
|
||||
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
|
||||
# (i.e., args passed into OrtBackend._ort_acclerated_call).
|
||||
self.example_outputs: Union[
|
||||
|
@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from torch.onnx import _constants, errors
|
||||
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
OpsetVersion = int
|
||||
|
||||
@ -265,6 +265,7 @@ class SymbolicRegistry:
|
||||
return set(self._registry)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def onnx_symbolic(
|
||||
name: str,
|
||||
opset: Union[OpsetVersion, Sequence[OpsetVersion]],
|
||||
@ -313,6 +314,7 @@ def onnx_symbolic(
|
||||
return wrapper
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def custom_onnx_symbolic(
|
||||
name: str,
|
||||
opset: Union[OpsetVersion, Sequence[OpsetVersion]],
|
||||
|
@ -9,6 +9,7 @@ from typing import Dict, Literal, Optional, Union
|
||||
import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# Hack to help mypy to recognize torch._C.Value
|
||||
@ -105,6 +106,7 @@ class JitScalarType(enum.IntEnum):
|
||||
UNDEFINED = enum.auto() # 20
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def _from_name(
|
||||
cls, name: Union[ScalarName, TorchName, Optional[str]]
|
||||
) -> JitScalarType:
|
||||
@ -134,6 +136,7 @@ class JitScalarType(enum.IntEnum):
|
||||
raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'")
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_dtype(cls, dtype: Optional[torch.dtype]) -> JitScalarType:
|
||||
"""Convert a torch dtype to JitScalarType.
|
||||
|
||||
@ -156,6 +159,7 @@ class JitScalarType(enum.IntEnum):
|
||||
return _DTYPE_TO_SCALAR_TYPE[dtype]
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_onnx_type(
|
||||
cls, onnx_type: Optional[Union[int, _C_onnx.TensorProtoDataType]]
|
||||
) -> JitScalarType:
|
||||
@ -175,6 +179,7 @@ class JitScalarType(enum.IntEnum):
|
||||
return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)]
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def from_value(
|
||||
cls, value: Union[None, torch._C.Value, torch.Tensor], default=None
|
||||
) -> JitScalarType:
|
||||
@ -242,18 +247,22 @@ class JitScalarType(enum.IntEnum):
|
||||
value,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def scalar_name(self) -> ScalarName:
|
||||
"""Convert a JitScalarType to a JIT scalar type name."""
|
||||
return _SCALAR_TYPE_TO_NAME[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def torch_name(self) -> TorchName:
|
||||
"""Convert a JitScalarType to a torch type name."""
|
||||
return _SCALAR_TYPE_TO_TORCH_NAME[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""Convert a JitScalarType to a torch dtype."""
|
||||
return _SCALAR_TYPE_TO_DTYPE[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def onnx_type(self) -> _C_onnx.TensorProtoDataType:
|
||||
"""Convert a JitScalarType to an ONNX data type."""
|
||||
if self not in _SCALAR_TYPE_TO_ONNX:
|
||||
@ -262,6 +271,7 @@ class JitScalarType(enum.IntEnum):
|
||||
)
|
||||
return _SCALAR_TYPE_TO_ONNX[self]
|
||||
|
||||
@_beartype.beartype
|
||||
def onnx_compatible(self) -> bool:
|
||||
"""Return whether this JitScalarType is compatible with ONNX."""
|
||||
return (
|
||||
@ -271,11 +281,13 @@ class JitScalarType(enum.IntEnum):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def valid_scalar_name(scalar_name: Union[ScalarName, str]) -> bool:
|
||||
"""Return whether the given scalar name is a valid JIT scalar type name."""
|
||||
return scalar_name in _SCALAR_NAME_TO_TYPE
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
|
||||
"""Return whether the given torch name is a valid torch type name."""
|
||||
return torch_name in _TORCH_NAME_TO_SCALAR_TYPE
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
|
@ -27,10 +27,8 @@ from torch import _C
|
||||
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||||
from torch.onnx import _constants, _type_utils, errors, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torch.types import Number
|
||||
from torch.onnx._internal import _beartype, jit_utils
|
||||
from torch.types import Number
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
@ -50,6 +48,7 @@ _ValueDescriptor = Literal[
|
||||
]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _parse_arg(
|
||||
value,
|
||||
desc: _ValueDescriptor,
|
||||
@ -119,6 +118,7 @@ def _parse_arg(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_get(node: _C.Node, key: str):
|
||||
"""Gets attributes of a node which is polymorphic over return type."""
|
||||
assert isinstance(node, _C.Node)
|
||||
@ -126,11 +126,13 @@ def _node_get(node: _C.Node, key: str):
|
||||
return getattr(node, sel)(key)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_onnx_constant(value: _C.Value):
|
||||
"""Whether a Value is an ONNX constant."""
|
||||
return value.node().kind() == "onnx::Constant"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_get_const(
|
||||
value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]],
|
||||
descriptor: _ValueDescriptor,
|
||||
@ -143,6 +145,7 @@ def _maybe_get_const(
|
||||
return value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_get_scalar(value):
|
||||
value_t = _maybe_get_const(value, "t")
|
||||
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
|
||||
@ -150,6 +153,7 @@ def _maybe_get_scalar(value):
|
||||
return value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_const(value, desc, arg_name):
|
||||
if not _is_constant(value):
|
||||
raise errors.SymbolicValueError(
|
||||
@ -160,6 +164,7 @@ def _get_const(value, desc, arg_name):
|
||||
return _parse_arg(value, desc)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
|
||||
list_node = list_value.node()
|
||||
if list_node.kind() != "prim::ListConstruct":
|
||||
@ -171,6 +176,7 @@ def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
|
||||
return list(list_node.inputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
tuple_node = tuple_value.node()
|
||||
if not _is_tuple_construct(tuple_value):
|
||||
@ -182,6 +188,7 @@ def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
return tuple(tuple_node.inputs())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
"""Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
|
||||
Args:
|
||||
@ -205,10 +212,12 @@ def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
|
||||
|
||||
# Check if list_value is output from prim::ListConstruct
|
||||
# This is usually called before _unpack_list to ensure the list can be unpacked.
|
||||
@_beartype.beartype
|
||||
def _is_packed_list(list_value: Any) -> bool:
|
||||
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def parse_args(*arg_descriptors: _ValueDescriptor):
|
||||
"""A decorator which converts args from torch._C.Value to built-in types.
|
||||
|
||||
@ -287,6 +296,7 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
|
||||
return decorator
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def quantized_args(
|
||||
*arg_q_descriptors: bool,
|
||||
scale: Optional[float] = None,
|
||||
@ -423,6 +433,7 @@ def quantized_args(
|
||||
return decorator
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _scalar(x: Any) -> Optional[Number]:
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
if isinstance(x, torch.Tensor) and x.shape == ():
|
||||
@ -430,6 +441,7 @@ def _scalar(x: Any) -> Optional[Number]:
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _if_scalar_type_as(self, tensor):
|
||||
"""
|
||||
Convert self into the same type of tensor, as necessary.
|
||||
@ -449,14 +461,17 @@ def _if_scalar_type_as(self, tensor):
|
||||
return self
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_none(x: Any) -> bool:
|
||||
return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_value(x: Any) -> bool:
|
||||
return isinstance(x, _C.Value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_constant(value: Any) -> bool:
|
||||
return not _is_value(value) or value.node().kind() in {
|
||||
"onnx::Constant",
|
||||
@ -464,6 +479,7 @@ def _is_constant(value: Any) -> bool:
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tensor(x: _C.Value) -> bool:
|
||||
return x.type().isSubtypeOf(_C.TensorType.get())
|
||||
|
||||
@ -475,10 +491,12 @@ def _as_list_type(jit_type: _C.JitType) -> Optional[_C.ListType]:
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_list(x: _C.Value) -> bool:
|
||||
return _as_list_type(x.type()) is not None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tensor_list(x: _C.Value) -> bool:
|
||||
x_type = _as_list_type(x.type())
|
||||
if x_type is None:
|
||||
@ -486,6 +504,7 @@ def _is_tensor_list(x: _C.Value) -> bool:
|
||||
return isinstance(x_type.getElementType(), _C.TensorType)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_scalar_list(x: _C.Value) -> bool:
|
||||
"""Checks if x is a scalar list, for example: List[float], List[int].
|
||||
|
||||
@ -499,10 +518,12 @@ def _is_scalar_list(x: _C.Value) -> bool:
|
||||
return scalar_type.onnx_compatible()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_tuple_construct(x: _C.Value) -> bool:
|
||||
return x.node().kind() == "prim::TupleConstruct"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def is_complex_value(x: _C.Value) -> bool:
|
||||
assert _is_value(x)
|
||||
return _type_utils.JitScalarType.from_value(
|
||||
@ -514,6 +535,7 @@ def is_complex_value(x: _C.Value) -> bool:
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_tensor_rank(x: _C.Value) -> Optional[int]:
|
||||
if not _is_tensor(x) or x.type() is None:
|
||||
return None
|
||||
@ -522,6 +544,7 @@ def _get_tensor_rank(x: _C.Value) -> Optional[int]:
|
||||
return x_type.dim()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
|
||||
if not _is_tensor(x) or x.type() is None:
|
||||
return None
|
||||
@ -536,11 +559,13 @@ def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
|
||||
return x_type.sizes()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_tensor_dim_size(x: _C.Value, dim: int) -> Optional[int]:
|
||||
sizes = _get_tensor_sizes(x)
|
||||
return sizes[dim] if sizes else None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
|
||||
if dim == -1:
|
||||
tensor_rank = _get_tensor_rank(x)
|
||||
@ -556,12 +581,14 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
|
||||
return dim
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
|
||||
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
|
||||
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
|
||||
_onnx_unsupported(f"{op}, {msg}", value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoReturn:
|
||||
message = (
|
||||
f"Unsupported: ONNX export of operator {op_name}. "
|
||||
@ -576,6 +603,7 @@ def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoRetur
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_opset_unsupported(
|
||||
op_name: str,
|
||||
current_opset: int,
|
||||
@ -594,6 +622,7 @@ def _onnx_opset_unsupported(
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_opset_unsupported_detailed(
|
||||
op_name: str,
|
||||
current_opset: int,
|
||||
@ -613,6 +642,7 @@ def _onnx_opset_unsupported_detailed(
|
||||
raise errors.OnnxExporterError(message)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _block_list_in_opset(name: str):
|
||||
def symbolic_fn(*args, **kwargs):
|
||||
raise errors.OnnxExporterError(
|
||||
@ -624,6 +654,7 @@ def _block_list_in_opset(name: str):
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_get_scalar_type(*args) -> Optional[_type_utils.JitScalarType]:
|
||||
for arg in args:
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
@ -634,19 +665,21 @@ def _try_get_scalar_type(*args) -> Optional[_type_utils.JitScalarType]:
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _type_promote_from_values(*args) -> _type_utils.JitScalarType:
|
||||
undef = _type_utils.JitScalarType.UNDEFINED
|
||||
jit_types = [_try_get_scalar_type(arg) for arg in args]
|
||||
if len(jit_types) == 0:
|
||||
return undef
|
||||
if len(jit_types) == 1:
|
||||
return jit_types[0] # type: ignore[return-value]
|
||||
new_dtype = jit_types[0].dtype() # type: ignore[union-attr]
|
||||
return jit_types[0]
|
||||
new_dtype = jit_types[0].dtype()
|
||||
for t in jit_types:
|
||||
new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr]
|
||||
new_dtype = torch.promote_types(new_dtype, t.dtype())
|
||||
return _type_utils.JitScalarType.from_dtype(new_dtype)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_cast_to_type(
|
||||
g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType
|
||||
):
|
||||
@ -662,6 +695,7 @@ def _maybe_cast_to_type(
|
||||
return value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True):
|
||||
index_const = _maybe_get_scalar(index)
|
||||
index_dim = _get_tensor_rank(index)
|
||||
@ -686,6 +720,7 @@ def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=Tr
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _slice_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -704,6 +739,7 @@ def _slice_helper(
|
||||
return _slice10(g, input, axes, starts, ends, steps)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_fp(value) -> bool:
|
||||
return _type_utils.JitScalarType.from_value(
|
||||
value, _type_utils.JitScalarType.UNDEFINED
|
||||
@ -715,12 +751,14 @@ def _is_fp(value) -> bool:
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_bool(value) -> bool:
|
||||
return _type_utils.JitScalarType.from_value(
|
||||
value, _type_utils.JitScalarType.UNDEFINED
|
||||
) in {_type_utils.JitScalarType.BOOL}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
|
||||
"""Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
|
||||
|
||||
@ -739,6 +777,7 @@ def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
|
||||
return g.op("Constant", value_t=torch.tensor(scalar))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None):
|
||||
if out is not None:
|
||||
_unimplemented("Sort", "Out parameter is not supported")
|
||||
@ -758,6 +797,7 @@ def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _topk_helper(
|
||||
g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None
|
||||
):
|
||||
@ -779,6 +819,7 @@ def _topk_helper(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _lt_helper(g: jit_utils.GraphContext, input, other):
|
||||
if g.opset <= 8:
|
||||
from torch.onnx.symbolic_opset8 import lt as _lt8
|
||||
@ -790,6 +831,7 @@ def _lt_helper(g: jit_utils.GraphContext, input, other):
|
||||
return _lt9(g, input, other)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_warning(interpolate_mode):
|
||||
onnx_op = (
|
||||
"onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
|
||||
@ -807,6 +849,7 @@ def _interpolate_warning(interpolate_mode):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
|
||||
if _is_constant(axes_i[0]):
|
||||
if g.opset >= 13:
|
||||
@ -821,6 +864,7 @@ def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
|
||||
return g.op("Unsqueeze", input, axes_i[0])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
|
||||
if _is_constant(axes_i[0]):
|
||||
if g.opset >= 13:
|
||||
@ -846,6 +890,7 @@ def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
|
||||
return g.op("Squeeze", input, axes_t)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reducesum_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -877,6 +922,7 @@ def _reducesum_helper(
|
||||
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim):
|
||||
output_size = _maybe_get_const(output_size, "is")
|
||||
if _is_value(output_size):
|
||||
@ -903,6 +949,7 @@ def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, d
|
||||
return scales
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
|
||||
available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
|
||||
scales[0]
|
||||
@ -919,6 +966,7 @@ def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
|
||||
return scales
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
|
||||
if mode == "nearest":
|
||||
align_corners = None
|
||||
@ -930,6 +978,7 @@ def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
|
||||
return scales, align_corners
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
|
||||
offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
|
||||
scale_factor_rank = _get_tensor_rank(scale_factor)
|
||||
@ -947,6 +996,7 @@ def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
|
||||
return scale_factor
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_get_scales_and_mode(
|
||||
g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners
|
||||
):
|
||||
@ -982,6 +1032,7 @@ def _interpolate_get_scales_and_mode(
|
||||
return scale_factor, mode
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _argmin_argmax_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
@ -1020,6 +1071,7 @@ def _argmin_argmax_helper(
|
||||
return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _interpolate_helper(name, dim, interpolate_mode):
|
||||
@quantized_args(True, False, False)
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
@ -1087,6 +1139,7 @@ def _interpolate_helper(name, dim, interpolate_mode):
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def __interpolate_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -1186,6 +1239,7 @@ def __interpolate_helper(
|
||||
) # only valid when mode="nearest"
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
|
||||
if g.opset < 11:
|
||||
from torch.onnx.symbolic_opset9 import unbind
|
||||
@ -1196,6 +1250,7 @@ def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
|
||||
return unbind(g, self, dim, _outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
if g.opset <= 10:
|
||||
from torch.onnx.symbolic_opset9 import scatter
|
||||
@ -1205,6 +1260,7 @@ def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
return scatter(g, self, dim, index, src)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
|
||||
if g.opset <= 12:
|
||||
split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
|
||||
@ -1216,6 +1272,7 @@ def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
|
||||
return split_out if reps > 1 else [split_out]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _repeat_interleave_single_value_repeat_helper(
|
||||
g: jit_utils.GraphContext, self, repeats, dim
|
||||
):
|
||||
@ -1237,7 +1294,7 @@ def _repeat_interleave_single_value_repeat_helper(
|
||||
# repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'.
|
||||
if const_repeats:
|
||||
# 'Repeats' is a constant, 'repeats_per_dim' can be a constant.
|
||||
onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type]
|
||||
onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64)
|
||||
onehot[dim + 1] = reps
|
||||
repeats_per_dim = g.op("Constant", value_t=onehot)
|
||||
else:
|
||||
@ -1258,6 +1315,7 @@ def _repeat_interleave_single_value_repeat_helper(
|
||||
return flatten(g, tiled, dim, dim + 1)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _arange_cast_helper(
|
||||
g: jit_utils.GraphContext, end, start=None, step=None, dtype=None
|
||||
) -> Tuple[
|
||||
@ -1300,6 +1358,7 @@ def _arange_cast_helper(
|
||||
return scalar_type, end, start, step
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _arange_helper(g: jit_utils.GraphContext, *args):
|
||||
if g.opset <= 10:
|
||||
from torch.onnx.symbolic_opset9 import arange
|
||||
@ -1308,6 +1367,7 @@ def _arange_helper(g: jit_utils.GraphContext, *args):
|
||||
return arange(g, *args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _size_helper(g: jit_utils.GraphContext, self, dim):
|
||||
full_shape = g.op("Shape", self)
|
||||
from torch.onnx.symbolic_opset9 import select
|
||||
@ -1315,6 +1375,7 @@ def _size_helper(g: jit_utils.GraphContext, self, dim):
|
||||
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
|
||||
# 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
|
||||
# 2. expand index => [..., dim, ...], same shape as self except for dim.
|
||||
@ -1350,6 +1411,7 @@ def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
|
||||
# allowzero=1 indicates that if any value in the 'shape' input is set to zero,
|
||||
# the zero value is honored, similar to NumPy.
|
||||
# allowzero=1 is only supported for opset version >= 14.
|
||||
@_beartype.beartype
|
||||
def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
|
||||
shape = _maybe_get_const(shape, "is")
|
||||
if not _is_value(shape):
|
||||
@ -1364,6 +1426,7 @@ def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
|
||||
return g.op("Reshape", input, shape, allowzero_i=allowzero)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _batchnorm_helper(
|
||||
g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var
|
||||
):
|
||||
@ -1421,6 +1484,7 @@ def _batchnorm_helper(
|
||||
return weight, bias, running_mean, running_var
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _avgpool_helper(
|
||||
tuple_fn: Callable[[Any], Sequence[int]],
|
||||
padding: Union[int, Sequence[int]],
|
||||
@ -1434,6 +1498,7 @@ def _avgpool_helper(
|
||||
return tuple(tuple_fn(padding))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def check_training_mode(op_train_mode: int, op_name: str) -> None:
|
||||
"""Warns the user if the model's training mode and the export mode do not agree."""
|
||||
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
|
||||
@ -1459,6 +1524,7 @@ def check_training_mode(op_train_mode: int, op_name: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
|
||||
input_size = g.op("Shape", input)
|
||||
slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
|
||||
@ -1479,6 +1545,7 @@ def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
|
||||
return _reshape_from_tensor(g, input, final_shape)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_split_static(split_size_or_sizes, _outputs):
|
||||
if _outputs is None:
|
||||
return False
|
||||
@ -1490,12 +1557,14 @@ def _is_split_static(split_size_or_sizes, _outputs):
|
||||
return True
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _optional_input_placeholder_tensor(g):
|
||||
n = g.op("prim::Constant")
|
||||
n.setType(_C.OptionalType.ofTensor())
|
||||
return n
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
|
||||
rank = _get_tensor_rank(self)
|
||||
if rank is not None and any(
|
||||
@ -1507,6 +1576,7 @@ def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
|
||||
return g.op(op_name, self, keepdims_i=0)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def dequantize_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
qtensor: _C.Value,
|
||||
@ -1555,6 +1625,7 @@ def dequantize_helper(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def quantize_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
tensor: _C.Value,
|
||||
@ -1616,6 +1687,7 @@ def quantize_helper(
|
||||
return g.op("prim::TupleConstruct", *args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def requantize_bias_helper(
|
||||
g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None
|
||||
):
|
||||
@ -1638,6 +1710,7 @@ def requantize_bias_helper(
|
||||
return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def args_have_same_dtype(args):
|
||||
assert args
|
||||
base_dtype = _type_utils.JitScalarType.from_value(args[0])
|
||||
@ -1647,6 +1720,7 @@ def args_have_same_dtype(args):
|
||||
return has_same_dtype
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
|
||||
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
|
||||
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
|
||||
@ -1700,6 +1774,7 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
|
||||
return self
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
self, _type_utils.JitScalarType.UNDEFINED
|
||||
@ -1721,7 +1796,9 @@ def _apply_params(*args, **kwargs):
|
||||
return _apply
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = _maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None or dim == ():
|
||||
@ -1754,8 +1831,10 @@ def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True):
|
||||
return symbolic
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _overload_by_arg_count(fn):
|
||||
@functools.wraps(fn)
|
||||
@_beartype.beartype
|
||||
def wrapper(g, *args):
|
||||
overloads = fn(g, *args)
|
||||
for overload in overloads:
|
||||
@ -1767,6 +1846,7 @@ def _overload_by_arg_count(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype_helper(
|
||||
onnx_op: str, name: str, allow_multi_dim_support: bool = True
|
||||
):
|
||||
@ -1821,6 +1901,7 @@ def _reduce_with_dtype_helper(
|
||||
return reduce
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
@ -1841,6 +1922,7 @@ def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return max, indices
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input)
|
||||
if dim_or_y is None and keepdim is None:
|
||||
@ -1861,12 +1943,14 @@ def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return min, indices
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _numel_helper(g: jit_utils.GraphContext, self):
|
||||
shape = g.op("Shape", self)
|
||||
return g.op("ReduceProd", shape, keepdims_i=0)
|
||||
|
||||
|
||||
@parse_args("v", "is", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim):
|
||||
if g.opset < 18:
|
||||
if dim is None:
|
||||
@ -1939,6 +2023,7 @@ def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim)
|
||||
return var, mean
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _embedding_bag_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
@ -2045,6 +2130,7 @@ def _embedding_bag_helper(
|
||||
return loop.node().output(), None, None, None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _linalg_vector_norm_helper(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
@ -21,7 +20,7 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
@ -73,6 +72,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::div")
|
||||
@_beartype.beartype
|
||||
def div(g: jit_utils.GraphContext, self, other, *args):
|
||||
if len(args) == 0:
|
||||
return opset9.true_divide(g, self, other)
|
||||
@ -81,6 +81,7 @@ def div(g: jit_utils.GraphContext, self, other, *args):
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "s")
|
||||
@_beartype.beartype
|
||||
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
|
||||
if rounding_mode == "floor":
|
||||
return _floor_divide(g, self, other)
|
||||
@ -89,6 +90,7 @@ def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_floor_divide")
|
||||
@_beartype.beartype
|
||||
def _floor_divide(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||||
out = opset9.true_divide(g, self, other)
|
||||
@ -111,12 +113,14 @@ def _floor_divide(g: jit_utils.GraphContext, self, other):
|
||||
|
||||
@_onnx_symbolic("aten::sort")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::topk")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||||
return symbolic_helper._topk_helper(
|
||||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||||
@ -304,6 +308,7 @@ def _aten_max_pool_with_indices_onnx(
|
||||
)
|
||||
],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _max_pool(name: str, expand_size: int, return_indices: bool):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||||
@ -394,9 +399,11 @@ def _adjust_attributes_of_avg_pool(
|
||||
"aten::avg_pool3d",
|
||||
decorate=[symbolic_helper._apply_params("avg_pool3d", 3)],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _avg_pool(name, expand_size):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input: _C.Value,
|
||||
@ -450,8 +457,10 @@ def _avg_pool(name, expand_size):
|
||||
"aten::upsample_trilinear3d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
@symbolic_helper.quantized_args(True, False, False)
|
||||
@_beartype.beartype
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||||
g, interpolate_mode, args
|
||||
@ -470,6 +479,7 @@ def _interpolate(name, dim, interpolate_mode):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__interpolate")
|
||||
@_beartype.beartype
|
||||
def __interpolate(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -486,6 +496,7 @@ def __interpolate(
|
||||
return g.op("Resize", input, scales, mode_s=mode)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _slice(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
@ -545,6 +556,7 @@ def _slice(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::slice")
|
||||
@_beartype.beartype
|
||||
def slice(g: jit_utils.GraphContext, self, *args):
|
||||
if len(args) == 4:
|
||||
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
|
||||
@ -568,6 +580,7 @@ def slice(g: jit_utils.GraphContext, self, *args):
|
||||
|
||||
@_onnx_symbolic("aten::flip")
|
||||
@symbolic_helper.parse_args("v", "is")
|
||||
@_beartype.beartype
|
||||
def flip(g: jit_utils.GraphContext, input, dims):
|
||||
return symbolic_helper._slice_helper(
|
||||
g,
|
||||
@ -580,12 +593,14 @@ def flip(g: jit_utils.GraphContext, input, dims):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::fmod")
|
||||
@_beartype.beartype
|
||||
def fmod(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("Mod", input, other, fmod_i=1)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
@ -672,6 +687,7 @@ def embedding_bag(
|
||||
|
||||
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_tensor_affine(
|
||||
g: jit_utils.GraphContext,
|
||||
inputs,
|
||||
@ -719,11 +735,13 @@ def fake_quantize_per_tensor_affine(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::isinf")
|
||||
@_beartype.beartype
|
||||
def isinf(g: jit_utils.GraphContext, input):
|
||||
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::isfinite")
|
||||
@_beartype.beartype
|
||||
def isfinite(g: jit_utils.GraphContext, input):
|
||||
inf_node = isinf(g, input)
|
||||
nan_node = opset9.isnan(g, input)
|
||||
@ -731,6 +749,7 @@ def isfinite(g: jit_utils.GraphContext, input):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::quantize_per_tensor")
|
||||
@_beartype.beartype
|
||||
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
# TODO(justinchuby): Extract all the cast ops into a helper function.
|
||||
@ -742,12 +761,14 @@ def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dty
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::dequantize")
|
||||
@_beartype.beartype
|
||||
def dequantize(g: jit_utils.GraphContext, input):
|
||||
return symbolic_helper.dequantize_helper(g, input)[0]
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nan_to_num")
|
||||
@symbolic_helper.parse_args("v", "f", "f", "f")
|
||||
@_beartype.beartype
|
||||
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
|
||||
# Cannot create a int type tensor with inf/nan values, so we simply
|
||||
# return the original tensor
|
||||
@ -803,6 +824,7 @@ def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
|
||||
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
|
||||
# introduced in opset version 10.
|
||||
@_onnx_symbolic("quantized::linear")
|
||||
@_beartype.beartype
|
||||
def quantized_linear(
|
||||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||||
):
|
||||
@ -817,6 +839,7 @@ def quantized_linear(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::linear_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_linear_relu(
|
||||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||||
):
|
||||
@ -832,6 +855,7 @@ def quantized_linear_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::add")
|
||||
@_beartype.beartype
|
||||
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -842,6 +866,7 @@ def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::add_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -853,6 +878,7 @@ def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point)
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::mul")
|
||||
@_beartype.beartype
|
||||
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
@ -863,6 +889,7 @@ def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::hardswish")
|
||||
@_beartype.beartype
|
||||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -872,6 +899,7 @@ def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::sigmoid")
|
||||
@_beartype.beartype
|
||||
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -881,6 +909,7 @@ def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::leaky_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_leaky_relu(
|
||||
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
|
||||
):
|
||||
@ -892,6 +921,7 @@ def quantized_leaky_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::layer_norm")
|
||||
@_beartype.beartype
|
||||
def quantized_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
x,
|
||||
@ -910,6 +940,7 @@ def quantized_layer_norm(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::group_norm")
|
||||
@_beartype.beartype
|
||||
def quantized_group_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
x,
|
||||
@ -929,6 +960,7 @@ def quantized_group_norm(
|
||||
|
||||
@_onnx_symbolic("quantized::instance_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
|
||||
@_beartype.beartype
|
||||
def quantized_instance_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -948,6 +980,7 @@ def quantized_instance_norm(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv1d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv1d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -972,6 +1005,7 @@ def quantized_conv1d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv2d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv2d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -996,6 +1030,7 @@ def quantized_conv2d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv3d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv3d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1020,6 +1055,7 @@ def quantized_conv3d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv1d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv1d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1043,6 +1079,7 @@ def quantized_conv1d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv2d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv2d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1066,6 +1103,7 @@ def quantized_conv2d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv3d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv3d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1089,6 +1127,7 @@ def quantized_conv3d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose1d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose1d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1115,6 +1154,7 @@ def quantized_conv_transpose1d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose2d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose2d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1141,6 +1181,7 @@ def quantized_conv_transpose2d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose3d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose3d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1168,6 +1209,7 @@ def quantized_conv_transpose3d(
|
||||
|
||||
@_onnx_symbolic("quantized::cat")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def quantized_cat(
|
||||
g: jit_utils.GraphContext,
|
||||
q_inputs: _C.Value,
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 11."""
|
||||
from __future__ import annotations
|
||||
|
||||
@ -19,7 +18,7 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
@ -90,6 +89,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
|
||||
@_onnx_symbolic("aten::hardtanh")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "f", "f")
|
||||
@_beartype.beartype
|
||||
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
self, _type_utils.JitScalarType.FLOAT
|
||||
@ -108,7 +108,9 @@ def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val:
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::clamp")
|
||||
@_beartype.beartype
|
||||
def clamp(g: jit_utils.GraphContext, self, min, max):
|
||||
@_beartype.beartype
|
||||
def _cast_if_not_none(tensor, dtype):
|
||||
if tensor is not None and not symbolic_helper._is_none(tensor):
|
||||
return g.op(
|
||||
@ -144,6 +146,7 @@ def clamp(g: jit_utils.GraphContext, self, min, max):
|
||||
|
||||
@_onnx_symbolic("aten::clamp_min")
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def clamp_min(g: jit_utils.GraphContext, self, min):
|
||||
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
|
||||
if symbolic_helper._get_tensor_rank(min) == 0:
|
||||
@ -159,6 +162,7 @@ def clamp_min(g: jit_utils.GraphContext, self, min):
|
||||
|
||||
@_onnx_symbolic("aten::clamp_max")
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def clamp_max(g: jit_utils.GraphContext, self, max):
|
||||
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
|
||||
if symbolic_helper._get_tensor_rank(max) == 0:
|
||||
@ -173,6 +177,7 @@ def clamp_max(g: jit_utils.GraphContext, self, max):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::relu6")
|
||||
@_beartype.beartype
|
||||
def relu6(g: jit_utils.GraphContext, input):
|
||||
scalar_type = _type_utils.JitScalarType.from_value(
|
||||
input, _type_utils.JitScalarType.FLOAT
|
||||
@ -192,11 +197,13 @@ def relu6(g: jit_utils.GraphContext, input):
|
||||
# Opset 11 gather accepts negative indices
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "i", "v")
|
||||
@_beartype.beartype
|
||||
def select(g: jit_utils.GraphContext, self, dim, index):
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::index_put")
|
||||
@_beartype.beartype
|
||||
def index_put(
|
||||
g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False
|
||||
):
|
||||
@ -322,6 +329,7 @@ def index_put(
|
||||
|
||||
@_onnx_symbolic("aten::pixel_shuffle")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
|
||||
rank = symbolic_helper._get_tensor_rank(self)
|
||||
if rank is not None and rank != 4:
|
||||
@ -357,12 +365,14 @@ def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
|
||||
"aten::upsample_bicubic2d",
|
||||
decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _interpolate(name: str, dim: int, interpolate_mode: str):
|
||||
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__interpolate")
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@_beartype.beartype
|
||||
def __interpolate(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -380,6 +390,7 @@ def __interpolate(
|
||||
|
||||
@_onnx_symbolic("aten::gather")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
||||
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
|
||||
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
|
||||
@ -388,6 +399,7 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
||||
|
||||
@_onnx_symbolic("aten::scatter")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
src_type = _type_utils.JitScalarType.from_value(src)
|
||||
src = symbolic_helper._maybe_get_scalar(src)
|
||||
@ -409,6 +421,7 @@ def scatter(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
|
||||
@_onnx_symbolic("aten::cumsum")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
|
||||
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -423,12 +436,14 @@ def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::masked_select")
|
||||
@_beartype.beartype
|
||||
def masked_select(g: jit_utils.GraphContext, self, mask):
|
||||
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
|
||||
return g.op("GatherND", self, index)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::masked_scatter")
|
||||
@_beartype.beartype
|
||||
def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
|
||||
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
|
||||
# NOTE: source can have more elements than needed.
|
||||
@ -446,6 +461,7 @@ def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::len")
|
||||
@_beartype.beartype
|
||||
def _len(g: jit_utils.GraphContext, self):
|
||||
if (
|
||||
symbolic_helper._is_tensor_list(self)
|
||||
@ -457,6 +473,7 @@ def _len(g: jit_utils.GraphContext, self):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__getitem_")
|
||||
@_beartype.beartype
|
||||
def __getitem_(g: jit_utils.GraphContext, self, i):
|
||||
if symbolic_helper._is_tensor_list(self):
|
||||
# SequenceAt requires that the input be a List of Tensors
|
||||
@ -468,17 +485,20 @@ def __getitem_(g: jit_utils.GraphContext, self, i):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_set_item")
|
||||
@_beartype.beartype
|
||||
def _set_item(g: jit_utils.GraphContext, tensor_list, i, v):
|
||||
tensor_list = g.op("SequenceErase", tensor_list, i)
|
||||
return g.op("SequenceInsert", tensor_list, v, i)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::append")
|
||||
@_beartype.beartype
|
||||
def append(g: jit_utils.GraphContext, self, tensor):
|
||||
return g.op("SequenceInsert", self, tensor)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::add")
|
||||
@_beartype.beartype
|
||||
def add(g: jit_utils.GraphContext, self, other, alpha=None):
|
||||
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
|
||||
tensor_list_node = other.node()
|
||||
@ -496,22 +516,26 @@ def add(g: jit_utils.GraphContext, self, other, alpha=None):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::insert")
|
||||
@_beartype.beartype
|
||||
def insert(g: jit_utils.GraphContext, self, pos, tensor):
|
||||
return g.op("SequenceInsert", self, tensor, pos)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::pop")
|
||||
@_beartype.beartype
|
||||
def pop(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
return g.op("SequenceErase", tensor_list, dim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::Delete")
|
||||
@_beartype.beartype
|
||||
def Delete(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
return g.op("SequenceErase", tensor_list, dim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::cat")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@_beartype.beartype
|
||||
def cat(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
if symbolic_helper._is_packed_list(tensor_list):
|
||||
return opset9.cat(g, tensor_list, dim)
|
||||
@ -521,6 +545,7 @@ def cat(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::stack")
|
||||
@_beartype.beartype
|
||||
def stack(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
if symbolic_helper._is_packed_list(tensor_list):
|
||||
return opset9.stack(g, tensor_list, dim)
|
||||
@ -531,6 +556,7 @@ def stack(g: jit_utils.GraphContext, tensor_list, dim):
|
||||
|
||||
@_onnx_symbolic("aten::_unique2")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
|
||||
u, indices, inverse_indices, counts = g.op(
|
||||
"Unique", self, sorted_i=sorted, outputs=4
|
||||
@ -540,6 +566,7 @@ def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_cou
|
||||
|
||||
@_onnx_symbolic("aten::unique_dim")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unique_dim(
|
||||
g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
|
||||
):
|
||||
@ -551,6 +578,7 @@ def unique_dim(
|
||||
|
||||
@_onnx_symbolic("aten::topk")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||||
return symbolic_helper._topk_helper(
|
||||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||||
@ -559,12 +587,14 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||||
|
||||
@_onnx_symbolic("aten::sort")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::argsort")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||||
@_beartype.beartype
|
||||
def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||||
_, indices = symbolic_helper._sort_helper(
|
||||
g, self, dim, decending=decending, out=out
|
||||
@ -574,6 +604,7 @@ def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||||
|
||||
@_onnx_symbolic("aten::round")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def round(g: jit_utils.GraphContext, self, decimals=0):
|
||||
if not symbolic_helper._is_fp(self):
|
||||
return self
|
||||
@ -587,6 +618,7 @@ def round(g: jit_utils.GraphContext, self, decimals=0):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::remainder")
|
||||
@_beartype.beartype
|
||||
def remainder(g: jit_utils.GraphContext, input, other):
|
||||
if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
|
||||
return opset9.remainder(g, input, other)
|
||||
@ -595,6 +627,7 @@ def remainder(g: jit_utils.GraphContext, input, other):
|
||||
|
||||
@_onnx_symbolic("aten::split")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
|
||||
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
|
||||
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
|
||||
@ -633,12 +666,14 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No
|
||||
|
||||
@_onnx_symbolic("aten::split_with_sizes")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
|
||||
return split(g, self, split_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unbind")
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -652,6 +687,7 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
|
||||
return opset9.unbind(g, self, dim, _outputs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
|
||||
"""Generate paddings in ONNX order based on pad in pytorch.
|
||||
|
||||
@ -710,6 +746,7 @@ def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::constant_pad_nd")
|
||||
@_beartype.beartype
|
||||
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
|
||||
mode = "constant"
|
||||
value = symbolic_helper._maybe_get_scalar(value)
|
||||
@ -721,6 +758,7 @@ def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
|
||||
@_onnx_symbolic("aten::reflection_pad1d")
|
||||
@_onnx_symbolic("aten::reflection_pad2d")
|
||||
@_onnx_symbolic("aten::reflection_pad3d")
|
||||
@_beartype.beartype
|
||||
def reflection_pad(g: jit_utils.GraphContext, input, padding):
|
||||
mode = "reflect"
|
||||
paddings = _prepare_onnx_paddings(g, input, padding)
|
||||
@ -730,6 +768,7 @@ def reflection_pad(g: jit_utils.GraphContext, input, padding):
|
||||
@_onnx_symbolic("aten::replication_pad1d")
|
||||
@_onnx_symbolic("aten::replication_pad2d")
|
||||
@_onnx_symbolic("aten::replication_pad3d")
|
||||
@_beartype.beartype
|
||||
def replication_pad(g: jit_utils.GraphContext, input, padding):
|
||||
mode = "edge"
|
||||
paddings = _prepare_onnx_paddings(g, input, padding)
|
||||
@ -737,6 +776,7 @@ def replication_pad(g: jit_utils.GraphContext, input, padding):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::pad")
|
||||
@_beartype.beartype
|
||||
def pad(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
@ -758,16 +798,19 @@ def pad(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::linalg_det")
|
||||
@_beartype.beartype
|
||||
def linalg_det(g: jit_utils.GraphContext, self):
|
||||
return g.op("Det", self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::logdet")
|
||||
@_beartype.beartype
|
||||
def logdet(g: jit_utils.GraphContext, input):
|
||||
return opset9.log(g, linalg_det(g, input))
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::arange")
|
||||
@_beartype.beartype
|
||||
def arange(g: jit_utils.GraphContext, *args):
|
||||
def _get_arange_dtype(dtype):
|
||||
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
||||
@ -841,6 +884,7 @@ def arange(g: jit_utils.GraphContext, *args):
|
||||
|
||||
@_onnx_symbolic("aten::_dim_arange")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def _dim_arange(g: jit_utils.GraphContext, like, dim):
|
||||
like_shape = g.op("Shape", like)
|
||||
stop = g.op(
|
||||
@ -851,6 +895,7 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
|
||||
|
||||
@_onnx_symbolic("aten::size")
|
||||
@symbolic_helper.quantized_args(True, quantize_output=False)
|
||||
@_beartype.beartype
|
||||
def size(g: jit_utils.GraphContext, self, dim=None):
|
||||
if dim is None:
|
||||
return g.op("Shape", self)
|
||||
@ -858,6 +903,7 @@ def size(g: jit_utils.GraphContext, self, dim=None):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::squeeze")
|
||||
@_beartype.beartype
|
||||
def squeeze(g: jit_utils.GraphContext, self, dim=None):
|
||||
if dim is None:
|
||||
return g.op("Squeeze", self)
|
||||
@ -909,6 +955,7 @@ def squeeze(g: jit_utils.GraphContext, self, dim=None):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unsqueeze")
|
||||
@_beartype.beartype
|
||||
def unsqueeze(g: jit_utils.GraphContext, self, dim):
|
||||
if symbolic_helper._is_constant(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
@ -917,11 +964,13 @@ def unsqueeze(g: jit_utils.GraphContext, self, dim):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::mm")
|
||||
@_beartype.beartype
|
||||
def mm(g: jit_utils.GraphContext, self, other):
|
||||
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::index")
|
||||
@_beartype.beartype
|
||||
def index(g: jit_utils.GraphContext, self, index):
|
||||
if symbolic_helper._is_packed_list(index):
|
||||
indices = symbolic_helper._unpack_list(index)
|
||||
@ -942,6 +991,7 @@ def index(g: jit_utils.GraphContext, self, index):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::index_fill")
|
||||
@_beartype.beartype
|
||||
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
@ -954,6 +1004,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::index_copy")
|
||||
@_beartype.beartype
|
||||
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
@ -964,6 +1015,7 @@ def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
|
||||
@_onnx_symbolic("aten::bitwise_right_shift")
|
||||
@_onnx_symbolic("aten::__rshift_")
|
||||
@_beartype.beartype
|
||||
def __rshift_(g: jit_utils.GraphContext, self, other):
|
||||
# make sure to cast other to self's type
|
||||
# (when self is long, make sure that other is not float)
|
||||
@ -998,6 +1050,7 @@ def __rshift_(g: jit_utils.GraphContext, self, other):
|
||||
|
||||
@_onnx_symbolic("aten::bitwise_left_shift")
|
||||
@_onnx_symbolic("aten::__lshift_")
|
||||
@_beartype.beartype
|
||||
def __lshift_(g: jit_utils.GraphContext, self, other):
|
||||
# make sure to cast other to self's type
|
||||
# (when self is long, make sure that other is not float)
|
||||
@ -1030,6 +1083,7 @@ def __lshift_(g: jit_utils.GraphContext, self, other):
|
||||
return lshift
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_indices_along_dim(
|
||||
g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d
|
||||
):
|
||||
@ -1073,6 +1127,7 @@ def _get_im2col_indices_along_dim(
|
||||
return block_mask
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# Padding tensor has the following format: (padding_h, padding_w)
|
||||
@ -1081,6 +1136,7 @@ def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, paddin
|
||||
return g.op("Pad", input, pad)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w):
|
||||
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
|
||||
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
|
||||
@ -1099,6 +1155,7 @@ def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_
|
||||
|
||||
@_onnx_symbolic("aten::im2col")
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is")
|
||||
@_beartype.beartype
|
||||
def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# All other args are int[2]
|
||||
@ -1151,6 +1208,7 @@ def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, str
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::narrow")
|
||||
@_beartype.beartype
|
||||
def narrow(g: jit_utils.GraphContext, input, dim, start, length):
|
||||
end = g.op("Add", start, length)
|
||||
return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end)
|
||||
@ -1159,6 +1217,7 @@ def narrow(g: jit_utils.GraphContext, input, dim, start, length):
|
||||
@_onnx_symbolic("aten::flatten")
|
||||
@symbolic_helper.quantized_args(True, False, False)
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
||||
dim = symbolic_helper._get_tensor_rank(input)
|
||||
if dim == 1:
|
||||
@ -1185,6 +1244,7 @@ def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
||||
|
||||
@_onnx_symbolic("aten::linalg_vector_norm")
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def linalg_vector_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self,
|
||||
@ -1198,6 +1258,7 @@ def linalg_vector_norm(
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
@ -1226,6 +1287,7 @@ def embedding_bag(
|
||||
|
||||
@_onnx_symbolic("aten::embedding_renorm")
|
||||
@symbolic_helper.parse_args("v", "v", "f", "f")
|
||||
@_beartype.beartype
|
||||
def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
|
||||
unique_indices = g.op("Unique", indices)
|
||||
partial_weight = g.op("Gather", weight, unique_indices)
|
||||
@ -1264,6 +1326,7 @@ def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::chunk")
|
||||
@_beartype.beartype
|
||||
def chunk(g: jit_utils.GraphContext, self, chunks, dim):
|
||||
# Calculate chunk size for dynamic chunk
|
||||
dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
|
||||
@ -1281,6 +1344,7 @@ def chunk(g: jit_utils.GraphContext, self, chunks, dim):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::normal")
|
||||
@_beartype.beartype
|
||||
def normal(
|
||||
g: jit_utils.GraphContext,
|
||||
mean,
|
||||
@ -1304,6 +1368,7 @@ def normal(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::atleast_1d")
|
||||
@_beartype.beartype
|
||||
def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
# NOTE: If it's 0D, reshape to 1D
|
||||
|
||||
@ -1330,6 +1395,7 @@ def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::atleast_2d")
|
||||
@_beartype.beartype
|
||||
def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
# NOTE: If it's 0D, reshape to 2D
|
||||
# If it's 1D, unsqueeze to 2D
|
||||
@ -1363,6 +1429,7 @@ def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::atleast_3d")
|
||||
@_beartype.beartype
|
||||
def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
# NOTE: If it's 0D, reshape to 3D
|
||||
# If it's 1D, unsqueeze to 3D
|
||||
@ -1407,6 +1474,7 @@ def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::ConstantChunk")
|
||||
@_beartype.beartype
|
||||
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
|
||||
input_shape = g.op("Shape", self)
|
||||
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
@ -1428,6 +1496,7 @@ def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::hstack")
|
||||
@_beartype.beartype
|
||||
def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
|
||||
tensor_list = atleast_1d(g, tensor_list)
|
||||
first_tensor = g.op(
|
||||
@ -1460,6 +1529,7 @@ def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::vstack")
|
||||
@_beartype.beartype
|
||||
def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
|
||||
tensor_list = atleast_2d(g, tensor_list)
|
||||
return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0)
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
@ -15,7 +14,7 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
@ -46,6 +45,7 @@ __all__ = [
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
|
||||
if not tensors:
|
||||
raise RuntimeError("Einsum inputs are empty.")
|
||||
@ -66,6 +66,7 @@ def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
|
||||
|
||||
@_onnx_symbolic("aten::einsum")
|
||||
@symbolic_helper.parse_args("s", "v", "is")
|
||||
@_beartype.beartype
|
||||
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
|
||||
tensors = symbolic_helper._unpack_list(tensor_list)
|
||||
return _einsum_helper(g, equation, tensors)
|
||||
@ -73,6 +74,7 @@ def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
|
||||
|
||||
@_onnx_symbolic("aten::outer")
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def outer(g: jit_utils.GraphContext, input, other):
|
||||
# make sure to cast other to self's type
|
||||
if _type_utils.JitScalarType.from_value(
|
||||
@ -86,6 +88,7 @@ def outer(g: jit_utils.GraphContext, input, other):
|
||||
return _einsum_helper(g, "i,j->ij", [input, other])
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _dropout_returns_masked_input_and_mask(
|
||||
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
|
||||
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
|
||||
@ -102,6 +105,7 @@ def _dropout_returns_masked_input_and_mask(
|
||||
|
||||
@_onnx_symbolic("aten::dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
@_beartype.beartype
|
||||
def dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
return masked
|
||||
@ -109,11 +113,13 @@ def dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
|
||||
@_onnx_symbolic("aten::native_dropout")
|
||||
@symbolic_helper.parse_args("v", "f", "b")
|
||||
@_beartype.beartype
|
||||
def native_dropout(g: jit_utils.GraphContext, input, p, train):
|
||||
return _dropout_returns_masked_input_and_mask(g, input, p, train)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss")
|
||||
@_beartype.beartype
|
||||
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
@ -147,6 +153,7 @@ def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss2d")
|
||||
@_beartype.beartype
|
||||
def nll_loss2d(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
@ -154,6 +161,7 @@ def nll_loss2d(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::nll_loss_nd")
|
||||
@_beartype.beartype
|
||||
def nll_loss_nd(
|
||||
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
||||
):
|
||||
@ -161,6 +169,7 @@ def nll_loss_nd(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::cross_entropy_loss")
|
||||
@_beartype.beartype
|
||||
def cross_entropy_loss(
|
||||
g: jit_utils.GraphContext,
|
||||
self,
|
||||
@ -209,6 +218,7 @@ def cross_entropy_loss(
|
||||
|
||||
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def binary_cross_entropy_with_logits(
|
||||
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
|
||||
):
|
||||
@ -253,6 +263,7 @@ def binary_cross_entropy_with_logits(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::celu")
|
||||
@_beartype.beartype
|
||||
def celu(g: jit_utils.GraphContext, self, alpha):
|
||||
alpha = symbolic_helper._maybe_get_const(alpha, "f")
|
||||
# if the input is of type double cast it to float
|
||||
@ -269,6 +280,7 @@ def celu(g: jit_utils.GraphContext, self, alpha):
|
||||
|
||||
@_onnx_symbolic("aten::argmax")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def argmax(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
@ -280,6 +292,7 @@ def argmax(
|
||||
|
||||
@_onnx_symbolic("aten::argmin")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def argmin(
|
||||
g: jit_utils.GraphContext,
|
||||
input: torch._C.Value,
|
||||
@ -290,22 +303,26 @@ def argmin(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::pow")
|
||||
@_beartype.beartype
|
||||
def pow(g: jit_utils.GraphContext, self, exponent):
|
||||
return g.op("Pow", self, exponent)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::ge")
|
||||
@_beartype.beartype
|
||||
def ge(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("GreaterOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::le")
|
||||
@_beartype.beartype
|
||||
def le(g: jit_utils.GraphContext, input, other):
|
||||
return g.op("LessOrEqual", input, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unfold")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||||
const_size = symbolic_helper._maybe_get_const(size, "i")
|
||||
const_step = symbolic_helper._maybe_get_const(step, "i")
|
||||
@ -382,6 +399,7 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||||
|
||||
@_onnx_symbolic("aten::tensordot")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
|
||||
@_beartype.beartype
|
||||
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
|
||||
if out is not None:
|
||||
symbolic_helper._unimplemented(
|
||||
|
@ -16,7 +16,7 @@ from torch.onnx import (
|
||||
symbolic_opset9 as opset9,
|
||||
utils,
|
||||
)
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
|
||||
@ -24,6 +24,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
|
||||
|
||||
@_onnx_symbolic("aten::softmax")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||
softmax = g.op("Softmax", input, axis_i=dim)
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -37,6 +38,7 @@ def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||
|
||||
@_onnx_symbolic("aten::log_softmax")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||
return_op = g.op("LogSoftmax", input, axis_i=dim)
|
||||
if dtype and dtype.node().kind() != "prim::Constant":
|
||||
@ -49,6 +51,7 @@ def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||||
|
||||
@_onnx_symbolic("aten::frobenius_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
|
||||
dim_val = symbolic_helper._maybe_get_const(dim, "is")
|
||||
if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
|
||||
@ -60,6 +63,7 @@ def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
|
||||
|
||||
@_onnx_symbolic("aten::split")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
|
||||
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
|
||||
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
|
||||
@ -116,11 +120,13 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::split_with_sizes")
|
||||
@_beartype.beartype
|
||||
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
|
||||
return split(g, self, split_sizes, dim, _outputs)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unsafe_split")
|
||||
@_beartype.beartype
|
||||
def unsafe_split(
|
||||
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
|
||||
):
|
||||
@ -128,6 +134,7 @@ def unsafe_split(
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::unsafe_split_with_sizes")
|
||||
@_beartype.beartype
|
||||
def unsafe_split_with_sizes(
|
||||
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
|
||||
):
|
||||
@ -136,6 +143,7 @@ def unsafe_split_with_sizes(
|
||||
|
||||
@_onnx_symbolic("aten::tensor_split")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def tensor_split(
|
||||
g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None
|
||||
):
|
||||
@ -274,6 +282,7 @@ def tensor_split(
|
||||
|
||||
@_onnx_symbolic("aten::unbind")
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -296,12 +305,14 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
|
||||
|
||||
@_onnx_symbolic("aten::nonzero_numpy")
|
||||
# Emitted from `torch.nonzero(x, as_tuple=True)`
|
||||
@_beartype.beartype
|
||||
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
|
||||
return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::where")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
|
||||
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
|
||||
if not symbolic_helper._is_bool(condition):
|
||||
@ -316,6 +327,7 @@ def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=
|
||||
|
||||
@_onnx_symbolic("aten::fake_quantize_per_channel_affine")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_channel_affine(
|
||||
g: jit_utils.GraphContext,
|
||||
inputs,
|
||||
@ -351,6 +363,7 @@ def fake_quantize_per_channel_affine(
|
||||
|
||||
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def fake_quantize_per_tensor_affine(
|
||||
g: jit_utils.GraphContext,
|
||||
inputs,
|
||||
@ -387,7 +400,9 @@ def fake_quantize_per_tensor_affine(
|
||||
return g.op("DequantizeLinear", quantized, scale, zero_point)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reduce_op_symbolic(onnx_op_name):
|
||||
@_beartype.beartype
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
self = symbolic_helper._maybe_cast_reduce_op_input(g, self)
|
||||
if dim is None:
|
||||
@ -404,12 +419,15 @@ def _reduce_op_symbolic(onnx_op_name):
|
||||
"aten::sum",
|
||||
decorate=[symbolic_helper._apply_params("ReduceSum", "sum")],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op, name):
|
||||
symbolic = _reduce_op_symbolic(onnx_op)
|
||||
|
||||
@symbolic_helper._overload_by_arg_count
|
||||
@_beartype.beartype
|
||||
def reduce(g, *args, **kwargs):
|
||||
@symbolic_helper.parse_args("v", "none")
|
||||
@_beartype.beartype
|
||||
def reduce_nodim(g, self, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
@ -428,6 +446,7 @@ def _reduce_with_dtype(onnx_op, name):
|
||||
return result
|
||||
|
||||
@symbolic_helper.parse_args("v", "v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def reduce_dim(g, self, dim, keepdim, dtype):
|
||||
dtype_onnx = None
|
||||
if dtype.node().kind() == "onnx::Constant":
|
||||
@ -454,6 +473,7 @@ def _reduce_with_dtype(onnx_op, name):
|
||||
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097
|
||||
# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ...
|
||||
@_onnx_symbolic("aten::unflatten")
|
||||
@_beartype.beartype
|
||||
def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size):
|
||||
input_dim = symbolic_helper._get_tensor_rank(input)
|
||||
if input_dim is None:
|
||||
@ -498,6 +518,7 @@ def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size):
|
||||
|
||||
@_onnx_symbolic("aten::unsafe_chunk")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
|
||||
if _outputs is None:
|
||||
return g.op(
|
||||
@ -525,6 +546,7 @@ def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tile")
|
||||
@_beartype.beartype
|
||||
def tile(g: jit_utils.GraphContext, self, dims):
|
||||
self_shape = g.op("Shape", self)
|
||||
self_rank = g.op("Size", self_shape)
|
||||
@ -580,6 +602,7 @@ def tile(g: jit_utils.GraphContext, self, dims):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::repeat_interleave")
|
||||
@_beartype.beartype
|
||||
def repeat_interleave(
|
||||
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
|
||||
):
|
||||
@ -722,6 +745,7 @@ def repeat_interleave(
|
||||
|
||||
@_onnx_symbolic("aten::diagonal")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
|
||||
rank = symbolic_helper._get_tensor_rank(self)
|
||||
# Replace negative indexing when rank is known
|
||||
@ -844,6 +868,7 @@ def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::linear")
|
||||
@_beartype.beartype
|
||||
def quantized_linear(
|
||||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||||
):
|
||||
@ -860,6 +885,7 @@ def quantized_linear(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::linear_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_linear_relu(
|
||||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||||
):
|
||||
@ -877,6 +903,7 @@ def quantized_linear_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv1d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv1d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -903,6 +930,7 @@ def quantized_conv1d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv2d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv2d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -929,6 +957,7 @@ def quantized_conv2d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv3d_relu")
|
||||
@_beartype.beartype
|
||||
def quantized_conv3d_relu(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -955,6 +984,7 @@ def quantized_conv3d_relu(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv1d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv1d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -980,6 +1010,7 @@ def quantized_conv1d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv2d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv2d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1005,6 +1036,7 @@ def quantized_conv2d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv3d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv3d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1030,6 +1062,7 @@ def quantized_conv3d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose1d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose1d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1058,6 +1091,7 @@ def quantized_conv_transpose1d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose2d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose2d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
@ -1086,6 +1120,7 @@ def quantized_conv_transpose2d(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::conv_transpose3d")
|
||||
@_beartype.beartype
|
||||
def quantized_conv_transpose3d(
|
||||
g: jit_utils.GraphContext,
|
||||
q_input,
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 14.
|
||||
|
||||
Note [ONNX operators that are added/updated in opset 14]
|
||||
@ -24,7 +23,7 @@ from typing import Optional
|
||||
import torch
|
||||
from torch.onnx import _constants, _type_utils, symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
__all__ = [
|
||||
"hardswish",
|
||||
@ -41,16 +40,19 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
|
||||
|
||||
@_onnx_symbolic("aten::hardswish")
|
||||
@symbolic_helper.parse_args("v")
|
||||
@_beartype.beartype
|
||||
def hardswish(g: jit_utils.GraphContext, self):
|
||||
return g.op("HardSwish", self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::tril")
|
||||
@_beartype.beartype
|
||||
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=0)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::triu")
|
||||
@_beartype.beartype
|
||||
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
return g.op("Trilu", self, diagonal, upper_i=1)
|
||||
|
||||
@ -58,6 +60,7 @@ def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
|
||||
@_onnx_symbolic("aten::reshape")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v")
|
||||
@_beartype.beartype
|
||||
def reshape(g: jit_utils.GraphContext, self, shape):
|
||||
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
|
||||
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
|
||||
@ -66,6 +69,7 @@ def reshape(g: jit_utils.GraphContext, self, shape):
|
||||
|
||||
@_onnx_symbolic("aten::batch_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
|
||||
@_beartype.beartype
|
||||
def batch_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -120,6 +124,7 @@ def batch_norm(
|
||||
|
||||
|
||||
@_onnx_symbolic("quantized::hardswish")
|
||||
@_beartype.beartype
|
||||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
|
||||
@ -134,6 +139,7 @@ def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
# NOTE: Need op.Trilu
|
||||
@_onnx_symbolic("aten::scaled_dot_product_attention")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v")
|
||||
@_beartype.beartype
|
||||
def scaled_dot_product_attention(
|
||||
g: jit_utils.GraphContext,
|
||||
query: torch._C.Value,
|
||||
@ -206,6 +212,7 @@ def scaled_dot_product_attention(
|
||||
return g.op("MatMul", attn_weight, value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _attention_scale(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
@ -242,6 +249,7 @@ def _attention_scale(
|
||||
return scale
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _causal_attention_mask(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
|
@ -31,12 +31,13 @@ import functools
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::__is_")
|
||||
@_beartype.beartype
|
||||
def aten__is_(g: jit_utils.GraphContext, self, other):
|
||||
if symbolic_helper._is_none(other):
|
||||
if isinstance(self.type(), _C.OptionalType):
|
||||
@ -49,11 +50,13 @@ def aten__is_(g: jit_utils.GraphContext, self, other):
|
||||
|
||||
@_onnx_symbolic("aten::__isnot_")
|
||||
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
|
||||
@_beartype.beartype
|
||||
def aten__isnot_(g: jit_utils.GraphContext, self, other):
|
||||
return aten__is_(g, self, other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::bernoulli")
|
||||
@_beartype.beartype
|
||||
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
|
||||
if out is not None and not symbolic_helper._is_none(out):
|
||||
symbolic_helper._unimplemented(
|
||||
@ -69,6 +72,7 @@ def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::unchecked_cast")
|
||||
@_beartype.beartype
|
||||
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
|
||||
# exists to refine the type of the Value
|
||||
# if x is Optional[Tensor], unchecked_cast will cast
|
||||
|
@ -34,7 +34,7 @@ from torch.nn.functional import (
|
||||
GRID_SAMPLE_PADDING_MODES,
|
||||
)
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper, utils
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||||
|
||||
@ -43,6 +43,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
@_beartype.beartype
|
||||
def grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input,
|
||||
@ -68,6 +69,7 @@ def grid_sampler(
|
||||
|
||||
@_onnx_symbolic("aten::scatter_add")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
src_type = _type_utils.JitScalarType.from_value(
|
||||
src, _type_utils.JitScalarType.UNDEFINED
|
||||
@ -115,6 +117,7 @@ def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
|
||||
@_onnx_symbolic("aten::scatter_reduce")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
|
||||
@_beartype.beartype
|
||||
def scatter_reduce(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 17.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 17]
|
||||
@ -23,7 +22,7 @@ from typing import Optional, Sequence
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
@ -97,6 +96,7 @@ def _compute_edge_sizes(n_fft, window_size):
|
||||
|
||||
@_onnx_symbolic("aten::stft")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
|
||||
@_beartype.beartype
|
||||
def stft(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
@ -154,7 +154,7 @@ def stft(
|
||||
signal,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
elif signal_rank is None or signal_rank > 2:
|
||||
elif signal_rank > 2:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
|
||||
f"Current rank of signal is {signal_rank}, please reduce it.",
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Optional, Sequence, Tuple
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
@ -39,6 +39,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
||||
|
||||
@_onnx_symbolic("aten::__and_")
|
||||
@_onnx_symbolic("aten::bitwise_and")
|
||||
@_beartype.beartype
|
||||
def __and_(g: jit_utils.GraphContext, self, other):
|
||||
# do type promotion (scalars don't seem to apply)
|
||||
args = [self, other]
|
||||
@ -56,6 +57,7 @@ def __and_(g: jit_utils.GraphContext, self, other):
|
||||
|
||||
@_onnx_symbolic("aten::col2im")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
|
||||
@_beartype.beartype
|
||||
def col2im(
|
||||
g,
|
||||
input: _C.Value,
|
||||
@ -103,6 +105,7 @@ def col2im(
|
||||
)
|
||||
],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||||
return symbolic_helper._reduce_with_dtype_helper(
|
||||
onnx_op, name, allow_multi_dim_support
|
||||
@ -112,6 +115,7 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
|
||||
@_onnx_symbolic("aten::native_layer_norm")
|
||||
@symbolic_helper.quantized_args(True, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
|
||||
@_beartype.beartype
|
||||
def _native_layer_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
@ -125,6 +129,7 @@ def _native_layer_norm(
|
||||
|
||||
@_onnx_symbolic("aten::glu")
|
||||
@symbolic_helper.parse_args("v", "i")
|
||||
@_beartype.beartype
|
||||
def _glu(g: jit_utils.GraphContext, input, dim):
|
||||
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
|
||||
if dim_size is not None:
|
||||
@ -138,24 +143,28 @@ def _glu(g: jit_utils.GraphContext, input, dim):
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::maximum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
@_beartype.beartype
|
||||
def maximum(g: jit_utils.GraphContext, input, other):
|
||||
return max(g, input, dim_or_y=other)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::min")
|
||||
# TODO(justinchuby): Support multiple quantized args in output
|
||||
@_beartype.beartype
|
||||
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||||
return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::minimum")
|
||||
@symbolic_helper.quantized_args(True, True)
|
||||
@_beartype.beartype
|
||||
def minimum(g: jit_utils.GraphContext, input, other):
|
||||
return min(g, input, dim_or_y=other)
|
||||
|
||||
@ -163,6 +172,7 @@ def minimum(g: jit_utils.GraphContext, input, other):
|
||||
@_onnx_symbolic("aten::amax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
|
||||
@ -171,6 +181,7 @@ def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
@_onnx_symbolic("aten::amin")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
|
||||
return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
|
||||
@ -179,6 +190,7 @@ def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
@_onnx_symbolic("aten::aminmax")
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
if not symbolic_helper._is_none(dim):
|
||||
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||||
@ -193,6 +205,7 @@ def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::var_mean")
|
||||
@_beartype.beartype
|
||||
def _var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
if len(args) == 1:
|
||||
return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
|
||||
@ -202,6 +215,7 @@ def _var_mean(g: jit_utils.GraphContext, input, *args):
|
||||
|
||||
@_onnx_symbolic("aten::logsumexp")
|
||||
@symbolic_helper.parse_args("v", "is", "i")
|
||||
@_beartype.beartype
|
||||
def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||||
if dim is None:
|
||||
return g.op("ReduceLogSumExp", input, keepdims_i=0)
|
||||
@ -212,6 +226,7 @@ def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||||
|
||||
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||||
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def _linalg_matrix_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
@ -225,6 +240,7 @@ def _linalg_matrix_norm(
|
||||
|
||||
@_onnx_symbolic("aten::embedding_bag")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def embedding_bag(
|
||||
g: jit_utils.GraphContext,
|
||||
embedding_matrix,
|
||||
@ -253,6 +269,7 @@ def embedding_bag(
|
||||
|
||||
@_onnx_symbolic("aten::linalg_vector_norm")
|
||||
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||||
@_beartype.beartype
|
||||
def linalg_vector_norm(
|
||||
g: jit_utils.GraphContext,
|
||||
self: torch._C.Value,
|
||||
|
@ -27,7 +27,7 @@ import torch.nn.functional as F
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
@ -46,6 +46,7 @@ _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
|
||||
|
||||
@_onnx_symbolic("aten::grid_sampler")
|
||||
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||||
@_beartype.beartype
|
||||
def _grid_sampler(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
@ -70,6 +71,7 @@ def _grid_sampler(
|
||||
|
||||
@_onnx_symbolic("aten::affine_grid_generator")
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
@_beartype.beartype
|
||||
def _affine_grid_generator(
|
||||
g: jit_utils.GraphContext,
|
||||
theta: _C.Value,
|
||||
@ -86,5 +88,6 @@ def _affine_grid_generator(
|
||||
|
||||
@_onnx_symbolic("aten::gelu")
|
||||
@symbolic_helper.parse_args("v", "s")
|
||||
@_beartype.beartype
|
||||
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
|
||||
return g.op("Gelu", self, approximate_s=approximate)
|
||||
|
@ -165,7 +165,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
||||
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
|
||||
old_type = arg0_type
|
||||
if old_type not in floating_scalar_types:
|
||||
old_type = old_type.scalar_name() # type: ignore[assignment]
|
||||
old_type = old_type.scalar_name()
|
||||
args = tuple(
|
||||
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
for arg in args
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
import typing
|
||||
import warnings
|
||||
@ -37,13 +38,17 @@ from torch.onnx import ( # noqa: F401
|
||||
_constants,
|
||||
_exporter_states,
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import io
|
||||
from torch.onnx._internal import (
|
||||
_beartype,
|
||||
diagnostics,
|
||||
jit_utils,
|
||||
onnx_proto_utils,
|
||||
registration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"is_in_onnx_export",
|
||||
@ -73,6 +78,7 @@ _params_dict = {} # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
||||
r"""A context manager to temporarily set the training mode of ``model``
|
||||
to ``mode``, resetting it when we exit the with-block.
|
||||
@ -121,6 +127,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def disable_apex_o2_state_dict_hook(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction]
|
||||
):
|
||||
@ -154,6 +161,7 @@ def disable_apex_o2_state_dict_hook(
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def setup_onnx_logging(verbose: bool):
|
||||
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
|
||||
if is_originally_enabled or verbose:
|
||||
@ -166,6 +174,7 @@ def setup_onnx_logging(verbose: bool):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@_beartype.beartype
|
||||
def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
|
||||
with select_model_mode_for_export(
|
||||
model, mode
|
||||
@ -559,6 +568,7 @@ def export(
|
||||
return None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_constant_tensor_list(node):
|
||||
if node.kind() != "prim::Constant":
|
||||
return False
|
||||
@ -573,6 +583,7 @@ def _is_constant_tensor_list(node):
|
||||
# get generated in constant prop. So we split them back into prim::ListConstructs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _split_tensor_list_constants(g, block):
|
||||
for node in block.nodes():
|
||||
for subblock in node.blocks():
|
||||
@ -595,6 +606,7 @@ def _split_tensor_list_constants(g, block):
|
||||
node.output().replaceAllUsesWith(lc)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _optimize_graph(
|
||||
graph: _C.Graph,
|
||||
operator_export_type: _C_onnx.OperatorExportTypes,
|
||||
@ -705,6 +717,7 @@ def _optimize_graph(
|
||||
return graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def warn_on_static_input_change(input_states):
|
||||
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
|
||||
|
||||
@ -734,11 +747,13 @@ def warn_on_static_input_change(input_states):
|
||||
warnings.warn(warning)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
|
||||
return arg_value
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_keep_init_as_input(
|
||||
keep_initializers_as_inputs: Optional[bool],
|
||||
operator_export_type: _C_onnx.OperatorExportTypes,
|
||||
@ -782,12 +797,14 @@ def _decide_keep_init_as_input(
|
||||
return val_keep_init_as_ip
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_add_node_names(add_node_names, operator_export_type):
|
||||
return _resolve_args_by_export_type(
|
||||
"add_node_names", add_node_names, operator_export_type
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
|
||||
do_constant_folding = _resolve_args_by_export_type(
|
||||
"do_constant_folding", do_constant_folding, operator_export_type
|
||||
@ -806,6 +823,7 @@ def _decide_constant_folding(do_constant_folding, operator_export_type, training
|
||||
return do_constant_folding
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _signature(model) -> inspect.Signature:
|
||||
should_be_callable = getattr(model, "forward", model)
|
||||
if callable(should_be_callable):
|
||||
@ -813,6 +831,7 @@ def _signature(model) -> inspect.Signature:
|
||||
raise ValueError("model has no forward method and is not callable")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _decide_input_format(model, args):
|
||||
try:
|
||||
sig = _signature(model)
|
||||
@ -851,6 +870,7 @@ def _decide_input_format(model, args):
|
||||
return args
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _from_dynamic_axes_to_dynamic_shapes(
|
||||
model,
|
||||
dynamic_axes: Optional[
|
||||
@ -909,6 +929,7 @@ def _from_dynamic_axes_to_dynamic_shapes(
|
||||
return dynamic_shapes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace(func, args, operator_export_type, return_outs=False):
|
||||
# Special case for common case of passing a single Tensor
|
||||
if isinstance(args, torch.Tensor):
|
||||
@ -929,6 +950,7 @@ def _trace(func, args, operator_export_type, return_outs=False):
|
||||
return trace_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace_and_get_graph_from_model(model, args):
|
||||
# A basic sanity check: make sure the state_dict keys are the same
|
||||
# before and after running the model. Fail fast!
|
||||
@ -959,6 +981,7 @@ def _trace_and_get_graph_from_model(model, args):
|
||||
return trace_graph, torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_param_count_list(method_graph, args_params):
|
||||
param_count_list = []
|
||||
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
|
||||
@ -971,9 +994,11 @@ def _get_param_count_list(method_graph, args_params):
|
||||
return param_count_list
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _check_flatten_did_not_remove(original, jit_flattened):
|
||||
"""torch.jit._flatten removes None. Check if it did so in this case."""
|
||||
|
||||
@_beartype.beartype
|
||||
def flatten(x):
|
||||
if isinstance(x, (list, tuple)):
|
||||
for inner in x:
|
||||
@ -1046,6 +1071,7 @@ def _create_jit_graph(
|
||||
return graph, params, torch_out, None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_named_param_dict(graph, params):
|
||||
input_and_param_names = [val.debugName() for val in graph.inputs()]
|
||||
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
|
||||
@ -1053,6 +1079,7 @@ def _get_named_param_dict(graph, params):
|
||||
return _params_dict
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_example_outputs(model, args):
|
||||
input_args = copy.deepcopy(args)
|
||||
input_kwargs = {}
|
||||
@ -1077,6 +1104,7 @@ _qtype_vtype_map = {
|
||||
}
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unpack_quantized_tensor(value, cast_onnx_accepted=True):
|
||||
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
|
||||
q_value_dequantize = value.dequantize()
|
||||
@ -1097,6 +1125,7 @@ def unpack_quantized_tensor(value, cast_onnx_accepted=True):
|
||||
return (value,)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _pre_trace_quant_model(model, args):
|
||||
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
|
||||
original model.
|
||||
@ -1110,6 +1139,7 @@ def _pre_trace_quant_model(model, args):
|
||||
return model
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _model_to_graph(
|
||||
model,
|
||||
args,
|
||||
@ -1245,6 +1275,7 @@ def _model_to_graph(
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
@torch._disable_dynamo
|
||||
def export_to_pretty_string(
|
||||
model,
|
||||
@ -1322,6 +1353,7 @@ def export_to_pretty_string(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unconvertible_ops(
|
||||
model,
|
||||
args,
|
||||
@ -1390,6 +1422,7 @@ def unconvertible_ops(
|
||||
return graph, unsupported_ops
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _setup_trace_module_map(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
|
||||
@ -1467,11 +1500,13 @@ def _setup_trace_module_map(
|
||||
return module_typenames
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _reset_trace_module_map():
|
||||
torch.jit._trace._trace_module_map = None
|
||||
_C._jit_pass_onnx_clear_scope_records()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_module_attributes(module):
|
||||
annotations = typing.get_type_hints(type(module))
|
||||
base_m_annotations = typing.get_type_hints(torch.nn.Module)
|
||||
@ -1496,6 +1531,7 @@ def _get_module_attributes(module):
|
||||
return attrs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export(
|
||||
model,
|
||||
args,
|
||||
@ -1704,6 +1740,7 @@ def _export(
|
||||
return torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _apply_friendly_debug_names(graph, params):
|
||||
for n in graph.nodes():
|
||||
for v in n.inputs():
|
||||
@ -1716,7 +1753,9 @@ def _apply_friendly_debug_names(graph, params):
|
||||
params[new_name] = params.pop(old_name)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _set_input_and_output_names(graph, input_names, output_names):
|
||||
@_beartype.beartype
|
||||
def set_names(node_list, name_list, descriptor):
|
||||
if name_list is None:
|
||||
return
|
||||
@ -1747,6 +1786,7 @@ def _set_input_and_output_names(graph, input_names, output_names):
|
||||
set_names(list(graph.outputs()), output_names, "output")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_symbolic_method(g, op_name, symbolic_fn, args):
|
||||
r"""
|
||||
This trampoline function gets invoked for every symbolic method
|
||||
@ -1772,18 +1812,22 @@ def _run_symbolic_method(g, op_name, symbolic_fn, args):
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_block(node: _C.Node) -> _C.Block:
|
||||
return node.addBlock()
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_input_to_block(block: _C.Block):
|
||||
return block.addInputToBlock() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
|
||||
return block.registerOutput(value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _should_aten_fallback(
|
||||
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
|
||||
):
|
||||
@ -1805,6 +1849,7 @@ def _should_aten_fallback(
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _need_symbolic_context(symbolic_fn: Callable) -> bool:
|
||||
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
|
||||
params = tuple(inspect.signature(symbolic_fn).parameters.values())
|
||||
@ -1820,6 +1865,7 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
|
||||
return issubclass(param_type, _exporter_states.SymbolicContext)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||
"""Decorator that provides the symbolic context to the symbolic function if needed."""
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
@ -1844,6 +1890,7 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
||||
schema = n.schema()
|
||||
@ -1852,6 +1899,7 @@ def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
return _C.parse_schema(schema).overload_name
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_symbolic_function(
|
||||
graph: _C.Graph,
|
||||
block: _C.Block,
|
||||
@ -1964,6 +2012,7 @@ def _run_symbolic_function(
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _verify_custom_op_name(symbolic_name: str):
|
||||
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
|
||||
raise errors.OnnxExporterError(
|
||||
@ -1980,6 +2029,7 @@ def _verify_custom_op_name(symbolic_name: str):
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def register_custom_op_symbolic(
|
||||
symbolic_name: str,
|
||||
symbolic_fn: Callable,
|
||||
@ -2016,6 +2066,7 @@ def register_custom_op_symbolic(
|
||||
)(symbolic_fn)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
"""Unregisters ``symbolic_name``.
|
||||
|
||||
@ -2034,6 +2085,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
registration.registry.unregister(symbolic_name, opset_version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
||||
"""Ensures dynamic axes argument is follows the expected format."""
|
||||
if len(dynamic_axes) == 0:
|
||||
|
@ -40,7 +40,7 @@ import torch._C._onnx as _C_onnx
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, _experimental, _exporter_states, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import onnx_proto_utils
|
||||
from torch.onnx._internal import _beartype, onnx_proto_utils
|
||||
from torch.types import Number
|
||||
|
||||
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
||||
@ -99,6 +99,7 @@ class VerificationOptions:
|
||||
acceptable_error_percentage: Optional[float] = None
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _flatten_tuples(elem):
|
||||
flattened = []
|
||||
for t in elem:
|
||||
@ -128,6 +129,7 @@ def _to_numpy(elem) -> Union[list, np.ndarray]:
|
||||
return elem
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _inline_flatten_list(inputs, res_list) -> list:
|
||||
for i in inputs:
|
||||
res_list.append(i) if not isinstance(
|
||||
@ -136,6 +138,7 @@ def _inline_flatten_list(inputs, res_list) -> list:
|
||||
return res_list
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
|
||||
value_unpacked = []
|
||||
for value in values:
|
||||
@ -145,6 +148,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
|
||||
return [_to_numpy(v) for v in value_unpacked]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_onnx(onnx_session, inputs) -> _OutputsType:
|
||||
kw_inputs = {}
|
||||
if inputs and isinstance(inputs[-1], dict):
|
||||
@ -175,6 +179,7 @@ def _run_onnx(onnx_session, inputs) -> _OutputsType:
|
||||
return onnx_outs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _ort_session(
|
||||
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
|
||||
):
|
||||
@ -198,6 +203,7 @@ def _ort_session(
|
||||
return ort_session
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
|
||||
try:
|
||||
import onnx
|
||||
@ -214,6 +220,7 @@ def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
|
||||
return onnx_session
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
|
||||
if backend == OnnxBackend.REFERENCE:
|
||||
onnx_session = _onnx_reference_evaluator_session(model)
|
||||
@ -224,6 +231,7 @@ def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
|
||||
return onnx_session
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_onnx_pytorch_outputs_in_np(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: _OutputsType,
|
||||
@ -273,6 +281,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_onnx_pytorch_outputs(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: Any,
|
||||
@ -301,6 +310,7 @@ def _compare_onnx_pytorch_outputs(
|
||||
_compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_pytorch(args, kwargs):
|
||||
"""Prepare input for PyTorch model execution.
|
||||
|
||||
@ -327,6 +337,7 @@ def _prepare_input_for_pytorch(args, kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_export(args, kwargs):
|
||||
"""Prepare input for ONNX model export.
|
||||
|
||||
@ -351,6 +362,7 @@ def _prepare_input_for_export(args, kwargs):
|
||||
return onnx_inputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_onnx(
|
||||
args, kwargs, remained_onnx_input_idx: Optional[Sequence[int]], flatten: bool
|
||||
):
|
||||
@ -380,6 +392,7 @@ def _prepare_input_for_onnx(
|
||||
return onnx_inputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _try_clone_model(model):
|
||||
"""Used for preserving original model in case forward mutates model states."""
|
||||
try:
|
||||
@ -391,6 +404,7 @@ def _try_clone_model(model):
|
||||
return model
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_onnx_pytorch_model(
|
||||
pt_model: _ModelType,
|
||||
onnx_model_f: Union[str, io.BytesIO],
|
||||
@ -416,6 +430,7 @@ def _compare_onnx_pytorch_model(
|
||||
"""
|
||||
onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
|
||||
|
||||
@_beartype.beartype
|
||||
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
|
||||
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
|
||||
# TODO: remove this and treat mutating model separately. See #77679
|
||||
@ -444,6 +459,7 @@ def _compare_onnx_pytorch_model(
|
||||
class _GraphDiff:
|
||||
"""A class to represent the difference between two graphs."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
|
||||
"""Construct a _GraphDiff object.
|
||||
|
||||
@ -454,13 +470,16 @@ class _GraphDiff:
|
||||
self.graph_a = graph_a
|
||||
self.graph_b = graph_b
|
||||
|
||||
@_beartype.beartype
|
||||
def __str__(self):
|
||||
"""See function :func:`diff_report`."""
|
||||
return self.diff_report()
|
||||
|
||||
@_beartype.beartype
|
||||
def _indent(self, lines: str) -> str:
|
||||
return "\n".join(["\t" + line for line in lines.splitlines()])
|
||||
|
||||
@_beartype.beartype
|
||||
def diff_report(self) -> str:
|
||||
"""Return a string representation of the graph difference.
|
||||
|
||||
@ -510,6 +529,7 @@ class _GraphDiff:
|
||||
return "\n".join(graph_diff_report)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _check_graph_diff(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
||||
@ -551,6 +571,7 @@ def _check_graph_diff(
|
||||
return ""
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _traced_graph_from_model(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
args: Tuple[Any, ...],
|
||||
@ -578,6 +599,7 @@ def _traced_graph_from_model(
|
||||
return jit_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_graph_from_model(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
args: Tuple[Any, ...],
|
||||
@ -642,6 +664,7 @@ def _onnx_graph_from_model(
|
||||
return onnx_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_graph_from_aten_graph(
|
||||
graph: torch.Graph,
|
||||
export_options: _experimental.ExportOptions,
|
||||
@ -705,6 +728,7 @@ def _onnx_graph_from_aten_graph(
|
||||
return graph, params_dict
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_proto_from_onnx_graph(
|
||||
onnx_graph: torch.Graph,
|
||||
export_options: _experimental.ExportOptions,
|
||||
@ -738,6 +762,7 @@ def _onnx_proto_from_onnx_graph(
|
||||
return proto, export_map
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def check_export_model_diff(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
||||
@ -781,6 +806,7 @@ def check_export_model_diff(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def verify(
|
||||
model: _ModelType,
|
||||
input_args: _InputArgsType,
|
||||
@ -869,6 +895,7 @@ def verify(
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def verify_aten_graph(
|
||||
graph: torch.Graph,
|
||||
input_args: Tuple[Any, ...],
|
||||
@ -969,6 +996,7 @@ class GraphInfoPrettyPrinter:
|
||||
self.upper_printer = None
|
||||
self.lower_printer = None
|
||||
|
||||
@_beartype.beartype
|
||||
def _total_rows(self) -> int:
|
||||
if self.graph_info is None:
|
||||
return 1
|
||||
@ -978,6 +1006,7 @@ class GraphInfoPrettyPrinter:
|
||||
)
|
||||
return 2 # Two lines: node count + id.
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_count_segment_str(self) -> str:
|
||||
if self.graph_info is None:
|
||||
return "..."
|
||||
@ -991,16 +1020,19 @@ class GraphInfoPrettyPrinter:
|
||||
|
||||
return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}"
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_id_segment_str(self) -> str:
|
||||
if self.graph_info is None:
|
||||
return ""
|
||||
return f"id: {self.graph_info.id}"
|
||||
|
||||
@_beartype.beartype
|
||||
def _max_segment_columns(self) -> int:
|
||||
return max(
|
||||
map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_segment_str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the graph segment at the given line."""
|
||||
if line == 0:
|
||||
@ -1015,6 +1047,7 @@ class GraphInfoPrettyPrinter:
|
||||
return " " * self._max_segment_columns()
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _connector_segment_str_at_line(self, line: int) -> str:
|
||||
"""Get the connector segment string at the given line."""
|
||||
if self.upper_printer is None and self.lower_printer is None:
|
||||
@ -1031,6 +1064,7 @@ class GraphInfoPrettyPrinter:
|
||||
return " "
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _children_str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the children at the given line.
|
||||
|
||||
@ -1052,6 +1086,7 @@ class GraphInfoPrettyPrinter:
|
||||
)
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the graph at the given line."""
|
||||
return (
|
||||
@ -1101,6 +1136,7 @@ class OnnxTestCaseRepro:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def create_test_case_repro(
|
||||
cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None
|
||||
):
|
||||
@ -1138,6 +1174,7 @@ class OnnxTestCaseRepro:
|
||||
dir,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def validate(self, options: VerificationOptions):
|
||||
"""Run the ONNX test case with options.backend, and compare with the expected outputs.
|
||||
|
||||
@ -1249,16 +1286,19 @@ class GraphInfo:
|
||||
else:
|
||||
print(" No mismatch ".center(80, "="))
|
||||
|
||||
@_beartype.beartype
|
||||
def has_mismatch(self) -> bool:
|
||||
"""Return True if the subgraph has output mismatch between torch and ONNX."""
|
||||
return self.mismatch_error is not None
|
||||
|
||||
@_beartype.beartype
|
||||
def essential_node_count(self) -> int:
|
||||
"""Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
|
||||
return sum(
|
||||
1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def essential_node_kinds(self) -> Set[str]:
|
||||
"""Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
|
||||
return {
|
||||
@ -1267,7 +1307,8 @@ class GraphInfo:
|
||||
if n.kind() not in self._EXCLUDED_NODE_KINDS
|
||||
}
|
||||
|
||||
def all_mismatch_leaf_graph_info(self) -> List[GraphInfo]:
|
||||
@_beartype.beartype
|
||||
def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]:
|
||||
"""Return a list of all leaf `GraphInfo` objects that have mismatch."""
|
||||
if not self.has_mismatch():
|
||||
return []
|
||||
@ -1289,7 +1330,8 @@ class GraphInfo:
|
||||
|
||||
return results
|
||||
|
||||
def find_partition(self, id: str) -> Optional[GraphInfo]:
|
||||
@_beartype.beartype
|
||||
def find_partition(self, id: str) -> Optional["GraphInfo"]:
|
||||
"""Find the `GraphInfo` object with the given id."""
|
||||
if id == self.id:
|
||||
return self
|
||||
@ -1301,6 +1343,7 @@ class GraphInfo:
|
||||
return self.lower_graph_info.find_partition(id)
|
||||
return None
|
||||
|
||||
@_beartype.beartype
|
||||
def export_repro(
|
||||
self, repro_dir: Optional[str] = None, name: Optional[str] = None
|
||||
) -> str:
|
||||
@ -1341,6 +1384,7 @@ class GraphInfo:
|
||||
proto, self.input_args, self.pt_outs, repro_dir, name
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_partition_pivot(self) -> int:
|
||||
"""Find the pivot index to partition the graph.
|
||||
|
||||
@ -1363,6 +1407,7 @@ class GraphInfo:
|
||||
return included_node_indices[half_idx] + 1
|
||||
return -1
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_upper_graph(self) -> torch.Graph:
|
||||
pivot = self._graph_partition_pivot()
|
||||
if pivot == -1:
|
||||
@ -1406,6 +1451,7 @@ class GraphInfo:
|
||||
|
||||
return graph
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_lower_graph(self) -> torch.Graph:
|
||||
pivot = self._graph_partition_pivot()
|
||||
if pivot == -1:
|
||||
@ -1460,6 +1506,7 @@ class GraphInfo:
|
||||
|
||||
return graph
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_node(
|
||||
self,
|
||||
node: torch.Node,
|
||||
@ -1498,6 +1545,7 @@ class GraphInfo:
|
||||
):
|
||||
covered_bridge_values.add(process_bridge_value(output))
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_nodes(
|
||||
self,
|
||||
graph: torch.Graph,
|
||||
@ -1537,6 +1585,7 @@ class GraphInfo:
|
||||
complete_lower_nodes_set,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _bridge_kwargs(self):
|
||||
pt_outs = self.pt_outs
|
||||
graph_outputs = list(self.graph.outputs())
|
||||
@ -1546,6 +1595,7 @@ class GraphInfo:
|
||||
), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
|
||||
return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
|
||||
|
||||
@_beartype.beartype
|
||||
def _args_and_params_for_partition_graph(
|
||||
self,
|
||||
graph: torch.Graph,
|
||||
@ -1562,6 +1612,7 @@ class GraphInfo:
|
||||
), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
|
||||
return args, params
|
||||
|
||||
@_beartype.beartype
|
||||
def verify_export(
|
||||
self, options: VerificationOptions
|
||||
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
|
||||
@ -1591,6 +1642,7 @@ class GraphInfo:
|
||||
verification_options=options,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def find_mismatch(
|
||||
self,
|
||||
options: Optional[VerificationOptions] = None,
|
||||
@ -1670,6 +1722,7 @@ class GraphInfo:
|
||||
self.lower_graph_info.find_mismatch(options)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
|
||||
all_nodes = set(nodes)
|
||||
for n in nodes:
|
||||
@ -1678,10 +1731,12 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
|
||||
return all_nodes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
||||
return any(use.user in nodes for use in value.uses())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
|
||||
for output in node.outputs():
|
||||
if _has_uses_by_nodes(output, nodes):
|
||||
@ -1689,10 +1744,12 @@ def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
||||
return value.node() in nodes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def find_mismatch(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
input_args: Tuple[Any, ...],
|
||||
|
Reference in New Issue
Block a user