mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] typing for decorators - jit/_decompositions (#131566)
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131566 Approved by: https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b83e4f8d7
commit
44fdf24967
@ -3,7 +3,8 @@ import inspect
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from typing import Callable, Dict, List, Sequence, Union
|
||||
from typing import Callable, Dict, List, Sequence, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
@ -20,6 +21,8 @@ __all__ = [
|
||||
"core_aten_decompositions",
|
||||
]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
# TODO: relax key type here; torch registrations should be possible to; but
|
||||
# right now this type is accurate
|
||||
@ -145,7 +148,7 @@ def _convert_out_params(f):
|
||||
|
||||
def register_decomposition(
|
||||
aten_op, registry=None, *, type="post_autograd", unsafe=False
|
||||
):
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
"""
|
||||
A decorator to register a function as a decomposition to the Python
|
||||
decomposition table. Use it like this::
|
||||
@ -171,7 +174,7 @@ def register_decomposition(
|
||||
|
||||
assert type in {"post_autograd", "pre_autograd", "meta"}
|
||||
|
||||
def decomposition_decorator(fn: Callable) -> Callable:
|
||||
def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
orig_fn = fn
|
||||
if not unsafe:
|
||||
fn = _convert_out_params(fn)
|
||||
|
@ -657,7 +657,7 @@ noop_registry: Dict[Any, Any] = {}
|
||||
def register_noop_decomp(targets, nop_arg=0):
|
||||
def register_fun(cond):
|
||||
register_decomposition(targets, registry=noop_registry, unsafe=True)(
|
||||
(cond, nop_arg)
|
||||
(cond, nop_arg) # type: ignore[arg-type]
|
||||
)
|
||||
return cond
|
||||
|
||||
|
@ -516,7 +516,7 @@ def _make_inplace(fn):
|
||||
|
||||
inplace_name = f"{fn.__name__}_"
|
||||
_fn.__name__ = inplace_name
|
||||
_fn = register_decomposition(getattr(aten, inplace_name))(_fn)
|
||||
_fn = register_decomposition(getattr(aten, inplace_name))(_fn) # type: ignore[assignment]
|
||||
|
||||
# We access the __all__ attribute of the module where fn is defined
|
||||
# There may be a cleaner way of doing this...
|
||||
@ -993,7 +993,7 @@ def view_as_complex(self: TensorLikeType) -> TensorLikeType:
|
||||
)
|
||||
dims = old_strides[:-1]
|
||||
torch._check(
|
||||
py_all(stride % 2 == 0 for stride in dims),
|
||||
builtins.all(stride % 2 == 0 for stride in dims),
|
||||
lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
|
||||
)
|
||||
torch._check(
|
||||
@ -2168,7 +2168,7 @@ def _reduction(
|
||||
dims = (dims,) # type: ignore[assignment]
|
||||
dims = utils.reduction_dims(a.shape, dims)
|
||||
if not has_identity:
|
||||
valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
|
||||
valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
|
||||
if not valid_shape:
|
||||
raise RuntimeError(
|
||||
"reducing over zero-size dimension for reduction operation without identity"
|
||||
@ -2224,10 +2224,6 @@ def _make_copy_from_view(fn):
|
||||
return _fn
|
||||
|
||||
|
||||
# Saves Python all
|
||||
py_all = all
|
||||
|
||||
|
||||
@register_decomposition(aten.all)
|
||||
@out_wrapper()
|
||||
def all(
|
||||
@ -2243,10 +2239,6 @@ def all(
|
||||
return result
|
||||
|
||||
|
||||
# Saves Python any
|
||||
py_any = any
|
||||
|
||||
|
||||
@register_decomposition(aten.any)
|
||||
@out_wrapper()
|
||||
def any(
|
||||
@ -5074,7 +5066,7 @@ def linspace(
|
||||
)
|
||||
end = _maybe_convert_to_dtype(end, torch.float64)
|
||||
|
||||
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
default_complex_dtype = utils.corresponding_complex_dtype(
|
||||
torch.get_default_dtype()
|
||||
)
|
||||
@ -5173,7 +5165,7 @@ def logspace(
|
||||
)
|
||||
end = _maybe_convert_to_dtype(end, dtype)
|
||||
|
||||
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
default_complex_dtype = utils.corresponding_complex_dtype(
|
||||
torch.get_default_dtype()
|
||||
)
|
||||
@ -5208,7 +5200,7 @@ def meshgrid(*tensors: TensorLikeType, indexing: str):
|
||||
pass
|
||||
|
||||
|
||||
@register_decomposition(aten.meshgrid)
|
||||
@register_decomposition(aten.meshgrid) # type: ignore[misc]
|
||||
def meshgrid(
|
||||
*tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
|
||||
indexing: str,
|
||||
@ -5221,7 +5213,7 @@ def meshgrid(
|
||||
tensors = tuple(tensors[0])
|
||||
|
||||
torch._check(
|
||||
py_all(isinstance(a, TensorLike) for a in tensors),
|
||||
builtins.all(isinstance(a, TensorLike) for a in tensors),
|
||||
lambda: "meshgrid expects its inputs to be tensors",
|
||||
)
|
||||
|
||||
|
@ -440,7 +440,7 @@ def _compile(
|
||||
gm = make_fx(
|
||||
partial(stateless_func, func),
|
||||
tracing_mode=tracing_mode,
|
||||
decomposition_table=SPMD_DECOMP_TABLE,
|
||||
decomposition_table=SPMD_DECOMP_TABLE, # type: ignore[arg-type]
|
||||
_allow_non_fake_inputs=False,
|
||||
)(params, buffers, named_states, args, kwargs)
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -6,13 +5,17 @@ from torch import Tensor
|
||||
aten = torch.ops.aten
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Callable, Dict, List, Optional, Set, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.types import Number
|
||||
|
||||
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
|
||||
function_name_set: Set[str] = set()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def check_decomposition_has_type_annotations(f):
|
||||
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
||||
@ -58,8 +61,11 @@ def signatures_match(decomposition_sig, torch_op_sig):
|
||||
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
|
||||
|
||||
|
||||
def register_decomposition(aten_op, registry=None):
|
||||
def decomposition_decorator(f):
|
||||
def register_decomposition(
|
||||
aten_op: torch._ops.OpOverload,
|
||||
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
nonlocal registry
|
||||
if registry is None:
|
||||
registry = decomposition_table
|
||||
|
Reference in New Issue
Block a user