Files
pytorch/torch/_library/simple_registry.py
Yuanyuan Chen a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00

89 lines
2.8 KiB
Python

from collections.abc import Callable
from typing import Any, Optional
from .fake_impl import FakeImplHolder
from .utils import RegistrationHandle
__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
class SimpleLibraryRegistry:
"""Registry for the "simple" torch.library APIs
The "simple" torch.library APIs are a higher-level API on top of the
raw PyTorch DispatchKey registration APIs that includes:
- fake impl
Registrations for these APIs do not go into the PyTorch dispatcher's
table because they may not directly involve a DispatchKey. For example,
the fake impl is a Python function that gets invoked by FakeTensor.
Instead, we manage them here.
SimpleLibraryRegistry is a mapping from a fully qualified operator name
(including the overload) to SimpleOperatorEntry.
"""
def __init__(self) -> None:
self._data: dict[str, SimpleOperatorEntry] = {}
def find(self, qualname: str) -> "SimpleOperatorEntry":
res = self._data.get(qualname, None)
if res is None:
self._data[qualname] = res = SimpleOperatorEntry(qualname)
return res
singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
class SimpleOperatorEntry:
"""This is 1:1 to an operator overload.
The fields of SimpleOperatorEntry are Holders where kernels can be
registered to.
"""
def __init__(self, qualname: str) -> None:
self.qualname: str = qualname
self.fake_impl: FakeImplHolder = FakeImplHolder(qualname)
self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = (
GenericTorchDispatchRuleHolder(qualname)
)
# For compatibility reasons. We can delete this soon.
@property
def abstract_impl(self) -> FakeImplHolder:
return self.fake_impl
class GenericTorchDispatchRuleHolder:
def __init__(self, qualname: str) -> None:
self._data: dict[type, Callable[..., Any]] = {}
self.qualname: str = qualname
def register(
self, torch_dispatch_class: type, func: Callable[..., Any]
) -> RegistrationHandle:
if self.find(torch_dispatch_class):
raise RuntimeError(
f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}"
)
self._data[torch_dispatch_class] = func
def deregister() -> None:
del self._data[torch_dispatch_class]
return RegistrationHandle(deregister)
def find(self, torch_dispatch_class: type) -> Optional[Callable[..., Any]]:
return self._data.get(torch_dispatch_class, None)
def find_torch_dispatch_rule(
op: Any, torch_dispatch_class: type
) -> Optional[Callable[..., Any]]:
return singleton.find(op.__qualname__).torch_dispatch_rules.find(
torch_dispatch_class
)