BE: Type previously untyped decorators (#154515)

Summary: Cloned #153726 from Skylion007 and fixed internal typing issues.

Test Plan: Unit tests pass

Differential Revision: D75477355

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154515
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-05-29 00:36:32 +00:00
committed by PyTorch MergeBot
parent ba0a91b3ea
commit 946a4c2bdc
10 changed files with 58 additions and 45 deletions

View File

@ -31,8 +31,10 @@ from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by tor
List,
Optional,
Tuple,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
import torch
@ -47,6 +49,9 @@ from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch.futures import Future
_P = ParamSpec("_P")
_R = TypeVar("_R")
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
BuiltinUnionType: Union[type, tuple[type, ...]]
@ -665,7 +670,7 @@ class FunctionModifiers:
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
def export(fn):
def export(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
:class:`ScriptModule` and should be compiled.
@ -707,11 +712,11 @@ def export(fn):
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
"""
fn._torchscript_modifier = FunctionModifiers.EXPORT
fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined]
return fn
def unused(fn):
def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
This decorator indicates to the compiler that a function or method should
be ignored and replaced with the raising of an exception. This allows you
@ -764,7 +769,7 @@ def unused(fn):
return prop
fn._torchscript_modifier = FunctionModifiers.UNUSED
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
return fn
@ -882,13 +887,13 @@ def ignore(drop=False, **kwargs):
return decorator
def _drop(fn):
fn._torchscript_modifier = FunctionModifiers._DROP
def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]:
fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined]
return fn
def _copy_to_script_wrapper(fn):
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined]
return fn