mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #142306 This PR includes typing improvements and refactoring for the following files: - __init__.py - decorators.py - _ops.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/144047 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
committed by
PyTorch MergeBot
parent
9225f149eb
commit
d4609af1ca
@ -6,24 +6,36 @@ import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import _utils_internal
|
||||
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
|
||||
from torch._functorch.pyfunctorch import dispatch_functorch
|
||||
from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
|
||||
@ -112,11 +124,11 @@ class OperatorBase:
|
||||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
torch._C._functorch.TransformType,
|
||||
torch._C.DispatchKey,
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def inner(fn: _F) -> _F:
|
||||
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if inspect.isclass(k) and (
|
||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
||||
):
|
||||
@ -126,7 +138,7 @@ class OperatorBase:
|
||||
self._dispatch_cache.clear()
|
||||
return fn
|
||||
|
||||
if isinstance(k, torch._C._functorch.TransformType):
|
||||
if isinstance(k, TransformType):
|
||||
assert k not in self.functorch_table
|
||||
self.functorch_table[k] = fn
|
||||
return fn
|
||||
@ -134,7 +146,7 @@ class OperatorBase:
|
||||
assert isinstance(k, DispatchKey)
|
||||
assert (
|
||||
k != DispatchKey.Python
|
||||
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
||||
), "Please register a mode for the DispatchKey.Python key instead."
|
||||
|
||||
if k in self.py_kernels:
|
||||
raise RuntimeError(
|
||||
@ -157,31 +169,34 @@ class OperatorBase:
|
||||
# with ctx.redispatch_to_next():
|
||||
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
||||
# return ctx.wrap_tensors(out)
|
||||
def py_functionalize_impl(self, fn: _F) -> _F:
|
||||
def py_functionalize_impl(
|
||||
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
|
||||
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
|
||||
from torch._subclasses.functional_tensor import (
|
||||
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
||||
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
||||
PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
|
||||
CppFunctionalizeAPI,
|
||||
FunctionalTensorMode,
|
||||
FunctorchFunctionalizeAPI,
|
||||
PythonFunctionalizeAPI,
|
||||
)
|
||||
|
||||
# Construct our three flavors of functionalization,
|
||||
# each of which have slightly different wrap/unwrap/redispatch policies
|
||||
def functionalize_dk_fn(*args, **kwargs):
|
||||
return fn(_CppFunctionalizeAPI(), *args, **kwargs)
|
||||
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
return fn(CppFunctionalizeAPI(), *args, **kwargs)
|
||||
|
||||
def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
|
||||
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
|
||||
def functionalize_dispatch_mode_fn(
|
||||
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> _T:
|
||||
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
|
||||
|
||||
def functionalize_functorch_fn(interpreter, *args, **kwargs):
|
||||
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
||||
def functionalize_functorch_fn(
|
||||
interpreter, *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> _T:
|
||||
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
||||
|
||||
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
|
||||
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
|
||||
functionalize_dispatch_mode_fn
|
||||
)
|
||||
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
|
||||
functionalize_functorch_fn
|
||||
)
|
||||
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
|
||||
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
|
||||
|
||||
return fn
|
||||
|
||||
@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
torch._C._functorch.TransformType,
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
|
||||
), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
|
||||
assert (
|
||||
type(curr_mode) in self.python_key_table
|
||||
), f"Current active mode {curr_mode} not registered"
|
||||
@ -817,7 +832,7 @@ class OpOverload(OperatorBase):
|
||||
curr_mode = type(_get_current_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
|
||||
|
||||
if curr_mode not in self.python_key_table:
|
||||
if isinstance(self, TorchBindOpOverload):
|
||||
|
Reference in New Issue
Block a user