[BE] typing for decorators - jit/_decompositions (#131566)

See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131566
Approved by: https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
Aaron Orenstein
2024-07-24 09:54:26 -07:00
committed by PyTorch MergeBot
parent 2b83e4f8d7
commit 44fdf24967
5 changed files with 25 additions and 24 deletions

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
from torch import Tensor
@ -6,13 +5,17 @@ from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
from typing import Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set, TypeVar
from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
function_name_set: Set[str] = set()
_T = TypeVar("_T")
_P = ParamSpec("_P")
def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined]
@ -58,8 +61,11 @@ def signatures_match(decomposition_sig, torch_op_sig):
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
def register_decomposition(aten_op, registry=None):
def decomposition_decorator(f):
def register_decomposition(
aten_op: torch._ops.OpOverload,
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
nonlocal registry
if registry is None:
registry = decomposition_table