mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[BE] typing for decorators - jit/_decompositions (#131566)
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131566 Approved by: https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b83e4f8d7
commit
44fdf24967
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -6,13 +5,17 @@ from torch import Tensor
|
||||
aten = torch.ops.aten
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Callable, Dict, List, Optional, Set, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.types import Number
|
||||
|
||||
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
|
||||
function_name_set: Set[str] = set()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def check_decomposition_has_type_annotations(f):
|
||||
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
||||
@ -58,8 +61,11 @@ def signatures_match(decomposition_sig, torch_op_sig):
|
||||
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
|
||||
|
||||
|
||||
def register_decomposition(aten_op, registry=None):
|
||||
def decomposition_decorator(f):
|
||||
def register_decomposition(
|
||||
aten_op: torch._ops.OpOverload,
|
||||
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
nonlocal registry
|
||||
if registry is None:
|
||||
registry = decomposition_table
|
||||
|
||||
Reference in New Issue
Block a user