[BE]: Update Typeguard to TypeIs for better type inference (#133814)

Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2024-08-18 19:10:14 +00:00
committed by PyTorch MergeBot
parent 0d4cedaa47
commit cf60fe53a8
9 changed files with 24 additions and 24 deletions

View File

@ -34,7 +34,7 @@ from typing import (
TypeVar as _TypeVar, TypeVar as _TypeVar,
Union as _Union, Union as _Union,
) )
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1000,7 +1000,7 @@ def typename(obj: _Any, /) -> str:
return f"{module}.{qualname}" return f"{module}.{qualname}"
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
r"""Returns True if `obj` is a PyTorch tensor. r"""Returns True if `obj` is a PyTorch tensor.
Note that this function is simply doing ``isinstance(obj, Tensor)``. Note that this function is simply doing ``isinstance(obj, Tensor)``.
@ -1020,7 +1020,7 @@ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
return isinstance(obj, torch.Tensor) return isinstance(obj, torch.Tensor)
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]: def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]:
r"""Returns True if `obj` is a PyTorch storage object. r"""Returns True if `obj` is a PyTorch storage object.
Args: Args:

View File

@ -53,7 +53,7 @@ from typing import (
Union, Union,
ValuesView, ValuesView,
) )
from typing_extensions import Literal, TypeGuard from typing_extensions import Literal, TypeIs
import torch import torch
import torch._functorch.config import torch._functorch.config
@ -526,14 +526,14 @@ class ExactWeakKeyDictionary:
@overload @overload
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
... ...
@overload @overload
def istype( def istype(
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
) -> TypeGuard[T]: ) -> TypeIs[T]:
... ...

View File

@ -70,7 +70,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
) )
from typing_extensions import Self, TypeGuard from typing_extensions import Self, TypeIs
import torch import torch
import torch._guards import torch._guards
@ -277,10 +277,10 @@ class FailedMatch(RuntimeError):
MatchResult = Union[Match, FailedMatch] MatchResult = Union[Match, FailedMatch]
def is_match(m: MatchResult) -> TypeGuard[Match]: def is_match(m: MatchResult) -> TypeIs[Match]:
""" """
TypeGuards cannot act on `self`. Thus this function exists to let mypy TypeIs cannot act on `self`. Thus this function exists to let mypy
recognize FailedMatch.__bool__ as a TypeGuard. recognize FailedMatch.__bool__ as a TypeIs.
""" """
return bool(m) return bool(m)

View File

@ -31,7 +31,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
) )
from typing_extensions import Self, TypeGuard from typing_extensions import Self, TypeIs
from weakref import ReferenceType from weakref import ReferenceType
import torch import torch
@ -168,7 +168,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
return plain_tensors return plain_tensors
def is_fake(x: object) -> TypeGuard[Tensor]: def is_fake(x: object) -> TypeIs[Tensor]:
if isinstance(x, FakeTensor): if isinstance(x, FakeTensor):
return True return True
if is_traceable_wrapper_subclass(x): if is_traceable_wrapper_subclass(x):
@ -1213,7 +1213,7 @@ class FakeTensorMode(TorchDispatchMode):
# In this case, it's insufficient to test only one FakeTensor: you need # In this case, it's insufficient to test only one FakeTensor: you need
# to distinguish between our fake tensor and other fake tensors. That's # to distinguish between our fake tensor and other fake tensors. That's
# what this function does. # what this function does.
def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
return isinstance(t, FakeTensor) and t.fake_mode is self return isinstance(t, FakeTensor) and t.fake_mode is self
# If we should avoid device init. This changes the behavior of various APIs: # If we should avoid device init. This changes the behavior of various APIs:

View File

@ -3,7 +3,7 @@
import warnings import warnings
from typing import Any from typing import Any
from typing_extensions import TypeGuard from typing_extensions import TypeIs
import torch import torch
from torch.overrides import get_default_nowrap_functions from torch.overrides import get_default_nowrap_functions
@ -15,7 +15,7 @@ __all__ = [
] ]
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]: def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
r"""Returns True if the input is a MaskedTensor, else False r"""Returns True if the input is a MaskedTensor, else False
Args: Args:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing_extensions import TypeGuard from typing_extensions import TypeIs
from torch import device, dtype, Tensor from torch import device, dtype, Tensor
@ -8,7 +8,7 @@ class Parameter(Tensor):
def is_lazy( def is_lazy(
param: Tensor, param: Tensor,
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ... ) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ...
class UninitializedParameter(Tensor): class UninitializedParameter(Tensor):
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...

View File

@ -27,7 +27,7 @@ from typing import (
Type, Type,
Union, Union,
) )
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ from typing_extensions import TypeAlias, TypeIs
import torch import torch
import torch._weights_only_unpickler as _weights_only_unpickler import torch._weights_only_unpickler as _weights_only_unpickler
@ -549,7 +549,7 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
return isinstance(name_or_buffer, (str, os.PathLike)) return isinstance(name_or_buffer, (str, os.PathLike))

View File

@ -4,7 +4,7 @@ import contextlib
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
from typing_extensions import TypeGuard from typing_extensions import TypeIs
from collections import deque from collections import deque
import torch import torch
@ -354,7 +354,7 @@ class TensorWithFlatten(Protocol):
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
""" """
Returns whether or not a tensor subclass that implements __torch_dispatch__ Returns whether or not a tensor subclass that implements __torch_dispatch__
is 'traceable' with torch.compile. is 'traceable' with torch.compile.

View File

@ -17,7 +17,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
) )
from typing_extensions import TypeGuard from typing_extensions import TypeIs
import sympy import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
@ -97,11 +97,11 @@ def sympy_generic_le(lower, upper):
return not (lower and not upper) return not (lower and not upper)
def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]: def vr_is_bool(vr: ValueRanges[_T]) -> TypeIs[ValueRanges[SympyBoolean]]:
return vr.is_bool return vr.is_bool
def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: def vr_is_expr(vr: ValueRanges[_T]) -> TypeIs[ValueRanges[sympy.Expr]]:
return not vr.is_bool return not vr.is_bool