Improve typing in torch/types.py (#145237)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145237
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
cyyever
2025-01-28 05:29:12 +00:00
committed by PyTorch MergeBot
parent 8e46d0f595
commit 23eb0a3201
3 changed files with 11 additions and 17 deletions

View File

@ -1032,7 +1032,6 @@
"List",
"Number",
"Sequence",
"Tuple",
"Union"
],
"torch.utils.benchmark.utils.compare": [

View File

@ -1,16 +1,13 @@
from enum import Enum
from torch.types import _bool
from enum import IntEnum
# Defined in torch/csrc/cuda/shared/cudnn.cpp
is_cuda: _bool
is_cuda: bool
def getRuntimeVersion() -> tuple[int, int, int]: ...
def getCompileVersion() -> tuple[int, int, int]: ...
def getVersionInt() -> int: ...
class RNNMode(int, Enum):
value: int
class RNNMode(IntEnum):
rnn_relu = ...
rnn_tanh = ...
lstm = ...

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
# In some cases, these basic types are shadowed by corresponding
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
@ -15,7 +13,7 @@ from builtins import ( # noqa: F401
)
from collections.abc import Sequence
from typing import Any, IO, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
from torch import ( # noqa: F401
@ -59,7 +57,7 @@ FloatLikeType: TypeAlias = Union[float, SymFloat]
# bool or SymBool
BoolLikeType: TypeAlias = Union[bool, SymBool]
py_sym_types = (SymInt, SymFloat, SymBool)
py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally
PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool]
# Meta-type for "numeric" things; matches our docs
@ -83,10 +81,10 @@ class Storage:
dtype: _dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo: dict[int, Any]) -> "Storage":
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
raise NotImplementedError
def _new_shared(self, size: int) -> "Storage":
def _new_shared(self, size: int) -> Self:
raise NotImplementedError
def _write_file(
@ -104,13 +102,13 @@ class Storage:
def is_shared(self) -> bool:
raise NotImplementedError
def share_memory_(self) -> "Storage":
def share_memory_(self) -> Self:
raise NotImplementedError
def nbytes(self) -> int:
raise NotImplementedError
def cpu(self) -> "Storage":
def cpu(self) -> Self:
raise NotImplementedError
def data_ptr(self) -> int:
@ -121,12 +119,12 @@ class Storage:
filename: str,
shared: bool = False,
nbytes: int = 0,
) -> "Storage":
) -> Self:
raise NotImplementedError
def _new_with_file(
self,
f: Any,
element_size: int,
) -> "Storage":
) -> Self:
raise NotImplementedError