[BE] typing for decorators - library (#138969)

Test Plan: unit tests

Differential Revision: D62302678

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138969
Approved by: https://github.com/zou3519
This commit is contained in:
Aaron Orenstein
2025-01-15 17:08:55 +00:00
committed by PyTorch MergeBot
parent c7a9599100
commit d782e46a36
14 changed files with 126 additions and 49 deletions

View File

@ -6,8 +6,22 @@ import re
import sys
import traceback
import weakref
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import deprecated
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
overload,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import deprecated, ParamSpec
import torch
import torch._library as _library
@ -38,6 +52,9 @@ __all__ = [
"infer_schema",
]
_T = TypeVar("_T")
_P = ParamSpec("_P")
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
@ -529,8 +546,43 @@ def _(lib: Library, schema, alias_analysis=""):
return wrap
@overload
def impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Literal[None] = None,
*,
lib: Optional[Library] = None,
) -> Callable[[Callable[..., object]], None]: ...
@overload
def impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Callable[..., object],
*,
lib: Optional[Library] = None,
) -> None: ...
# Deprecated BC API
@overload
def impl(
lib: Library,
name: str,
dispatch_key: str = "",
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
@functools.singledispatch
def impl(qualname, types, func=None, *, lib=None):
def impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Optional[Callable[_P, _T]] = None,
*,
lib: Optional[Library] = None,
) -> object:
"""Register an implementation for a device type for this operator.
You may pass "default" for ``types`` to register this implementation as the
@ -591,7 +643,52 @@ def impl(qualname, types, func=None, *, lib=None):
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
if not TYPE_CHECKING:
@impl.register
def _(
lib: Library, name: str, dispatch_key: str = ""
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Legacy torch.library.impl API. Kept around for BC"""
def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]:
lib.impl(name, f, dispatch_key)
return f
return wrap
@overload
def _impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Literal[None] = None,
*,
lib: Optional[Library] = None,
disable_dynamo: bool = False,
) -> Callable[[Callable[..., object]], None]: ...
@overload
def _impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Callable[..., object],
*,
lib: Optional[Library] = None,
disable_dynamo: bool = False,
) -> None: ...
def _impl(
qualname: str,
types: Union[str, Sequence[str]],
func: Optional[Callable[..., object]] = None,
*,
lib: Optional[Library] = None,
disable_dynamo: bool = False,
) -> Optional[Callable[[Callable[..., object]], None]]:
# See impl()
if isinstance(types, str):
types = (types,)
keys = set({})
@ -608,7 +705,7 @@ def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
else:
keys.add(_device_type_to_key(typ))
def register(func):
def register_(func: Callable[..., object]) -> None:
namespace, _ = torch._library.utils.parse_namespace(qualname)
if lib is None:
@ -629,9 +726,10 @@ def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
use_lib.impl(qualname, func, key)
if func is None:
return register
return register_
else:
register(func)
register_(func)
return None
def _device_type_to_key(device_type: str) -> str:
@ -644,17 +742,6 @@ def _device_type_to_key(device_type: str) -> str:
return torch._C._dispatch_key_for_device(device_type)
@impl.register
def _(lib: Library, name, dispatch_key=""):
"""Legacy torch.library.impl API. Kept around for BC"""
def wrap(f):
lib.impl(name, f, dispatch_key)
return f
return wrap
@deprecated(
"`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
"instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",