[BE] typing for decorators - jit/_decompositions (#131566)

See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131566
Approved by: https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
Aaron Orenstein
2024-07-24 09:54:26 -07:00
committed by PyTorch MergeBot
parent 2b83e4f8d7
commit 44fdf24967
5 changed files with 25 additions and 24 deletions

View File

@ -3,7 +3,8 @@ import inspect
from collections import defaultdict
from functools import wraps
from itertools import chain
from typing import Callable, Dict, List, Sequence, Union
from typing import Callable, Dict, List, Sequence, TypeVar, Union
from typing_extensions import ParamSpec
import torch
import torch.library
@ -20,6 +21,8 @@ __all__ = [
"core_aten_decompositions",
]
_T = TypeVar("_T")
_P = ParamSpec("_P")
# TODO: relax key type here; torch registrations should be possible to; but
# right now this type is accurate
@ -145,7 +148,7 @@ def _convert_out_params(f):
def register_decomposition(
aten_op, registry=None, *, type="post_autograd", unsafe=False
):
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
A decorator to register a function as a decomposition to the Python
decomposition table. Use it like this::
@ -171,7 +174,7 @@ def register_decomposition(
assert type in {"post_autograd", "pre_autograd", "meta"}
def decomposition_decorator(fn: Callable) -> Callable:
def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
orig_fn = fn
if not unsafe:
fn = _convert_out_params(fn)

View File

@ -657,7 +657,7 @@ noop_registry: Dict[Any, Any] = {}
def register_noop_decomp(targets, nop_arg=0):
def register_fun(cond):
register_decomposition(targets, registry=noop_registry, unsafe=True)(
(cond, nop_arg)
(cond, nop_arg) # type: ignore[arg-type]
)
return cond

View File

@ -516,7 +516,7 @@ def _make_inplace(fn):
inplace_name = f"{fn.__name__}_"
_fn.__name__ = inplace_name
_fn = register_decomposition(getattr(aten, inplace_name))(_fn)
_fn = register_decomposition(getattr(aten, inplace_name))(_fn) # type: ignore[assignment]
# We access the __all__ attribute of the module where fn is defined
# There may be a cleaner way of doing this...
@ -993,7 +993,7 @@ def view_as_complex(self: TensorLikeType) -> TensorLikeType:
)
dims = old_strides[:-1]
torch._check(
py_all(stride % 2 == 0 for stride in dims),
builtins.all(stride % 2 == 0 for stride in dims),
lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
)
torch._check(
@ -2168,7 +2168,7 @@ def _reduction(
dims = (dims,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dims)
if not has_identity:
valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
if not valid_shape:
raise RuntimeError(
"reducing over zero-size dimension for reduction operation without identity"
@ -2224,10 +2224,6 @@ def _make_copy_from_view(fn):
return _fn
# Saves Python all
py_all = all
@register_decomposition(aten.all)
@out_wrapper()
def all(
@ -2243,10 +2239,6 @@ def all(
return result
# Saves Python any
py_any = any
@register_decomposition(aten.any)
@out_wrapper()
def any(
@ -5074,7 +5066,7 @@ def linspace(
)
end = _maybe_convert_to_dtype(end, torch.float64)
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
default_complex_dtype = utils.corresponding_complex_dtype(
torch.get_default_dtype()
)
@ -5173,7 +5165,7 @@ def logspace(
)
end = _maybe_convert_to_dtype(end, dtype)
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
default_complex_dtype = utils.corresponding_complex_dtype(
torch.get_default_dtype()
)
@ -5208,7 +5200,7 @@ def meshgrid(*tensors: TensorLikeType, indexing: str):
pass
@register_decomposition(aten.meshgrid)
@register_decomposition(aten.meshgrid) # type: ignore[misc]
def meshgrid(
*tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
indexing: str,
@ -5221,7 +5213,7 @@ def meshgrid(
tensors = tuple(tensors[0])
torch._check(
py_all(isinstance(a, TensorLike) for a in tensors),
builtins.all(isinstance(a, TensorLike) for a in tensors),
lambda: "meshgrid expects its inputs to be tensors",
)

View File

@ -440,7 +440,7 @@ def _compile(
gm = make_fx(
partial(stateless_func, func),
tracing_mode=tracing_mode,
decomposition_table=SPMD_DECOMP_TABLE,
decomposition_table=SPMD_DECOMP_TABLE, # type: ignore[arg-type]
_allow_non_fake_inputs=False,
)(params, buffers, named_states, args, kwargs)

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
from torch import Tensor
@ -6,13 +5,17 @@ from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
from typing import Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set, TypeVar
from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
function_name_set: Set[str] = set()
_T = TypeVar("_T")
_P = ParamSpec("_P")
def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined]
@ -58,8 +61,11 @@ def signatures_match(decomposition_sig, torch_op_sig):
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
def register_decomposition(aten_op, registry=None):
def decomposition_decorator(f):
def register_decomposition(
aten_op: torch._ops.OpOverload,
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
nonlocal registry
if registry is None:
registry = decomposition_table