mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Improve typing in torch/types.py (#145237)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/145237 Approved by: https://github.com/XuehaiPan, https://github.com/albanD Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e46d0f595
commit
23eb0a3201
@ -1032,7 +1032,6 @@
|
||||
"List",
|
||||
"Number",
|
||||
"Sequence",
|
||||
"Tuple",
|
||||
"Union"
|
||||
],
|
||||
"torch.utils.benchmark.utils.compare": [
|
||||
|
@ -1,16 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
from torch.types import _bool
|
||||
from enum import IntEnum
|
||||
|
||||
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
||||
is_cuda: _bool
|
||||
is_cuda: bool
|
||||
|
||||
def getRuntimeVersion() -> tuple[int, int, int]: ...
|
||||
def getCompileVersion() -> tuple[int, int, int]: ...
|
||||
def getVersionInt() -> int: ...
|
||||
|
||||
class RNNMode(int, Enum):
|
||||
value: int
|
||||
class RNNMode(IntEnum):
|
||||
rnn_relu = ...
|
||||
rnn_tanh = ...
|
||||
lstm = ...
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
# In some cases, these basic types are shadowed by corresponding
|
||||
# top-level values. The underscore variants let us refer to these
|
||||
# types. See https://github.com/python/mypy/issues/4146 for why these
|
||||
@ -15,7 +13,7 @@ from builtins import ( # noqa: F401
|
||||
)
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, IO, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
|
||||
from torch import ( # noqa: F401
|
||||
@ -59,7 +57,7 @@ FloatLikeType: TypeAlias = Union[float, SymFloat]
|
||||
# bool or SymBool
|
||||
BoolLikeType: TypeAlias = Union[bool, SymBool]
|
||||
|
||||
py_sym_types = (SymInt, SymFloat, SymBool)
|
||||
py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally
|
||||
PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool]
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
@ -83,10 +81,10 @@ class Storage:
|
||||
dtype: _dtype
|
||||
_torch_load_uninitialized: bool
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> "Storage":
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_shared(self, size: int) -> "Storage":
|
||||
def _new_shared(self, size: int) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def _write_file(
|
||||
@ -104,13 +102,13 @@ class Storage:
|
||||
def is_shared(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def share_memory_(self) -> "Storage":
|
||||
def share_memory_(self) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def nbytes(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def cpu(self) -> "Storage":
|
||||
def cpu(self) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def data_ptr(self) -> int:
|
||||
@ -121,12 +119,12 @@ class Storage:
|
||||
filename: str,
|
||||
shared: bool = False,
|
||||
nbytes: int = 0,
|
||||
) -> "Storage":
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_with_file(
|
||||
self,
|
||||
f: Any,
|
||||
element_size: int,
|
||||
) -> "Storage":
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user