mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ONNX] New symbolic function registry (#84382)
## Summary The change brings the new registry for symbolic functions in ONNX. The `SymbolicRegistry` class in `torch.onnx._internal.registration` replaces the dictionary and various functions defined in `torch.onnx.symbolic_registry`. The new registry - Has faster lookup by storing only functions in the opset version they are defined in - Is easier to manage and interact with due to its class design - Builds the foundation for the more flexible registration process detailed in #83787 Implementation changes - **Breaking**: Remove `torch.onnx.symbolic_registry` - `register_custom_op_symbolic` and `unregister_custom_op_symbolic` in utils maintain their api for compatibility - Update _onnx_supported_ops.py for doc generation to include quantized ops. - Update code to register python ops in `torch/csrc/jit/passes/onnx.cpp` ## Profiling results -0.1 seconds in execution time. -34% time spent in `_run_symbolic_function`. Tested on the alexnet example in public doc. ### After ``` └─ 1.641 export <@beartype(torch.onnx.utils.export) at 0x7f19be17f790>:1 └─ 1.641 export torch/onnx/utils.py:185 └─ 1.640 _export torch/onnx/utils.py:1331 ├─ 0.889 _model_to_graph torch/onnx/utils.py:1005 │ ├─ 0.478 _optimize_graph torch/onnx/utils.py:535 │ │ ├─ 0.214 PyCapsule._jit_pass_onnx_graph_shape_type_inference <built-in>:0 │ │ │ [2 frames hidden] <built-in> │ │ ├─ 0.190 _run_symbolic_function torch/onnx/utils.py:1670 │ │ │ └─ 0.145 Constant torch/onnx/symbolic_opset9.py:5782 │ │ │ └─ 0.139 _graph_op torch/onnx/_patch_torch.py:18 │ │ │ └─ 0.134 PyCapsule._jit_pass_onnx_node_shape_type_inference <built-in>:0 │ │ │ [2 frames hidden] <built-in> │ │ └─ 0.033 [self] ``` ### Before  ### Start up time The startup process takes 0.03 seconds. Calls to `inspect` will be eliminated when we switch to using decorators for registration in #84448  Pull Request resolved: https://github.com/pytorch/pytorch/pull/84382 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
735154354b
commit
cd7e6d4ad1
168
test/onnx/internal/test_registraion.py
Normal file
168
test/onnx/internal/test_registraion.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the internal registration wrapper module."""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from torch.onnx._internal import registration
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestGlobalHelpers(common_utils.TestCase):
|
||||
@common_utils.parametrize(
|
||||
"available_opsets, target, expected",
|
||||
[
|
||||
((7, 8, 9, 10, 11), 16, 11),
|
||||
((7, 8, 9, 10, 11), 11, 11),
|
||||
((7, 8, 9, 10, 11), 10, 10),
|
||||
((7, 8, 9, 10, 11), 9, 9),
|
||||
((7, 8, 9, 10, 11), 8, 8),
|
||||
((7, 8, 9, 10, 11), 7, 7),
|
||||
((9, 10, 16), 16, 16),
|
||||
((9, 10, 16), 15, 10),
|
||||
((9, 10, 16), 10, 10),
|
||||
((9, 10, 16), 9, 9),
|
||||
((9, 10, 16), 8, 9),
|
||||
((9, 10, 16), 7, 9),
|
||||
((7, 9, 10, 16), 16, 16),
|
||||
((7, 9, 10, 16), 10, 10),
|
||||
((7, 9, 10, 16), 9, 9),
|
||||
((7, 9, 10, 16), 8, 9),
|
||||
((7, 9, 10, 16), 7, 7),
|
||||
([17], 16, None), # New op added in 17
|
||||
([9], 9, 9),
|
||||
([9], 8, 9),
|
||||
([], 16, None),
|
||||
([], 9, None),
|
||||
([], 8, None),
|
||||
([8], 16, None), # Ops lower than 9 are not supported by versions >= 9
|
||||
],
|
||||
)
|
||||
def test_dispatch_opset_version_returns_correct_version(
|
||||
self, available_opsets: Sequence[int], target: int, expected: int
|
||||
):
|
||||
actual = registration._dispatch_opset_version(target, available_opsets)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class TestOverrideDict(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
self.override_dict: registration.OverrideDict[
|
||||
str, int
|
||||
] = registration.OverrideDict()
|
||||
|
||||
def test_get_item_returns_base_value_when_no_override(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
|
||||
self.assertEqual(self.override_dict["a"], 42)
|
||||
self.assertEqual(self.override_dict["b"], 0)
|
||||
self.assertEqual(len(self.override_dict), 2)
|
||||
|
||||
def test_get_item_returns_overridden_value_when_override(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.override("c", 1)
|
||||
|
||||
self.assertEqual(self.override_dict["a"], 100)
|
||||
self.assertEqual(self.override_dict["b"], 0)
|
||||
self.assertEqual(self.override_dict["c"], 1)
|
||||
self.assertEqual(len(self.override_dict), 3)
|
||||
|
||||
def test_get_item_raises_key_error_when_not_found(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.override_dict["nonexistent_key"]
|
||||
|
||||
def test_get_returns_overridden_value_when_override(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.override("c", 1)
|
||||
|
||||
self.assertEqual(self.override_dict.get("a"), 100)
|
||||
self.assertEqual(self.override_dict.get("b"), 0)
|
||||
self.assertEqual(self.override_dict.get("c"), 1)
|
||||
self.assertEqual(len(self.override_dict), 3)
|
||||
|
||||
def test_get_returns_none_when_not_found(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
|
||||
self.assertEqual(self.override_dict.get("nonexistent_key"), None)
|
||||
|
||||
def test_in_base_returns_true_for_base_value(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.override("c", 1)
|
||||
|
||||
self.assertIn("a", self.override_dict)
|
||||
self.assertIn("b", self.override_dict)
|
||||
self.assertIn("c", self.override_dict)
|
||||
|
||||
self.assertTrue(self.override_dict.in_base("a"))
|
||||
self.assertTrue(self.override_dict.in_base("b"))
|
||||
self.assertFalse(self.override_dict.in_base("c"))
|
||||
self.assertFalse(self.override_dict.in_base("nonexistent_key"))
|
||||
|
||||
def test_overridden_returns_true_for_overridden_value(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.override("c", 1)
|
||||
|
||||
self.assertTrue(self.override_dict.overridden("a"))
|
||||
self.assertFalse(self.override_dict.overridden("b"))
|
||||
self.assertTrue(self.override_dict.overridden("c"))
|
||||
self.assertFalse(self.override_dict.overridden("nonexistent_key"))
|
||||
|
||||
def test_remove_override_removes_overridden_value(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.set_base("b", 0)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.override("c", 1)
|
||||
|
||||
self.assertEqual(self.override_dict["a"], 100)
|
||||
self.assertEqual(self.override_dict["c"], 1)
|
||||
|
||||
self.override_dict.remove_override("a")
|
||||
self.override_dict.remove_override("c")
|
||||
self.assertEqual(self.override_dict["a"], 42)
|
||||
self.assertEqual(self.override_dict.get("c"), None)
|
||||
self.assertFalse(self.override_dict.overridden("a"))
|
||||
self.assertFalse(self.override_dict.overridden("c"))
|
||||
|
||||
def test_remove_override_removes_overridden_key(self):
|
||||
self.override_dict.override("a", 100)
|
||||
self.assertEqual(self.override_dict["a"], 100)
|
||||
self.assertEqual(len(self.override_dict), 1)
|
||||
self.override_dict.remove_override("a")
|
||||
self.assertEqual(len(self.override_dict), 0)
|
||||
self.assertNotIn("a", self.override_dict)
|
||||
|
||||
def test_overriden_key_precededs_base_key_regardless_of_insert_order(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
self.override_dict.override("a", 100)
|
||||
self.override_dict.set_base("a", 0)
|
||||
|
||||
self.assertEqual(self.override_dict["a"], 100)
|
||||
self.assertEqual(len(self.override_dict), 1)
|
||||
|
||||
def test_bool_is_true_when_not_empty(self):
|
||||
if self.override_dict:
|
||||
self.fail("OverrideDict should be false when empty")
|
||||
self.override_dict.override("a", 1)
|
||||
if not self.override_dict:
|
||||
self.fail("OverrideDict should be true when not empty")
|
||||
self.override_dict.set_base("a", 42)
|
||||
if not self.override_dict:
|
||||
self.fail("OverrideDict should be true when not empty")
|
||||
self.override_dict.remove_override("a")
|
||||
if not self.override_dict:
|
||||
self.fail("OverrideDict should be true when not empty")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -6,6 +6,7 @@ import contextlib
|
||||
import io
|
||||
import itertools
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import onnx
|
||||
@ -14,14 +15,15 @@ import onnx.numpy_helper
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.onnx import symbolic_helper, symbolic_registry, utils
|
||||
from torch.onnx import symbolic_helper, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import registration
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def export_to_onnx(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
|
||||
input: Tuple[torch.Tensor],
|
||||
input: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
custom_ops: Optional[
|
||||
Iterable[
|
||||
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
|
||||
@ -436,14 +438,11 @@ class TestONNXExport(common_utils.TestCase):
|
||||
def forward(self, x):
|
||||
return torch.clamp(x, min=-0.5, max=0.5)
|
||||
|
||||
def break_is_registered_op_api(opname, domain, version):
|
||||
fake_missing_symbolics = ("clamp",)
|
||||
if opname in fake_missing_symbolics:
|
||||
return False
|
||||
return (
|
||||
(domain, version) in symbolic_registry._registry
|
||||
and opname in symbolic_registry._registry[(domain, version)]
|
||||
)
|
||||
def break_is_registered_op_api(name):
|
||||
fake_missing_symbolics = {"aten::clamp"}
|
||||
if name in fake_missing_symbolics:
|
||||
return None
|
||||
return registration.registry.get_function_group(name)
|
||||
|
||||
# Force missing symbolic for well-known op using a mock
|
||||
onnx_model = export_to_onnx(
|
||||
@ -451,7 +450,7 @@ class TestONNXExport(common_utils.TestCase):
|
||||
torch.randn(3, 4, requires_grad=True),
|
||||
mocks=[
|
||||
unittest.mock.patch(
|
||||
"torch.onnx.symbolic_registry.is_registered_op",
|
||||
"torch.onnx._internal.registration.registry.get_function_group",
|
||||
side_effect=break_is_registered_op_api,
|
||||
)
|
||||
],
|
||||
|
@ -243,7 +243,8 @@ void NodeToONNX(
|
||||
std::unordered_map<Value*, Value*>& env) {
|
||||
py::object onnx = py::module::import("torch.onnx");
|
||||
py::object onnx_globals = py::module::import("torch.onnx._globals");
|
||||
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
|
||||
py::object onnx_registration =
|
||||
py::module::import("torch.onnx._internal.registration");
|
||||
|
||||
// Setup all the lambda helper functions.
|
||||
|
||||
@ -452,10 +453,13 @@ void NodeToONNX(
|
||||
|
||||
py::object opset_version =
|
||||
onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
|
||||
py::object is_registered_op = onnx_registry.attr("is_registered_op")(
|
||||
"PythonOp", "prim", opset_version);
|
||||
if (!py::hasattr(pyobj, "symbolic") &&
|
||||
(!PyObject_IsTrue(is_registered_op.ptr()))) {
|
||||
// NOTE(justinchuby): Call the internal registry to register the symbolic
|
||||
// method defined in the module.
|
||||
bool is_registered_op =
|
||||
onnx_registration.attr("registry")
|
||||
.attr("is_registered_op")("prim::PythonOp", opset_version)
|
||||
.cast<bool>();
|
||||
if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
|
||||
// Inline the subgraph within the prim::PythonOp unless
|
||||
// either of these conditions are satisfied
|
||||
// 1. The torch.autograd.Function class of this node object has `symbolic`
|
||||
@ -514,8 +518,16 @@ void NodeToONNX(
|
||||
// Call the symbolic function
|
||||
// Use a little trampoline function so we can give good error messages
|
||||
// upon argument mismatch
|
||||
onnx_registry.attr("register_op")(
|
||||
op->name(), pyobj.attr("symbolic"), "", opset_version);
|
||||
// Register as a custom operator
|
||||
// TODO: Find a more elegant way to do this without having to touch
|
||||
// internal Python modules.
|
||||
// TODO(justinchuby): Define a namespace for these Python Ops.
|
||||
onnx_registration.attr("registry")
|
||||
.attr("register")(
|
||||
"::" + op->name(),
|
||||
opset_version,
|
||||
pyobj.attr("symbolic"),
|
||||
/* custom */ true);
|
||||
py::object raw_output = onnx.attr("_run_symbolic_method")(
|
||||
new_block->owningGraph(),
|
||||
op->name(),
|
||||
@ -524,7 +536,7 @@ void NodeToONNX(
|
||||
|
||||
processSymbolicOutput(op->name(), op, raw_output);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(PyObject_IsTrue(is_registered_op.ptr()));
|
||||
TORCH_INTERNAL_ASSERT(is_registered_op);
|
||||
Node* n = static_cast<Node*>(op);
|
||||
n->s_(attr::name, op->name());
|
||||
// Call symbolic function
|
||||
|
@ -9,6 +9,7 @@ from torch._C._onnx import (
|
||||
TensorProtoDataType,
|
||||
TrainingMode,
|
||||
)
|
||||
from torch.onnx._internal import registration as _registration
|
||||
|
||||
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
@ -25,7 +26,6 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
|
||||
symbolic_opset14,
|
||||
symbolic_opset15,
|
||||
symbolic_opset16,
|
||||
symbolic_registry,
|
||||
utils,
|
||||
)
|
||||
from ._exporter_states import ExportTypes, SymbolicContext
|
||||
@ -46,7 +46,6 @@ from .utils import (
|
||||
__all__ = [
|
||||
# Modules
|
||||
"symbolic_helper",
|
||||
"symbolic_registry",
|
||||
"utils",
|
||||
"errors",
|
||||
# All opsets
|
||||
@ -134,3 +133,6 @@ def log(*args) -> None:
|
||||
character appended to the end, and flushed to output stream.
|
||||
"""
|
||||
_C._jit_onnx_log(*args)
|
||||
|
||||
|
||||
_registration.discover_and_register_all_symbolic_opsets()
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
||||
|
||||
ONNX_BASE_OPSET = 9
|
||||
ONNX_MIN_OPSET = 7
|
||||
ONNX_MAX_OPSET = 17
|
||||
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
|
||||
|
300
torch/onnx/_internal/registration.py
Normal file
300
torch/onnx/_internal/registration.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""Module for handling symbolic function registration."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Collection, Dict, Generic, Optional, Set, TypeVar
|
||||
|
||||
from torch.onnx import _constants, errors
|
||||
|
||||
OpsetVersion = int
|
||||
|
||||
|
||||
def _dispatch_opset_version(
|
||||
target: OpsetVersion, registered_opsets: Collection[OpsetVersion]
|
||||
) -> Optional[OpsetVersion]:
|
||||
"""Finds the registered opset given a target opset version and the available opsets.
|
||||
|
||||
Args:
|
||||
target: The target opset version.
|
||||
available_opsets: The available opsets.
|
||||
|
||||
Returns:
|
||||
The registered opset version.
|
||||
"""
|
||||
if not registered_opsets:
|
||||
return None
|
||||
registered_versions = sorted(registered_opsets)
|
||||
# Linear search for the opset version, which is fine since the number of opset
|
||||
# versions is small.
|
||||
|
||||
# Always round toward opset 9 (ONNX_BASE_OPSET).
|
||||
# Count down until opset 9 is reached.
|
||||
for version in reversed(registered_versions):
|
||||
if _constants.ONNX_BASE_OPSET <= version <= target:
|
||||
return version
|
||||
|
||||
for version in registered_versions:
|
||||
# Count back up until _constants.ONNX_BASE_OPSET
|
||||
if target <= version <= _constants.ONNX_BASE_OPSET:
|
||||
return version
|
||||
|
||||
assert (
|
||||
not registered_versions
|
||||
or _constants.ONNX_BASE_OPSET <= target < registered_versions[0]
|
||||
or registered_versions[-1] < _constants.ONNX_BASE_OPSET <= target
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
_K = TypeVar("_K")
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class OverrideDict(Generic[_K, _V], Collection[_K]):
|
||||
"""A dictionary that merges built-in and custom symbolic functions.
|
||||
|
||||
It supports overriding and un-overriding built-in symbolic functions with custom
|
||||
ones.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._base: Dict[_K, _V] = {}
|
||||
self._overrides: Dict[_K, _V] = {}
|
||||
self._merged: Dict[_K, _V] = {}
|
||||
|
||||
def override(self, key: _K, value: _V) -> None:
|
||||
"""Overrides a base key-value with a new pair."""
|
||||
self._overrides[key] = value
|
||||
self._merged[key] = value
|
||||
|
||||
def remove_override(self, key: _K) -> None:
|
||||
"""Un-overrides a key-value pair."""
|
||||
self._overrides.pop(key, None) # type: ignore[arg-type]
|
||||
self._merged.pop(key, None) # type: ignore[arg-type]
|
||||
if key in self._base:
|
||||
self._merged[key] = self._base[key]
|
||||
|
||||
def overridden(self, key: _K) -> bool:
|
||||
"""Checks if a key-value pair is overridden."""
|
||||
return key in self._overrides
|
||||
|
||||
def in_base(self, key: _K) -> bool:
|
||||
"""Checks if a key is in the base dictionary."""
|
||||
return key in self._base
|
||||
|
||||
def __getitem__(self, key: _K) -> _V:
|
||||
return self._merged[key]
|
||||
|
||||
def set_base(self, key: _K, value: _V) -> None:
|
||||
self._base[key] = value
|
||||
if key not in self._overrides:
|
||||
self._merged[key] = value
|
||||
|
||||
def get(self, key: _K, default: Optional[_V] = None):
|
||||
return self._merged.get(key, default)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self._merged
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._merged)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._merged)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OverrideDict(base={self._base}, overrides={self._overrides})"
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._merged)
|
||||
|
||||
|
||||
class _SymbolicFunctionGroup:
|
||||
"""Different versions of symbolic functions registered to the same name.
|
||||
|
||||
O(number of registered versions of an op) search is performed to find the most
|
||||
recent version of the op.
|
||||
|
||||
The registration is delayed until op is used to improve startup time.
|
||||
|
||||
Function overloads with different arguments are not allowed.
|
||||
Custom op overrides are supported.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self._name = name
|
||||
# A dictionary of functions, keyed by the opset version.
|
||||
self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})"
|
||||
|
||||
def __getitem__(self, key: OpsetVersion) -> Callable:
|
||||
result = self.get(key)
|
||||
if result is None:
|
||||
raise KeyError(key)
|
||||
return result
|
||||
|
||||
# TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes
|
||||
# a problem.
|
||||
def get(self, opset: OpsetVersion) -> Optional[Callable]:
|
||||
"""Find the most recent version of the function."""
|
||||
version = _dispatch_opset_version(opset, self._functions)
|
||||
if version is None:
|
||||
return None
|
||||
|
||||
return self._functions[version]
|
||||
|
||||
def add(self, func: Callable, opset: OpsetVersion) -> None:
|
||||
"""Adds a symbolic function.
|
||||
|
||||
Args:
|
||||
func: The function to add.
|
||||
opset: The opset version of the function to add.
|
||||
"""
|
||||
if self._functions.in_base(opset):
|
||||
warnings.warn(
|
||||
f"Symbolic function '{self._name}' already registered for opset {opset}. "
|
||||
f"Replacing the existing function with new function. This is unexpected. "
|
||||
f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
|
||||
errors.OnnxExporterWarning,
|
||||
)
|
||||
self._functions.set_base(opset, func)
|
||||
|
||||
def add_custom(self, func: Callable, opset: OpsetVersion) -> None:
|
||||
"""Adds a custom symbolic function.
|
||||
|
||||
Args:
|
||||
func: The symbolic function to register.
|
||||
opset: The corresponding opset version.
|
||||
"""
|
||||
self._functions.override(opset, func)
|
||||
|
||||
def remove_custom(self, opset: OpsetVersion) -> None:
|
||||
"""Removes a custom symbolic function.
|
||||
|
||||
Args:
|
||||
opset: The opset version of the custom function to remove.
|
||||
"""
|
||||
if not self._functions.overridden(opset):
|
||||
warnings.warn(
|
||||
f"No custom function registered for '{self._name}' opset {opset}"
|
||||
)
|
||||
return
|
||||
self._functions.remove_override(opset)
|
||||
|
||||
def get_min_supported(self) -> OpsetVersion:
|
||||
"""Returns the lowest built-in opset version supported by the function."""
|
||||
return min(self._functions)
|
||||
|
||||
|
||||
class SymbolicRegistry:
|
||||
"""Registry for symbolic functions.
|
||||
|
||||
The registry maintains a mapping from qualified names to symbolic functions.
|
||||
It is used to register new symbolic functions and to dispatch calls to
|
||||
the appropriate function.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._registry: Dict[str, _SymbolicFunctionGroup] = {}
|
||||
|
||||
def register(
|
||||
self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False
|
||||
) -> None:
|
||||
"""Registers a symbolic function.
|
||||
|
||||
Args:
|
||||
name: The qualified name of the function to register. In the form of 'domain::op'.
|
||||
E.g. 'aten::add'.
|
||||
opset: The opset version of the function to register.
|
||||
func: The symbolic function to register.
|
||||
custom: Whether the function is a custom function that overrides existing ones.
|
||||
"""
|
||||
if "::" not in name:
|
||||
raise ValueError(
|
||||
f"The name must be in the form of 'domain::op', not '{name}'"
|
||||
)
|
||||
symbolic_functions = self._registry.setdefault(
|
||||
name, _SymbolicFunctionGroup(name)
|
||||
)
|
||||
if custom:
|
||||
symbolic_functions.add_custom(func, opset)
|
||||
else:
|
||||
symbolic_functions.add(func, opset)
|
||||
|
||||
def unregister(self, name: str, opset: OpsetVersion) -> None:
|
||||
"""Unregisters a symbolic function.
|
||||
|
||||
Args:
|
||||
name: The qualified name of the function to unregister.
|
||||
opset: The opset version of the function to unregister.
|
||||
"""
|
||||
if name not in self._registry:
|
||||
return
|
||||
self._registry[name].remove_custom(opset)
|
||||
|
||||
def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]:
|
||||
"""Returns the function group for the given name."""
|
||||
return self._registry.get(name)
|
||||
|
||||
def is_registered_op(self, name: str, version: int) -> bool:
|
||||
"""Returns whether the given op is registered for the given opset version."""
|
||||
functions = self.get_function_group(name)
|
||||
if functions is None:
|
||||
return False
|
||||
return functions.get(version) is not None
|
||||
|
||||
def all_functions(self) -> Set[str]:
|
||||
"""Returns the set of all registered function names."""
|
||||
return set(self._registry)
|
||||
|
||||
|
||||
def discover_and_register_all_symbolic_opsets() -> None:
|
||||
"""Discover all symbolic functions.
|
||||
|
||||
Opset 9 is the base version. It is selected as the base version because
|
||||
1. It is the first opset version supported by PyTorch export.
|
||||
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
||||
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
||||
we chose to handle them as special cases separately.
|
||||
|
||||
Backward support for opset versions beyond opset 7 is not in our roadmap.
|
||||
|
||||
For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
||||
symbolic_opset9.py.
|
||||
|
||||
To extend support for updated operators in different opset versions on top of opset 9,
|
||||
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
||||
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
||||
"""
|
||||
for opset in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
|
||||
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset}")
|
||||
_register_module(module, opset)
|
||||
|
||||
|
||||
def _register_module(module, opset: OpsetVersion) -> None:
|
||||
"""Registers all functions in the given module.
|
||||
|
||||
Args:
|
||||
module: The module to register.
|
||||
opset: The opset version to register.
|
||||
"""
|
||||
global registry
|
||||
members = inspect.getmembers(module)
|
||||
for name, obj in members:
|
||||
if isinstance(obj, type) and hasattr(obj, "domain"):
|
||||
# Symbolic functions in domains other than aten
|
||||
ops = inspect.getmembers(obj, predicate=inspect.isfunction)
|
||||
for op in ops:
|
||||
registry.register(f"{obj.domain}::{op[0]}", opset, op[1]) # type: ignore[attr-defined]
|
||||
|
||||
elif inspect.isfunction(obj):
|
||||
if name in {"_len", "_list", "_any", "_all"}:
|
||||
name = name[1:]
|
||||
registry.register(f"aten::{name}", opset, obj)
|
||||
|
||||
|
||||
# The registry for all symbolic functions.
|
||||
registry = SymbolicRegistry()
|
@ -2,10 +2,8 @@ import inspect
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, symbolic_registry
|
||||
|
||||
for v in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
|
||||
symbolic_registry.register_version("", v)
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal import registration
|
||||
|
||||
|
||||
class _TorchSchema:
|
||||
@ -74,18 +72,27 @@ def _symbolic_argument_count(func):
|
||||
return params
|
||||
|
||||
|
||||
def _all_symbolics_schemas():
|
||||
symbolics_schemas: Dict[str, _TorchSchema] = {}
|
||||
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
||||
symbolics_schemas = {}
|
||||
|
||||
for name in registration.registry.all_functions():
|
||||
func_group = registration.registry.get_function_group(name)
|
||||
assert func_group is not None
|
||||
symbolics_schema = _TorchSchema(name)
|
||||
func = func_group.get(_constants.ONNX_MAX_OPSET)
|
||||
if func is not None:
|
||||
symbolics_schema.arguments = _symbolic_argument_count(func)
|
||||
symbolics_schema.opsets = list(
|
||||
range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
|
||||
)
|
||||
else:
|
||||
# Only support opset < 9
|
||||
func = func_group.get(7)
|
||||
symbolics_schema.arguments = _symbolic_argument_count(func)
|
||||
symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
|
||||
|
||||
symbolics_schemas[name] = symbolics_schema
|
||||
|
||||
for domain, version in symbolic_registry._registry:
|
||||
for opname, sym_func in symbolic_registry._registry[(domain, version)].items():
|
||||
symbolics_schema = _TorchSchema("aten::" + opname)
|
||||
symbolics_schema.arguments = _symbolic_argument_count(sym_func)
|
||||
if opname in symbolics_schemas:
|
||||
symbolics_schemas[opname].opsets.append(version)
|
||||
else:
|
||||
symbolics_schema.opsets = [version]
|
||||
symbolics_schemas[opname] = symbolics_schema
|
||||
return symbolics_schemas
|
||||
|
||||
|
||||
@ -97,7 +104,7 @@ def onnx_supported_ops():
|
||||
onnx_supported = []
|
||||
for schema in aten_schemas:
|
||||
if schema in torch_schemas:
|
||||
opname = schema.name[6:] # without "aten::" prefix
|
||||
opname = schema.name
|
||||
opsets = symbolic_schemas[opname].opsets
|
||||
if schema not in supported_ops:
|
||||
supported_ops.append(symbolic_schemas[opname])
|
||||
|
@ -1,38 +1,38 @@
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9, symbolic_registry
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import registration
|
||||
|
||||
|
||||
def register_quantized_ops(domain: str, version: int):
|
||||
# Register all the non-quantized ops
|
||||
symbolic_registry.register_version("", version)
|
||||
# Register all quantized ops
|
||||
module = importlib.import_module("torch.onnx.symbolic_caffe2")
|
||||
symbolic_registry._symbolic_versions["caffe2"] = module
|
||||
quant_version_ops = inspect.getmembers(
|
||||
symbolic_registry._symbolic_versions["caffe2"]
|
||||
)
|
||||
for op in quant_version_ops:
|
||||
if inspect.isfunction(op[1]) and not symbolic_registry.is_registered_op(
|
||||
op[0], domain, version
|
||||
quant_version_ops = inspect.getmembers(module)
|
||||
aten_q_ops = {
|
||||
"relu",
|
||||
"_empty_affine_quantized",
|
||||
"dequantize",
|
||||
"quantize_per_tensor",
|
||||
"upsample_nearest2d",
|
||||
"avg_pool2d",
|
||||
"reshape",
|
||||
"slice",
|
||||
"cat",
|
||||
"max_pool2d",
|
||||
"sigmoid",
|
||||
}
|
||||
for op, func in quant_version_ops:
|
||||
name = f"{domain}::{op}"
|
||||
if inspect.isfunction(func) and not registration.registry.is_registered_op(
|
||||
name, version
|
||||
):
|
||||
aten_q_ops = [
|
||||
"relu",
|
||||
"_empty_affine_quantized",
|
||||
"dequantize",
|
||||
"quantize_per_tensor",
|
||||
"upsample_nearest2d",
|
||||
"avg_pool2d",
|
||||
"reshape",
|
||||
"slice",
|
||||
"cat",
|
||||
"max_pool2d",
|
||||
"sigmoid",
|
||||
]
|
||||
if op[0] in aten_q_ops:
|
||||
symbolic_registry.register_op(op[0], op[1], "", version)
|
||||
symbolic_registry.register_op(op[0], op[1], domain, version)
|
||||
if op in aten_q_ops:
|
||||
# Override the builtin aten ops
|
||||
registration.registry.register(
|
||||
f"aten::{op}", version, func, custom=True
|
||||
)
|
||||
registration.registry.register(name, version, func)
|
||||
|
||||
|
||||
def _permute_helper(g, input, axes):
|
||||
|
@ -23,6 +23,8 @@ from torch.onnx import symbolic_helper
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["layer_norm"]
|
||||
|
||||
|
||||
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
|
||||
def layer_norm(
|
||||
|
@ -1,167 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, errors
|
||||
|
||||
__all__ = [
|
||||
"get_op_supported_version",
|
||||
"get_ops_in_version",
|
||||
"get_registered_op",
|
||||
"is_registered_op",
|
||||
"is_registered_version",
|
||||
"register_op",
|
||||
"register_ops_helper",
|
||||
"register_ops_in_version",
|
||||
"register_version",
|
||||
"unregister_op",
|
||||
]
|
||||
|
||||
_SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
|
||||
|
||||
"""
|
||||
The symbolic registry "_registry" is a dictionary that maps operators
|
||||
(for a specific domain and opset version) to their symbolic functions.
|
||||
An operator is defined by its domain, opset version, and opname.
|
||||
The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
||||
and the operator's name (string).
|
||||
The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
||||
"""
|
||||
_registry: Dict[
|
||||
Tuple[str, int],
|
||||
Dict[str, _SymbolicFunction],
|
||||
] = {}
|
||||
|
||||
_symbolic_versions: Dict[Union[int, str], Any] = {}
|
||||
|
||||
|
||||
def _import_symbolic_opsets():
|
||||
for opset_version in range(
|
||||
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
|
||||
):
|
||||
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset_version}")
|
||||
global _symbolic_versions
|
||||
_symbolic_versions[opset_version] = module
|
||||
|
||||
|
||||
def register_version(domain: str, version: int):
|
||||
if not is_registered_version(domain, version):
|
||||
global _registry
|
||||
_registry[(domain, version)] = {}
|
||||
register_ops_in_version(domain, version)
|
||||
|
||||
|
||||
def register_ops_helper(domain: str, version: int, iter_version: int):
|
||||
for domain, op_name, op_func in get_ops_in_version(iter_version):
|
||||
if not is_registered_op(op_name, domain, version):
|
||||
register_op(op_name, op_func, domain, version)
|
||||
|
||||
|
||||
def register_ops_in_version(domain: str, version: int):
|
||||
"""Iterates through the symbolic functions of the specified opset version, and the
|
||||
previous opset versions for operators supported in previous versions.
|
||||
|
||||
Opset 9 is the base version. It is selected as the base version because
|
||||
1. It is the first opset version supported by PyTorch export.
|
||||
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
||||
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
||||
we chose to handle them as special cases separately.
|
||||
|
||||
Backward support for opset versions beyond opset 7 is not in our roadmap.
|
||||
|
||||
For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
||||
symbolic_opset9.py.
|
||||
|
||||
To extend support for updated operators in different opset versions on top of opset 9,
|
||||
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
||||
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
||||
"""
|
||||
iter_version = version
|
||||
while iter_version != 9:
|
||||
register_ops_helper(domain, version, iter_version)
|
||||
if iter_version > 9:
|
||||
iter_version = iter_version - 1
|
||||
else:
|
||||
iter_version = iter_version + 1
|
||||
|
||||
register_ops_helper(domain, version, 9)
|
||||
|
||||
|
||||
def get_ops_in_version(version: int):
|
||||
if not _symbolic_versions:
|
||||
_import_symbolic_opsets()
|
||||
members = inspect.getmembers(_symbolic_versions[version])
|
||||
domain_opname_ops = []
|
||||
for obj in members:
|
||||
if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
|
||||
ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
|
||||
for op in ops:
|
||||
domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
|
||||
|
||||
elif inspect.isfunction(obj[1]):
|
||||
if obj[0] == "_len":
|
||||
obj = ("len", obj[1])
|
||||
if obj[0] == "_list":
|
||||
obj = ("list", obj[1])
|
||||
if obj[0] == "_any":
|
||||
obj = ("any", obj[1])
|
||||
if obj[0] == "_all":
|
||||
obj = ("all", obj[1])
|
||||
domain_opname_ops.append(("", obj[0], obj[1]))
|
||||
return domain_opname_ops
|
||||
|
||||
|
||||
def is_registered_version(domain: str, version: int):
|
||||
global _registry
|
||||
return (domain, version) in _registry
|
||||
|
||||
|
||||
def register_op(opname, op, domain, version):
|
||||
if domain is None or version is None:
|
||||
warnings.warn(
|
||||
"ONNX export failed. The ONNX domain and/or version to register are None."
|
||||
)
|
||||
global _registry
|
||||
if not is_registered_version(domain, version):
|
||||
_registry[(domain, version)] = {}
|
||||
_registry[(domain, version)][opname] = op
|
||||
|
||||
|
||||
def is_registered_op(opname: str, domain: str, version: int):
|
||||
if domain is None or version is None:
|
||||
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
||||
global _registry
|
||||
return (domain, version) in _registry and opname in _registry[(domain, version)]
|
||||
|
||||
|
||||
def unregister_op(opname: str, domain: str, version: int):
|
||||
global _registry
|
||||
if is_registered_op(opname, domain, version):
|
||||
del _registry[(domain, version)][opname]
|
||||
if not _registry[(domain, version)]:
|
||||
del _registry[(domain, version)]
|
||||
else:
|
||||
warnings.warn("The opname " + opname + " is not registered.")
|
||||
|
||||
|
||||
def get_op_supported_version(opname: str, domain: str, version: int):
|
||||
iter_version = version
|
||||
while iter_version <= _constants.ONNX_MAX_OPSET:
|
||||
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
|
||||
if (domain, opname) in ops:
|
||||
return iter_version
|
||||
iter_version += 1
|
||||
return None
|
||||
|
||||
|
||||
def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
|
||||
if domain is None or version is None:
|
||||
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
||||
global _registry
|
||||
if not is_registered_op(opname, domain, version):
|
||||
raise errors.UnsupportedOperatorError(
|
||||
domain, opname, version, get_op_supported_version(opname, domain, version)
|
||||
)
|
||||
return _registry[(domain, version)][opname]
|
@ -43,10 +43,9 @@ from torch.onnx import ( # noqa: F401
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
symbolic_registry,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal import _beartype, registration
|
||||
|
||||
__all__ = [
|
||||
"is_in_onnx_export",
|
||||
@ -59,7 +58,6 @@ __all__ = [
|
||||
"unpack_quantized_tensor",
|
||||
"export_to_pretty_string",
|
||||
"unconvertible_ops",
|
||||
"get_ns_op_name_from_custom_op",
|
||||
"register_custom_op_symbolic",
|
||||
"unregister_custom_op_symbolic",
|
||||
]
|
||||
@ -1279,7 +1277,7 @@ def unconvertible_ops(
|
||||
operator_export_type=_C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
)
|
||||
unsupported_ops = list()
|
||||
supported_namespaces = ("onnx", "prim", "quantized")
|
||||
supported_namespaces = {"onnx", "prim", "quantized"}
|
||||
for node in graph.nodes():
|
||||
if node.kind().split(":")[0] not in supported_namespaces:
|
||||
unsupported_ops.append(node.kind())
|
||||
@ -1690,37 +1688,10 @@ def _add_output_to_block(block: _C.Block, value: _C.Value):
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_symbolic_in_registry(
|
||||
domain: str,
|
||||
op_name: str,
|
||||
opset_version: int,
|
||||
operator_export_type: _C_onnx.OperatorExportTypes,
|
||||
) -> Optional[Callable]:
|
||||
"""Looks up for the symbolic function in the registry.
|
||||
|
||||
Args:
|
||||
domain: The domain of the symbolic function.
|
||||
op_name: The name of the op.
|
||||
opset_version: Currect opset used.
|
||||
operator_export_type: An enum in _C_onnx.OperatorExportTypes.
|
||||
|
||||
Returns:
|
||||
The symbolic function if found, None otherwise.
|
||||
"""
|
||||
|
||||
if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
|
||||
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
|
||||
# Use the original node directly
|
||||
return None
|
||||
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
|
||||
|
||||
is_exportable_aten_op = symbolic_registry.is_registered_op(
|
||||
op_name, "", opset_version
|
||||
)
|
||||
def _should_aten_fallback(
|
||||
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
|
||||
):
|
||||
is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
|
||||
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
|
||||
is_aten_fallback_export = (
|
||||
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
@ -1787,64 +1758,59 @@ def _run_symbolic_function(
|
||||
namespace, op_name = ns_op_name.split("::")
|
||||
|
||||
try:
|
||||
symbolic_registry.register_version("", opset_version)
|
||||
|
||||
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
||||
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
|
||||
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
||||
|
||||
if namespace == "aten":
|
||||
domain = ""
|
||||
elif namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
|
||||
if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
|
||||
domain = "caffe2"
|
||||
else:
|
||||
domain = namespace
|
||||
symbolic_function_name = f"{domain}::{op_name}"
|
||||
|
||||
if symbolic_registry.is_registered_op(op_name, domain, opset_version):
|
||||
symbolic_fn = _find_symbolic_in_registry(
|
||||
domain, op_name, opset_version, operator_export_type
|
||||
)
|
||||
assert symbolic_fn is not None
|
||||
symbolic_function_group = registration.registry.get_function_group(
|
||||
symbolic_function_name
|
||||
)
|
||||
if symbolic_function_group is not None:
|
||||
symbolic_fn = symbolic_function_group.get(opset_version)
|
||||
if symbolic_fn is not None:
|
||||
attrs = {k: symbolic_helper._node_get(n, k) for k in n.attributeNames()}
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
# TODO(justinchuby): Refactor how we check for the need of the symbolic context
|
||||
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
|
||||
return symbolic_fn(ctx, g, *inputs, **attrs)
|
||||
# PythonOp symbolic need access to the node to resolve the name conflict,
|
||||
# this is inconsistent with regular op symbolic.
|
||||
if op_name == "PythonOp":
|
||||
inputs = (n, *inputs)
|
||||
return symbolic_fn(g, *inputs, **attrs)
|
||||
|
||||
attrs = {k: symbolic_helper._node_get(n, k) for k in n.attributeNames()}
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
|
||||
return symbolic_fn(ctx, g, *inputs, **attrs)
|
||||
# PythonOp symbolic need access to the node to resolve the name conflict,
|
||||
# this is inconsistent with regular op symbolic.
|
||||
if op_name == "PythonOp":
|
||||
inputs = (n, *inputs)
|
||||
return symbolic_fn(g, *inputs, **attrs)
|
||||
elif namespace == "onnx":
|
||||
attrs = {
|
||||
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
|
||||
for k in n.attributeNames()
|
||||
}
|
||||
if namespace == "onnx":
|
||||
# Clone node to trigger ONNX shape inference
|
||||
attrs = {
|
||||
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
|
||||
for k in n.attributeNames()
|
||||
}
|
||||
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) # type: ignore[attr-defined]
|
||||
elif _should_aten_fallback(
|
||||
namespace, op_name, opset_version, operator_export_type
|
||||
):
|
||||
|
||||
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
|
||||
# Direct ATen export requested
|
||||
attrs = {
|
||||
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
|
||||
for k in n.attributeNames()
|
||||
}
|
||||
outputs = n.outputsSize()
|
||||
attrs["outputs"] = outputs
|
||||
# `overload_name` is set for non-Caffe2 builds only
|
||||
return g.at( # type: ignore[attr-defined]
|
||||
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
|
||||
)
|
||||
else:
|
||||
raise errors.UnsupportedOperatorError(
|
||||
domain,
|
||||
op_name,
|
||||
opset_version,
|
||||
symbolic_registry.get_op_supported_version(
|
||||
op_name, domain, opset_version
|
||||
),
|
||||
)
|
||||
|
||||
raise errors.UnsupportedOperatorError(
|
||||
domain,
|
||||
op_name,
|
||||
opset_version,
|
||||
symbolic_function_group.get_min_supported()
|
||||
if symbolic_function_group
|
||||
else None,
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
|
||||
return None
|
||||
@ -1869,31 +1835,26 @@ def _run_symbolic_function(
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def get_ns_op_name_from_custom_op(symbolic_name):
|
||||
if not bool(
|
||||
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Failed to register operator {symbolic_name}."
|
||||
"The symbolic name must match the format Domain::Name, "
|
||||
def _verify_custom_op_name(symbolic_name: str):
|
||||
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
|
||||
raise errors.OnnxExporterError(
|
||||
f"Failed to register operator {symbolic_name}. "
|
||||
"The symbolic name must match the format domain::name, "
|
||||
"and should start with a letter and contain only "
|
||||
"alphanumerical characters"
|
||||
)
|
||||
|
||||
ns, op_name = symbolic_name.split("::")
|
||||
ns, _ = symbolic_name.split("::")
|
||||
if ns == "onnx":
|
||||
raise ValueError(
|
||||
f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
|
||||
)
|
||||
|
||||
if ns == "aten":
|
||||
ns = ""
|
||||
|
||||
return ns, op_name
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
|
||||
def register_custom_op_symbolic(
|
||||
symbolic_name: str, symbolic_fn: Callable, opset_version: int
|
||||
):
|
||||
"""Registers a symbolic function for a custom operator.
|
||||
|
||||
When the user registers symbolic for custom/contrib ops,
|
||||
@ -1911,11 +1872,16 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
|
||||
operator nodes to add to the graph.
|
||||
opset_version (int): The ONNX opset version in which to register.
|
||||
"""
|
||||
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
||||
if symbolic_name.startswith("::"):
|
||||
symbolic_name = f"aten{symbolic_name}"
|
||||
|
||||
_verify_custom_op_name(symbolic_name)
|
||||
|
||||
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
|
||||
if version >= opset_version:
|
||||
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
|
||||
registration.registry.register(
|
||||
symbolic_name, version, symbolic_fn, custom=True
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
@ -1929,11 +1895,14 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
format.
|
||||
opset_version (int): The ONNX opset version in which to unregister.
|
||||
"""
|
||||
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
||||
if symbolic_name.startswith("::"):
|
||||
symbolic_name = f"aten{symbolic_name}"
|
||||
|
||||
_verify_custom_op_name(symbolic_name)
|
||||
|
||||
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
|
||||
if version >= opset_version:
|
||||
symbolic_registry.unregister_op(op_name, ns, version)
|
||||
registration.registry.unregister(symbolic_name, version)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
|
Reference in New Issue
Block a user