mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
BE: Type previously untyped decorators (#154515)
Summary: Cloned #153726 from Skylion007 and fixed internal typing issues. Test Plan: Unit tests pass Differential Revision: D75477355 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154515 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba0a91b3ea
commit
946a4c2bdc
@ -31,8 +31,10 @@ from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by tor
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
@ -47,6 +49,9 @@ from torch._sources import fake_range, get_source_lines_and_file, parse_def
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
|
||||
|
||||
BuiltinUnionType: Union[type, tuple[type, ...]]
|
||||
@ -665,7 +670,7 @@ class FunctionModifiers:
|
||||
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
|
||||
|
||||
|
||||
def export(fn):
|
||||
def export(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
|
||||
:class:`ScriptModule` and should be compiled.
|
||||
@ -707,11 +712,11 @@ def export(fn):
|
||||
# any compiled methods and wasn't decorated with `@torch.jit.export`
|
||||
m = torch.jit.script(MyModule())
|
||||
"""
|
||||
fn._torchscript_modifier = FunctionModifiers.EXPORT
|
||||
fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
def unused(fn):
|
||||
def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
This decorator indicates to the compiler that a function or method should
|
||||
be ignored and replaced with the raising of an exception. This allows you
|
||||
@ -764,7 +769,7 @@ def unused(fn):
|
||||
|
||||
return prop
|
||||
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
@ -882,13 +887,13 @@ def ignore(drop=False, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _drop(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers._DROP
|
||||
def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
def _copy_to_script_wrapper(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
|
||||
def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user