[ONNX] New symbolic function registry (#84382)

## Summary

The change brings the new registry for symbolic functions in ONNX. The `SymbolicRegistry` class in `torch.onnx._internal.registration` replaces the dictionary and various functions defined in `torch.onnx.symbolic_registry`.

The new registry

- Has faster lookup by storing only functions in the opset version they are defined in
- Is easier to manage and interact with due to its class design
- Builds the foundation for the more flexible registration process detailed in #83787

Implementation changes

- **Breaking**: Remove `torch.onnx.symbolic_registry`
- `register_custom_op_symbolic` and `unregister_custom_op_symbolic` in utils maintain their api for compatibility
- Update _onnx_supported_ops.py for doc generation to include quantized ops.
- Update code to register python ops in `torch/csrc/jit/passes/onnx.cpp`

## Profiling results

-0.1 seconds in execution time. -34% time spent in `_run_symbolic_function`. Tested on the alexnet example in public doc.

### After
```
   └─ 1.641 export  <@beartype(torch.onnx.utils.export) at 0x7f19be17f790>:1
      └─ 1.641 export  torch/onnx/utils.py:185
         └─ 1.640 _export  torch/onnx/utils.py:1331
            ├─ 0.889 _model_to_graph  torch/onnx/utils.py:1005
            │  ├─ 0.478 _optimize_graph  torch/onnx/utils.py:535
            │  │  ├─ 0.214 PyCapsule._jit_pass_onnx_graph_shape_type_inference  <built-in>:0
            │  │  │     [2 frames hidden]  <built-in>
            │  │  ├─ 0.190 _run_symbolic_function  torch/onnx/utils.py:1670
            │  │  │  └─ 0.145 Constant  torch/onnx/symbolic_opset9.py:5782
            │  │  │     └─ 0.139 _graph_op  torch/onnx/_patch_torch.py:18
            │  │  │        └─ 0.134 PyCapsule._jit_pass_onnx_node_shape_type_inference  <built-in>:0
            │  │  │              [2 frames hidden]  <built-in>
            │  │  └─ 0.033 [self]
```

### Before
![image](https://user-images.githubusercontent.com/11205048/188032302-688d881e-860d-4046-bdba-90da54233576.png)

### Start up time

The startup process takes 0.03 seconds. Calls to `inspect` will be eliminated when we switch to using decorators for registration in #84448

![image](https://user-images.githubusercontent.com/11205048/188208910-250f0434-475d-4872-9abc-781535519305.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84382
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
This commit is contained in:
Justin Chu
2022-09-16 17:30:24 +00:00
committed by PyTorch MergeBot
parent 735154354b
commit cd7e6d4ad1
11 changed files with 616 additions and 323 deletions

View File

@ -0,0 +1,168 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""
from typing import Sequence
from torch.onnx._internal import registration
from torch.testing._internal import common_utils
@common_utils.instantiate_parametrized_tests
class TestGlobalHelpers(common_utils.TestCase):
@common_utils.parametrize(
"available_opsets, target, expected",
[
((7, 8, 9, 10, 11), 16, 11),
((7, 8, 9, 10, 11), 11, 11),
((7, 8, 9, 10, 11), 10, 10),
((7, 8, 9, 10, 11), 9, 9),
((7, 8, 9, 10, 11), 8, 8),
((7, 8, 9, 10, 11), 7, 7),
((9, 10, 16), 16, 16),
((9, 10, 16), 15, 10),
((9, 10, 16), 10, 10),
((9, 10, 16), 9, 9),
((9, 10, 16), 8, 9),
((9, 10, 16), 7, 9),
((7, 9, 10, 16), 16, 16),
((7, 9, 10, 16), 10, 10),
((7, 9, 10, 16), 9, 9),
((7, 9, 10, 16), 8, 9),
((7, 9, 10, 16), 7, 7),
([17], 16, None), # New op added in 17
([9], 9, 9),
([9], 8, 9),
([], 16, None),
([], 9, None),
([], 8, None),
([8], 16, None), # Ops lower than 9 are not supported by versions >= 9
],
)
def test_dispatch_opset_version_returns_correct_version(
self, available_opsets: Sequence[int], target: int, expected: int
):
actual = registration._dispatch_opset_version(target, available_opsets)
self.assertEqual(actual, expected)
class TestOverrideDict(common_utils.TestCase):
def setUp(self):
self.override_dict: registration.OverrideDict[
str, int
] = registration.OverrideDict()
def test_get_item_returns_base_value_when_no_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.assertEqual(self.override_dict["a"], 42)
self.assertEqual(self.override_dict["b"], 0)
self.assertEqual(len(self.override_dict), 2)
def test_get_item_returns_overridden_value_when_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(self.override_dict["b"], 0)
self.assertEqual(self.override_dict["c"], 1)
self.assertEqual(len(self.override_dict), 3)
def test_get_item_raises_key_error_when_not_found(self):
self.override_dict.set_base("a", 42)
with self.assertRaises(KeyError):
self.override_dict["nonexistent_key"]
def test_get_returns_overridden_value_when_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict.get("a"), 100)
self.assertEqual(self.override_dict.get("b"), 0)
self.assertEqual(self.override_dict.get("c"), 1)
self.assertEqual(len(self.override_dict), 3)
def test_get_returns_none_when_not_found(self):
self.override_dict.set_base("a", 42)
self.assertEqual(self.override_dict.get("nonexistent_key"), None)
def test_in_base_returns_true_for_base_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertIn("a", self.override_dict)
self.assertIn("b", self.override_dict)
self.assertIn("c", self.override_dict)
self.assertTrue(self.override_dict.in_base("a"))
self.assertTrue(self.override_dict.in_base("b"))
self.assertFalse(self.override_dict.in_base("c"))
self.assertFalse(self.override_dict.in_base("nonexistent_key"))
def test_overridden_returns_true_for_overridden_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertTrue(self.override_dict.overridden("a"))
self.assertFalse(self.override_dict.overridden("b"))
self.assertTrue(self.override_dict.overridden("c"))
self.assertFalse(self.override_dict.overridden("nonexistent_key"))
def test_remove_override_removes_overridden_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(self.override_dict["c"], 1)
self.override_dict.remove_override("a")
self.override_dict.remove_override("c")
self.assertEqual(self.override_dict["a"], 42)
self.assertEqual(self.override_dict.get("c"), None)
self.assertFalse(self.override_dict.overridden("a"))
self.assertFalse(self.override_dict.overridden("c"))
def test_remove_override_removes_overridden_key(self):
self.override_dict.override("a", 100)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(len(self.override_dict), 1)
self.override_dict.remove_override("a")
self.assertEqual(len(self.override_dict), 0)
self.assertNotIn("a", self.override_dict)
def test_overriden_key_precededs_base_key_regardless_of_insert_order(self):
self.override_dict.set_base("a", 42)
self.override_dict.override("a", 100)
self.override_dict.set_base("a", 0)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(len(self.override_dict), 1)
def test_bool_is_true_when_not_empty(self):
if self.override_dict:
self.fail("OverrideDict should be false when empty")
self.override_dict.override("a", 1)
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
self.override_dict.set_base("a", 42)
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
self.override_dict.remove_override("a")
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -6,6 +6,7 @@ import contextlib
import io
import itertools
import unittest
import unittest.mock
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import onnx
@ -14,14 +15,15 @@ import onnx.numpy_helper
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, symbolic_registry, utils
from torch.onnx import symbolic_helper, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.testing._internal import common_utils
def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Tuple[torch.Tensor],
input: Union[torch.Tensor, Tuple[torch.Tensor]],
custom_ops: Optional[
Iterable[
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
@ -436,14 +438,11 @@ class TestONNXExport(common_utils.TestCase):
def forward(self, x):
return torch.clamp(x, min=-0.5, max=0.5)
def break_is_registered_op_api(opname, domain, version):
fake_missing_symbolics = ("clamp",)
if opname in fake_missing_symbolics:
return False
return (
(domain, version) in symbolic_registry._registry
and opname in symbolic_registry._registry[(domain, version)]
)
def break_is_registered_op_api(name):
fake_missing_symbolics = {"aten::clamp"}
if name in fake_missing_symbolics:
return None
return registration.registry.get_function_group(name)
# Force missing symbolic for well-known op using a mock
onnx_model = export_to_onnx(
@ -451,7 +450,7 @@ class TestONNXExport(common_utils.TestCase):
torch.randn(3, 4, requires_grad=True),
mocks=[
unittest.mock.patch(
"torch.onnx.symbolic_registry.is_registered_op",
"torch.onnx._internal.registration.registry.get_function_group",
side_effect=break_is_registered_op_api,
)
],

View File

@ -243,7 +243,8 @@ void NodeToONNX(
std::unordered_map<Value*, Value*>& env) {
py::object onnx = py::module::import("torch.onnx");
py::object onnx_globals = py::module::import("torch.onnx._globals");
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
py::object onnx_registration =
py::module::import("torch.onnx._internal.registration");
// Setup all the lambda helper functions.
@ -452,10 +453,13 @@ void NodeToONNX(
py::object opset_version =
onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
py::object is_registered_op = onnx_registry.attr("is_registered_op")(
"PythonOp", "prim", opset_version);
if (!py::hasattr(pyobj, "symbolic") &&
(!PyObject_IsTrue(is_registered_op.ptr()))) {
// NOTE(justinchuby): Call the internal registry to register the symbolic
// method defined in the module.
bool is_registered_op =
onnx_registration.attr("registry")
.attr("is_registered_op")("prim::PythonOp", opset_version)
.cast<bool>();
if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
// Inline the subgraph within the prim::PythonOp unless
// either of these conditions are satisfied
// 1. The torch.autograd.Function class of this node object has `symbolic`
@ -514,8 +518,16 @@ void NodeToONNX(
// Call the symbolic function
// Use a little trampoline function so we can give good error messages
// upon argument mismatch
onnx_registry.attr("register_op")(
op->name(), pyobj.attr("symbolic"), "", opset_version);
// Register as a custom operator
// TODO: Find a more elegant way to do this without having to touch
// internal Python modules.
// TODO(justinchuby): Define a namespace for these Python Ops.
onnx_registration.attr("registry")
.attr("register")(
"::" + op->name(),
opset_version,
pyobj.attr("symbolic"),
/* custom */ true);
py::object raw_output = onnx.attr("_run_symbolic_method")(
new_block->owningGraph(),
op->name(),
@ -524,7 +536,7 @@ void NodeToONNX(
processSymbolicOutput(op->name(), op, raw_output);
} else {
TORCH_INTERNAL_ASSERT(PyObject_IsTrue(is_registered_op.ptr()));
TORCH_INTERNAL_ASSERT(is_registered_op);
Node* n = static_cast<Node*>(op);
n->s_(attr::name, op->name());
// Call symbolic function

View File

@ -9,6 +9,7 @@ from torch._C._onnx import (
TensorProtoDataType,
TrainingMode,
)
from torch.onnx._internal import registration as _registration
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
_deprecation,
@ -25,7 +26,6 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_registry,
utils,
)
from ._exporter_states import ExportTypes, SymbolicContext
@ -46,7 +46,6 @@ from .utils import (
__all__ = [
# Modules
"symbolic_helper",
"symbolic_registry",
"utils",
"errors",
# All opsets
@ -134,3 +133,6 @@ def log(*args) -> None:
character appended to the end, and flushed to output stream.
"""
_C._jit_onnx_log(*args)
_registration.discover_and_register_all_symbolic_opsets()

View File

@ -2,6 +2,7 @@
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 17
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py

View File

@ -0,0 +1,300 @@
"""Module for handling symbolic function registration."""
import importlib
import inspect
import warnings
from typing import Callable, Collection, Dict, Generic, Optional, Set, TypeVar
from torch.onnx import _constants, errors
OpsetVersion = int
def _dispatch_opset_version(
target: OpsetVersion, registered_opsets: Collection[OpsetVersion]
) -> Optional[OpsetVersion]:
"""Finds the registered opset given a target opset version and the available opsets.
Args:
target: The target opset version.
available_opsets: The available opsets.
Returns:
The registered opset version.
"""
if not registered_opsets:
return None
registered_versions = sorted(registered_opsets)
# Linear search for the opset version, which is fine since the number of opset
# versions is small.
# Always round toward opset 9 (ONNX_BASE_OPSET).
# Count down until opset 9 is reached.
for version in reversed(registered_versions):
if _constants.ONNX_BASE_OPSET <= version <= target:
return version
for version in registered_versions:
# Count back up until _constants.ONNX_BASE_OPSET
if target <= version <= _constants.ONNX_BASE_OPSET:
return version
assert (
not registered_versions
or _constants.ONNX_BASE_OPSET <= target < registered_versions[0]
or registered_versions[-1] < _constants.ONNX_BASE_OPSET <= target
)
return None
_K = TypeVar("_K")
_V = TypeVar("_V")
class OverrideDict(Generic[_K, _V], Collection[_K]):
"""A dictionary that merges built-in and custom symbolic functions.
It supports overriding and un-overriding built-in symbolic functions with custom
ones.
"""
def __init__(self):
self._base: Dict[_K, _V] = {}
self._overrides: Dict[_K, _V] = {}
self._merged: Dict[_K, _V] = {}
def override(self, key: _K, value: _V) -> None:
"""Overrides a base key-value with a new pair."""
self._overrides[key] = value
self._merged[key] = value
def remove_override(self, key: _K) -> None:
"""Un-overrides a key-value pair."""
self._overrides.pop(key, None) # type: ignore[arg-type]
self._merged.pop(key, None) # type: ignore[arg-type]
if key in self._base:
self._merged[key] = self._base[key]
def overridden(self, key: _K) -> bool:
"""Checks if a key-value pair is overridden."""
return key in self._overrides
def in_base(self, key: _K) -> bool:
"""Checks if a key is in the base dictionary."""
return key in self._base
def __getitem__(self, key: _K) -> _V:
return self._merged[key]
def set_base(self, key: _K, value: _V) -> None:
self._base[key] = value
if key not in self._overrides:
self._merged[key] = value
def get(self, key: _K, default: Optional[_V] = None):
return self._merged.get(key, default)
def __contains__(self, key: object) -> bool:
return key in self._merged
def __iter__(self):
return iter(self._merged)
def __len__(self) -> int:
return len(self._merged)
def __repr__(self) -> str:
return f"OverrideDict(base={self._base}, overrides={self._overrides})"
def __bool__(self) -> bool:
return bool(self._merged)
class _SymbolicFunctionGroup:
"""Different versions of symbolic functions registered to the same name.
O(number of registered versions of an op) search is performed to find the most
recent version of the op.
The registration is delayed until op is used to improve startup time.
Function overloads with different arguments are not allowed.
Custom op overrides are supported.
"""
def __init__(self, name: str) -> None:
self._name = name
# A dictionary of functions, keyed by the opset version.
self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict()
def __repr__(self) -> str:
return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})"
def __getitem__(self, key: OpsetVersion) -> Callable:
result = self.get(key)
if result is None:
raise KeyError(key)
return result
# TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes
# a problem.
def get(self, opset: OpsetVersion) -> Optional[Callable]:
"""Find the most recent version of the function."""
version = _dispatch_opset_version(opset, self._functions)
if version is None:
return None
return self._functions[version]
def add(self, func: Callable, opset: OpsetVersion) -> None:
"""Adds a symbolic function.
Args:
func: The function to add.
opset: The opset version of the function to add.
"""
if self._functions.in_base(opset):
warnings.warn(
f"Symbolic function '{self._name}' already registered for opset {opset}. "
f"Replacing the existing function with new function. This is unexpected. "
f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
errors.OnnxExporterWarning,
)
self._functions.set_base(opset, func)
def add_custom(self, func: Callable, opset: OpsetVersion) -> None:
"""Adds a custom symbolic function.
Args:
func: The symbolic function to register.
opset: The corresponding opset version.
"""
self._functions.override(opset, func)
def remove_custom(self, opset: OpsetVersion) -> None:
"""Removes a custom symbolic function.
Args:
opset: The opset version of the custom function to remove.
"""
if not self._functions.overridden(opset):
warnings.warn(
f"No custom function registered for '{self._name}' opset {opset}"
)
return
self._functions.remove_override(opset)
def get_min_supported(self) -> OpsetVersion:
"""Returns the lowest built-in opset version supported by the function."""
return min(self._functions)
class SymbolicRegistry:
"""Registry for symbolic functions.
The registry maintains a mapping from qualified names to symbolic functions.
It is used to register new symbolic functions and to dispatch calls to
the appropriate function.
"""
def __init__(self) -> None:
self._registry: Dict[str, _SymbolicFunctionGroup] = {}
def register(
self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False
) -> None:
"""Registers a symbolic function.
Args:
name: The qualified name of the function to register. In the form of 'domain::op'.
E.g. 'aten::add'.
opset: The opset version of the function to register.
func: The symbolic function to register.
custom: Whether the function is a custom function that overrides existing ones.
"""
if "::" not in name:
raise ValueError(
f"The name must be in the form of 'domain::op', not '{name}'"
)
symbolic_functions = self._registry.setdefault(
name, _SymbolicFunctionGroup(name)
)
if custom:
symbolic_functions.add_custom(func, opset)
else:
symbolic_functions.add(func, opset)
def unregister(self, name: str, opset: OpsetVersion) -> None:
"""Unregisters a symbolic function.
Args:
name: The qualified name of the function to unregister.
opset: The opset version of the function to unregister.
"""
if name not in self._registry:
return
self._registry[name].remove_custom(opset)
def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]:
"""Returns the function group for the given name."""
return self._registry.get(name)
def is_registered_op(self, name: str, version: int) -> bool:
"""Returns whether the given op is registered for the given opset version."""
functions = self.get_function_group(name)
if functions is None:
return False
return functions.get(version) is not None
def all_functions(self) -> Set[str]:
"""Returns the set of all registered function names."""
return set(self._registry)
def discover_and_register_all_symbolic_opsets() -> None:
"""Discover all symbolic functions.
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
"""
for opset in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset}")
_register_module(module, opset)
def _register_module(module, opset: OpsetVersion) -> None:
"""Registers all functions in the given module.
Args:
module: The module to register.
opset: The opset version to register.
"""
global registry
members = inspect.getmembers(module)
for name, obj in members:
if isinstance(obj, type) and hasattr(obj, "domain"):
# Symbolic functions in domains other than aten
ops = inspect.getmembers(obj, predicate=inspect.isfunction)
for op in ops:
registry.register(f"{obj.domain}::{op[0]}", opset, op[1]) # type: ignore[attr-defined]
elif inspect.isfunction(obj):
if name in {"_len", "_list", "_any", "_all"}:
name = name[1:]
registry.register(f"aten::{name}", opset, obj)
# The registry for all symbolic functions.
registry = SymbolicRegistry()

View File

@ -2,10 +2,8 @@ import inspect
from typing import Dict, List, Union
from torch import _C
from torch.onnx import _constants, symbolic_registry
for v in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
symbolic_registry.register_version("", v)
from torch.onnx import _constants
from torch.onnx._internal import registration
class _TorchSchema:
@ -74,18 +72,27 @@ def _symbolic_argument_count(func):
return params
def _all_symbolics_schemas():
symbolics_schemas: Dict[str, _TorchSchema] = {}
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
symbolics_schemas = {}
for name in registration.registry.all_functions():
func_group = registration.registry.get_function_group(name)
assert func_group is not None
symbolics_schema = _TorchSchema(name)
func = func_group.get(_constants.ONNX_MAX_OPSET)
if func is not None:
symbolics_schema.arguments = _symbolic_argument_count(func)
symbolics_schema.opsets = list(
range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
)
else:
# Only support opset < 9
func = func_group.get(7)
symbolics_schema.arguments = _symbolic_argument_count(func)
symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
symbolics_schemas[name] = symbolics_schema
for domain, version in symbolic_registry._registry:
for opname, sym_func in symbolic_registry._registry[(domain, version)].items():
symbolics_schema = _TorchSchema("aten::" + opname)
symbolics_schema.arguments = _symbolic_argument_count(sym_func)
if opname in symbolics_schemas:
symbolics_schemas[opname].opsets.append(version)
else:
symbolics_schema.opsets = [version]
symbolics_schemas[opname] = symbolics_schema
return symbolics_schemas
@ -97,7 +104,7 @@ def onnx_supported_ops():
onnx_supported = []
for schema in aten_schemas:
if schema in torch_schemas:
opname = schema.name[6:] # without "aten::" prefix
opname = schema.name
opsets = symbolic_schemas[opname].opsets
if schema not in supported_ops:
supported_ops.append(symbolic_schemas[opname])

View File

@ -1,38 +1,38 @@
import importlib
import inspect
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9, symbolic_registry
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import registration
def register_quantized_ops(domain: str, version: int):
# Register all the non-quantized ops
symbolic_registry.register_version("", version)
# Register all quantized ops
module = importlib.import_module("torch.onnx.symbolic_caffe2")
symbolic_registry._symbolic_versions["caffe2"] = module
quant_version_ops = inspect.getmembers(
symbolic_registry._symbolic_versions["caffe2"]
)
for op in quant_version_ops:
if inspect.isfunction(op[1]) and not symbolic_registry.is_registered_op(
op[0], domain, version
quant_version_ops = inspect.getmembers(module)
aten_q_ops = {
"relu",
"_empty_affine_quantized",
"dequantize",
"quantize_per_tensor",
"upsample_nearest2d",
"avg_pool2d",
"reshape",
"slice",
"cat",
"max_pool2d",
"sigmoid",
}
for op, func in quant_version_ops:
name = f"{domain}::{op}"
if inspect.isfunction(func) and not registration.registry.is_registered_op(
name, version
):
aten_q_ops = [
"relu",
"_empty_affine_quantized",
"dequantize",
"quantize_per_tensor",
"upsample_nearest2d",
"avg_pool2d",
"reshape",
"slice",
"cat",
"max_pool2d",
"sigmoid",
]
if op[0] in aten_q_ops:
symbolic_registry.register_op(op[0], op[1], "", version)
symbolic_registry.register_op(op[0], op[1], domain, version)
if op in aten_q_ops:
# Override the builtin aten ops
registration.registry.register(
f"aten::{op}", version, func, custom=True
)
registration.registry.register(name, version, func)
def _permute_helper(g, input, axes):

View File

@ -23,6 +23,8 @@ from torch.onnx import symbolic_helper
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["layer_norm"]
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(

View File

@ -1,167 +0,0 @@
import importlib
import inspect
import warnings
from typing import Any, Callable, Dict, Tuple, Union
from torch import _C
from torch.onnx import _constants, errors
__all__ = [
"get_op_supported_version",
"get_ops_in_version",
"get_registered_op",
"is_registered_op",
"is_registered_version",
"register_op",
"register_ops_helper",
"register_ops_in_version",
"register_version",
"unregister_op",
]
_SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
"""
The symbolic registry "_registry" is a dictionary that maps operators
(for a specific domain and opset version) to their symbolic functions.
An operator is defined by its domain, opset version, and opname.
The keys are tuples (domain, version), (where domain is a string, and version is an int),
and the operator's name (string).
The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
"""
_registry: Dict[
Tuple[str, int],
Dict[str, _SymbolicFunction],
] = {}
_symbolic_versions: Dict[Union[int, str], Any] = {}
def _import_symbolic_opsets():
for opset_version in range(
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
):
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset_version}")
global _symbolic_versions
_symbolic_versions[opset_version] = module
def register_version(domain: str, version: int):
if not is_registered_version(domain, version):
global _registry
_registry[(domain, version)] = {}
register_ops_in_version(domain, version)
def register_ops_helper(domain: str, version: int, iter_version: int):
for domain, op_name, op_func in get_ops_in_version(iter_version):
if not is_registered_op(op_name, domain, version):
register_op(op_name, op_func, domain, version)
def register_ops_in_version(domain: str, version: int):
"""Iterates through the symbolic functions of the specified opset version, and the
previous opset versions for operators supported in previous versions.
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
"""
iter_version = version
while iter_version != 9:
register_ops_helper(domain, version, iter_version)
if iter_version > 9:
iter_version = iter_version - 1
else:
iter_version = iter_version + 1
register_ops_helper(domain, version, 9)
def get_ops_in_version(version: int):
if not _symbolic_versions:
_import_symbolic_opsets()
members = inspect.getmembers(_symbolic_versions[version])
domain_opname_ops = []
for obj in members:
if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
for op in ops:
domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
elif inspect.isfunction(obj[1]):
if obj[0] == "_len":
obj = ("len", obj[1])
if obj[0] == "_list":
obj = ("list", obj[1])
if obj[0] == "_any":
obj = ("any", obj[1])
if obj[0] == "_all":
obj = ("all", obj[1])
domain_opname_ops.append(("", obj[0], obj[1]))
return domain_opname_ops
def is_registered_version(domain: str, version: int):
global _registry
return (domain, version) in _registry
def register_op(opname, op, domain, version):
if domain is None or version is None:
warnings.warn(
"ONNX export failed. The ONNX domain and/or version to register are None."
)
global _registry
if not is_registered_version(domain, version):
_registry[(domain, version)] = {}
_registry[(domain, version)][opname] = op
def is_registered_op(opname: str, domain: str, version: int):
if domain is None or version is None:
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
global _registry
return (domain, version) in _registry and opname in _registry[(domain, version)]
def unregister_op(opname: str, domain: str, version: int):
global _registry
if is_registered_op(opname, domain, version):
del _registry[(domain, version)][opname]
if not _registry[(domain, version)]:
del _registry[(domain, version)]
else:
warnings.warn("The opname " + opname + " is not registered.")
def get_op_supported_version(opname: str, domain: str, version: int):
iter_version = version
while iter_version <= _constants.ONNX_MAX_OPSET:
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
if (domain, opname) in ops:
return iter_version
iter_version += 1
return None
def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
if domain is None or version is None:
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
global _registry
if not is_registered_op(opname, domain, version):
raise errors.UnsupportedOperatorError(
domain, opname, version, get_op_supported_version(opname, domain, version)
)
return _registry[(domain, version)][opname]

View File

@ -43,10 +43,9 @@ from torch.onnx import ( # noqa: F401
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_registry,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype
from torch.onnx._internal import _beartype, registration
__all__ = [
"is_in_onnx_export",
@ -59,7 +58,6 @@ __all__ = [
"unpack_quantized_tensor",
"export_to_pretty_string",
"unconvertible_ops",
"get_ns_op_name_from_custom_op",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
]
@ -1279,7 +1277,7 @@ def unconvertible_ops(
operator_export_type=_C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
)
unsupported_ops = list()
supported_namespaces = ("onnx", "prim", "quantized")
supported_namespaces = {"onnx", "prim", "quantized"}
for node in graph.nodes():
if node.kind().split(":")[0] not in supported_namespaces:
unsupported_ops.append(node.kind())
@ -1690,37 +1688,10 @@ def _add_output_to_block(block: _C.Block, value: _C.Value):
@_beartype.beartype
def _find_symbolic_in_registry(
domain: str,
op_name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
) -> Optional[Callable]:
"""Looks up for the symbolic function in the registry.
Args:
domain: The domain of the symbolic function.
op_name: The name of the op.
opset_version: Currect opset used.
operator_export_type: An enum in _C_onnx.OperatorExportTypes.
Returns:
The symbolic function if found, None otherwise.
"""
if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
# Use the original node directly
return None
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
@_beartype.beartype
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
is_exportable_aten_op = symbolic_registry.is_registered_op(
op_name, "", opset_version
)
def _should_aten_fallback(
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
@ -1787,64 +1758,59 @@ def _run_symbolic_function(
namespace, op_name = ns_op_name.split("::")
try:
symbolic_registry.register_version("", opset_version)
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
if namespace == "aten":
domain = ""
elif namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
domain = "caffe2"
else:
domain = namespace
symbolic_function_name = f"{domain}::{op_name}"
if symbolic_registry.is_registered_op(op_name, domain, opset_version):
symbolic_fn = _find_symbolic_in_registry(
domain, op_name, opset_version, operator_export_type
)
assert symbolic_fn is not None
symbolic_function_group = registration.registry.get_function_group(
symbolic_function_name
)
if symbolic_function_group is not None:
symbolic_fn = symbolic_function_group.get(opset_version)
if symbolic_fn is not None:
attrs = {k: symbolic_helper._node_get(n, k) for k in n.attributeNames()}
if _need_symbolic_context(symbolic_fn):
# TODO(justinchuby): Refactor how we check for the need of the symbolic context
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
return symbolic_fn(ctx, g, *inputs, **attrs)
# PythonOp symbolic need access to the node to resolve the name conflict,
# this is inconsistent with regular op symbolic.
if op_name == "PythonOp":
inputs = (n, *inputs)
return symbolic_fn(g, *inputs, **attrs)
attrs = {k: symbolic_helper._node_get(n, k) for k in n.attributeNames()}
if _need_symbolic_context(symbolic_fn):
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
return symbolic_fn(ctx, g, *inputs, **attrs)
# PythonOp symbolic need access to the node to resolve the name conflict,
# this is inconsistent with regular op symbolic.
if op_name == "PythonOp":
inputs = (n, *inputs)
return symbolic_fn(g, *inputs, **attrs)
elif namespace == "onnx":
attrs = {
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
for k in n.attributeNames()
}
if namespace == "onnx":
# Clone node to trigger ONNX shape inference
attrs = {
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
for k in n.attributeNames()
}
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) # type: ignore[attr-defined]
elif _should_aten_fallback(
namespace, op_name, opset_version, operator_export_type
):
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
# Direct ATen export requested
attrs = {
k + "_" + n.kindOf(k)[0]: symbolic_helper._node_get(n, k)
for k in n.attributeNames()
}
outputs = n.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return g.at( # type: ignore[attr-defined]
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
)
else:
raise errors.UnsupportedOperatorError(
domain,
op_name,
opset_version,
symbolic_registry.get_op_supported_version(
op_name, domain, opset_version
),
)
raise errors.UnsupportedOperatorError(
domain,
op_name,
opset_version,
symbolic_function_group.get_min_supported()
if symbolic_function_group
else None,
)
except RuntimeError:
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
return None
@ -1869,31 +1835,26 @@ def _run_symbolic_function(
@_beartype.beartype
def get_ns_op_name_from_custom_op(symbolic_name):
if not bool(
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
):
raise ValueError(
f"Failed to register operator {symbolic_name}."
"The symbolic name must match the format Domain::Name, "
def _verify_custom_op_name(symbolic_name: str):
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
raise errors.OnnxExporterError(
f"Failed to register operator {symbolic_name}. "
"The symbolic name must match the format domain::name, "
"and should start with a letter and contain only "
"alphanumerical characters"
)
ns, op_name = symbolic_name.split("::")
ns, _ = symbolic_name.split("::")
if ns == "onnx":
raise ValueError(
f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
)
if ns == "aten":
ns = ""
return ns, op_name
@_beartype.beartype
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
def register_custom_op_symbolic(
symbolic_name: str, symbolic_fn: Callable, opset_version: int
):
"""Registers a symbolic function for a custom operator.
When the user registers symbolic for custom/contrib ops,
@ -1911,11 +1872,16 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
operator nodes to add to the graph.
opset_version (int): The ONNX opset version in which to register.
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
if symbolic_name.startswith("::"):
symbolic_name = f"aten{symbolic_name}"
_verify_custom_op_name(symbolic_name)
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
if version >= opset_version:
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
registration.registry.register(
symbolic_name, version, symbolic_fn, custom=True
)
@_beartype.beartype
@ -1929,11 +1895,14 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
format.
opset_version (int): The ONNX opset version in which to unregister.
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
if symbolic_name.startswith("::"):
symbolic_name = f"aten{symbolic_name}"
_verify_custom_op_name(symbolic_name)
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
if version >= opset_version:
symbolic_registry.unregister_op(op_name, ns, version)
registration.registry.unregister(symbolic_name, version)
@_beartype.beartype