mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] typing for decorators - library (#138969)
Test Plan: unit tests Differential Revision: D62302678 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138969 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7a9599100
commit
d782e46a36
123
torch/library.py
123
torch/library.py
@ -6,8 +6,22 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing_extensions import deprecated
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._library as _library
|
||||
@ -38,6 +52,9 @@ __all__ = [
|
||||
"infer_schema",
|
||||
]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
||||
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
|
||||
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
|
||||
@ -529,8 +546,43 @@ def _(lib: Library, schema, alias_analysis=""):
|
||||
return wrap
|
||||
|
||||
|
||||
@overload
|
||||
def impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Literal[None] = None,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> Callable[[Callable[..., object]], None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Callable[..., object],
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
# Deprecated BC API
|
||||
@overload
|
||||
def impl(
|
||||
lib: Library,
|
||||
name: str,
|
||||
dispatch_key: str = "",
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def impl(qualname, types, func=None, *, lib=None):
|
||||
def impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Optional[Callable[_P, _T]] = None,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> object:
|
||||
"""Register an implementation for a device type for this operator.
|
||||
|
||||
You may pass "default" for ``types`` to register this implementation as the
|
||||
@ -591,7 +643,52 @@ def impl(qualname, types, func=None, *, lib=None):
|
||||
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
||||
|
||||
|
||||
def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
@impl.register
|
||||
def _(
|
||||
lib: Library, name: str, dispatch_key: str = ""
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
"""Legacy torch.library.impl API. Kept around for BC"""
|
||||
|
||||
def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
lib.impl(name, f, dispatch_key)
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@overload
|
||||
def _impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Literal[None] = None,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
disable_dynamo: bool = False,
|
||||
) -> Callable[[Callable[..., object]], None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Callable[..., object],
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
disable_dynamo: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _impl(
|
||||
qualname: str,
|
||||
types: Union[str, Sequence[str]],
|
||||
func: Optional[Callable[..., object]] = None,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
disable_dynamo: bool = False,
|
||||
) -> Optional[Callable[[Callable[..., object]], None]]:
|
||||
# See impl()
|
||||
if isinstance(types, str):
|
||||
types = (types,)
|
||||
keys = set({})
|
||||
@ -608,7 +705,7 @@ def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
|
||||
else:
|
||||
keys.add(_device_type_to_key(typ))
|
||||
|
||||
def register(func):
|
||||
def register_(func: Callable[..., object]) -> None:
|
||||
namespace, _ = torch._library.utils.parse_namespace(qualname)
|
||||
|
||||
if lib is None:
|
||||
@ -629,9 +726,10 @@ def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
|
||||
use_lib.impl(qualname, func, key)
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
return register_
|
||||
else:
|
||||
register(func)
|
||||
register_(func)
|
||||
return None
|
||||
|
||||
|
||||
def _device_type_to_key(device_type: str) -> str:
|
||||
@ -644,17 +742,6 @@ def _device_type_to_key(device_type: str) -> str:
|
||||
return torch._C._dispatch_key_for_device(device_type)
|
||||
|
||||
|
||||
@impl.register
|
||||
def _(lib: Library, name, dispatch_key=""):
|
||||
"""Legacy torch.library.impl API. Kept around for BC"""
|
||||
|
||||
def wrap(f):
|
||||
lib.impl(name, f, dispatch_key)
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
|
||||
"instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
|
||||
|
Reference in New Issue
Block a user