Propagate callable parameter types using ParamSpec (#142306) (#144047)

Fixes #142306

This PR includes typing improvements and refactoring for the following files:
- __init__.py
- decorators.py
- _ops.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144047
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
yijun-lee
2025-01-06 16:16:16 +00:00
committed by PyTorch MergeBot
parent 9225f149eb
commit d4609af1ca
4 changed files with 59 additions and 40 deletions

View File

@ -6,24 +6,36 @@ import importlib
import inspect
import sys
import types
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
from typing_extensions import ParamSpec
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Concatenate, ParamSpec
import torch
import torch.utils._pytree as pytree
from torch import _utils_internal
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
from torch._functorch.pyfunctorch import dispatch_functorch
from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
from torch.utils._python_dispatch import TorchDispatchMode
if TYPE_CHECKING:
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
_T = TypeVar("_T")
_P = ParamSpec("_P")
_F = TypeVar("_F", bound=Callable[..., Any])
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -112,11 +124,11 @@ class OperatorBase:
k: Union[
Type[TorchDispatchMode],
Type[torch.Tensor],
torch._C._functorch.TransformType,
torch._C.DispatchKey,
TransformType,
DispatchKey,
],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def inner(fn: _F) -> _F:
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if inspect.isclass(k) and (
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
):
@ -126,7 +138,7 @@ class OperatorBase:
self._dispatch_cache.clear()
return fn
if isinstance(k, torch._C._functorch.TransformType):
if isinstance(k, TransformType):
assert k not in self.functorch_table
self.functorch_table[k] = fn
return fn
@ -134,7 +146,7 @@ class OperatorBase:
assert isinstance(k, DispatchKey)
assert (
k != DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
), "Please register a mode for the DispatchKey.Python key instead."
if k in self.py_kernels:
raise RuntimeError(
@ -157,31 +169,34 @@ class OperatorBase:
# with ctx.redispatch_to_next():
# out = ctx.functionalize(inner_f)(*args_unwrapped)
# return ctx.wrap_tensors(out)
def py_functionalize_impl(self, fn: _F) -> _F:
def py_functionalize_impl(
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI as _CppFunctionalizeAPI,
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
CppFunctionalizeAPI,
FunctionalTensorMode,
FunctorchFunctionalizeAPI,
PythonFunctionalizeAPI,
)
# Construct our three flavors of functionalization,
# each of which have slightly different wrap/unwrap/redispatch policies
def functionalize_dk_fn(*args, **kwargs):
return fn(_CppFunctionalizeAPI(), *args, **kwargs)
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
return fn(CppFunctionalizeAPI(), *args, **kwargs)
def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
def functionalize_dispatch_mode_fn(
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
def functionalize_functorch_fn(interpreter, *args, **kwargs):
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
def functionalize_functorch_fn(
interpreter, *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
functionalize_dispatch_mode_fn
)
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
functionalize_functorch_fn
)
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
return fn
@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
k: Union[
Type[TorchDispatchMode],
Type[torch.Tensor],
torch._C._functorch.TransformType,
TransformType,
DispatchKey,
],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
curr_mode = _get_current_dispatch_mode_pre_dispatch()
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
assert (
type(curr_mode) in self.python_key_table
), f"Current active mode {curr_mode} not registered"
@ -817,7 +832,7 @@ class OpOverload(OperatorBase):
curr_mode = type(_get_current_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
if curr_mode not in self.python_key_table:
if isinstance(self, TorchBindOpOverload):