mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)"
This reverts commit 16caa8c1b3a02e47b5f52d3c2d40d7931cc427dc. Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/jeanschmidt due to checking if this will solve inductor errors ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2427565425))
This commit is contained in:
@ -257,7 +257,7 @@ tb-nightly==2.13.0a20230426
|
||||
#test that import:
|
||||
|
||||
# needed by torchgen utils
|
||||
typing-extensions>=4.10.0
|
||||
typing-extensions
|
||||
#Description: type hints for python
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
@ -7,7 +7,7 @@ requires = [
|
||||
"ninja",
|
||||
"pyyaml",
|
||||
"cmake",
|
||||
"typing-extensions>=4.10.0",
|
||||
"typing-extensions",
|
||||
"requests",
|
||||
]
|
||||
# Use legacy backend to import local packages in setup.py
|
||||
|
@ -11,7 +11,7 @@ requests
|
||||
# is required until pytorch build not refactored to work for latest setuptools.
|
||||
setuptools<=72.1.0
|
||||
types-dataclasses
|
||||
typing-extensions>=4.10.0
|
||||
typing-extensions>=4.8.0
|
||||
sympy==1.13.1 ; python_version >= "3.9"
|
||||
filelock
|
||||
networkx
|
||||
|
2
setup.py
2
setup.py
@ -1159,7 +1159,7 @@ def main():
|
||||
)
|
||||
install_requires = [
|
||||
"filelock",
|
||||
"typing-extensions>=4.10.0",
|
||||
"typing-extensions>=4.8.0",
|
||||
'setuptools ; python_version >= "3.12"',
|
||||
'sympy==1.13.1 ; python_version >= "3.9"',
|
||||
"networkx",
|
||||
|
@ -34,7 +34,7 @@ from typing import (
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
|
||||
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1008,7 +1008,7 @@ def typename(obj: _Any, /) -> str:
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
|
||||
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
|
||||
r"""Returns True if `obj` is a PyTorch tensor.
|
||||
|
||||
Note that this function is simply doing ``isinstance(obj, Tensor)``.
|
||||
@ -1028,7 +1028,7 @@ def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
|
||||
return isinstance(obj, torch.Tensor)
|
||||
|
||||
|
||||
def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]:
|
||||
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
|
||||
r"""Returns True if `obj` is a PyTorch storage object.
|
||||
|
||||
Args:
|
||||
|
@ -56,7 +56,7 @@ from typing import (
|
||||
Union,
|
||||
ValuesView,
|
||||
)
|
||||
from typing_extensions import Literal, TypeIs
|
||||
from typing_extensions import Literal, TypeGuard
|
||||
|
||||
import torch
|
||||
import torch._functorch.config
|
||||
@ -569,14 +569,14 @@ class ExactWeakKeyDictionary:
|
||||
|
||||
|
||||
@overload
|
||||
def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
|
||||
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def istype(
|
||||
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
|
||||
) -> TypeIs[T]:
|
||||
) -> TypeGuard[T]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Self, TypeIs
|
||||
from typing_extensions import Self, TypeGuard
|
||||
|
||||
import torch
|
||||
import torch._guards
|
||||
@ -305,10 +305,10 @@ class FailedMatch(RuntimeError):
|
||||
MatchResult = Union[Match, FailedMatch]
|
||||
|
||||
|
||||
def is_match(m: MatchResult) -> TypeIs[Match]:
|
||||
def is_match(m: MatchResult) -> TypeGuard[Match]:
|
||||
"""
|
||||
TypeIs cannot act on `self`. Thus this function exists to let mypy
|
||||
recognize FailedMatch.__bool__ as a TypeIs.
|
||||
TypeGuards cannot act on `self`. Thus this function exists to let mypy
|
||||
recognize FailedMatch.__bool__ as a TypeGuard.
|
||||
"""
|
||||
return bool(m)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Self, TypeIs
|
||||
from typing_extensions import Self, TypeGuard
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
@ -169,7 +169,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
|
||||
return plain_tensors
|
||||
|
||||
|
||||
def is_fake(x: object) -> TypeIs[Tensor]:
|
||||
def is_fake(x: object) -> TypeGuard[Tensor]:
|
||||
if isinstance(x, FakeTensor):
|
||||
return True
|
||||
if is_traceable_wrapper_subclass(x):
|
||||
@ -1213,7 +1213,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# In this case, it's insufficient to test only one FakeTensor: you need
|
||||
# to distinguish between our fake tensor and other fake tensors. That's
|
||||
# what this function does.
|
||||
def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
|
||||
def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
|
||||
return isinstance(t, FakeTensor) and t.fake_mode is self
|
||||
|
||||
# If we should avoid device init. This changes the behavior of various APIs:
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
from typing_extensions import TypeIs
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
@ -15,7 +15,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
|
||||
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
|
||||
r"""Returns True if the input is a MaskedTensor, else False
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing_extensions import TypeIs
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from torch import device, dtype, Tensor
|
||||
|
||||
@ -8,7 +8,7 @@ class Parameter(Tensor):
|
||||
|
||||
def is_lazy(
|
||||
param: Tensor,
|
||||
) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ...
|
||||
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
|
||||
|
||||
class UninitializedParameter(Tensor):
|
||||
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
||||
|
@ -28,7 +28,7 @@ from typing import (
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
|
||||
|
||||
import torch
|
||||
import torch._weights_only_unpickler as _weights_only_unpickler
|
||||
@ -620,7 +620,7 @@ def storage_to_tensor_type(storage):
|
||||
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
|
||||
|
||||
|
||||
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
|
||||
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
|
||||
return isinstance(name_or_buffer, (str, os.PathLike))
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import contextlib
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
|
||||
from typing_extensions import TypeIs
|
||||
from typing_extensions import TypeGuard
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
@ -365,7 +365,7 @@ class TensorWithFlatten(Protocol):
|
||||
|
||||
|
||||
|
||||
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
|
||||
"""
|
||||
Returns whether or not a tensor subclass that implements __torch_dispatch__
|
||||
is 'traceable' with torch.compile.
|
||||
@ -402,7 +402,7 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
and hasattr(t, "__tensor_unflatten__")
|
||||
)
|
||||
|
||||
def is_traceable_wrapper_subclass_type(t: Type) -> TypeIs[Type[TensorWithFlatten]]:
|
||||
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
|
||||
"""Same as above, but takes a type argument instead of an instance."""
|
||||
return (issubclass(t, torch.Tensor) and t != torch.Tensor
|
||||
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
|
||||
|
Reference in New Issue
Block a user