mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
[BE]: Update Typeguard to TypeIs for better type inference (#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d4cedaa47
commit
cf60fe53a8
@ -34,7 +34,7 @@ from typing import (
|
|||||||
TypeVar as _TypeVar,
|
TypeVar as _TypeVar,
|
||||||
Union as _Union,
|
Union as _Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1000,7 +1000,7 @@ def typename(obj: _Any, /) -> str:
|
|||||||
return f"{module}.{qualname}"
|
return f"{module}.{qualname}"
|
||||||
|
|
||||||
|
|
||||||
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
|
def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
|
||||||
r"""Returns True if `obj` is a PyTorch tensor.
|
r"""Returns True if `obj` is a PyTorch tensor.
|
||||||
|
|
||||||
Note that this function is simply doing ``isinstance(obj, Tensor)``.
|
Note that this function is simply doing ``isinstance(obj, Tensor)``.
|
||||||
@ -1020,7 +1020,7 @@ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
|
|||||||
return isinstance(obj, torch.Tensor)
|
return isinstance(obj, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
|
def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]:
|
||||||
r"""Returns True if `obj` is a PyTorch storage object.
|
r"""Returns True if `obj` is a PyTorch storage object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -53,7 +53,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
ValuesView,
|
ValuesView,
|
||||||
)
|
)
|
||||||
from typing_extensions import Literal, TypeGuard
|
from typing_extensions import Literal, TypeIs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._functorch.config
|
import torch._functorch.config
|
||||||
@ -526,14 +526,14 @@ class ExactWeakKeyDictionary:
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]:
|
def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def istype(
|
def istype(
|
||||||
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
|
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
|
||||||
) -> TypeGuard[T]:
|
) -> TypeIs[T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import Self, TypeGuard
|
from typing_extensions import Self, TypeIs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._guards
|
import torch._guards
|
||||||
@ -277,10 +277,10 @@ class FailedMatch(RuntimeError):
|
|||||||
MatchResult = Union[Match, FailedMatch]
|
MatchResult = Union[Match, FailedMatch]
|
||||||
|
|
||||||
|
|
||||||
def is_match(m: MatchResult) -> TypeGuard[Match]:
|
def is_match(m: MatchResult) -> TypeIs[Match]:
|
||||||
"""
|
"""
|
||||||
TypeGuards cannot act on `self`. Thus this function exists to let mypy
|
TypeIs cannot act on `self`. Thus this function exists to let mypy
|
||||||
recognize FailedMatch.__bool__ as a TypeGuard.
|
recognize FailedMatch.__bool__ as a TypeIs.
|
||||||
"""
|
"""
|
||||||
return bool(m)
|
return bool(m)
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import Self, TypeGuard
|
from typing_extensions import Self, TypeIs
|
||||||
from weakref import ReferenceType
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -168,7 +168,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
|
|||||||
return plain_tensors
|
return plain_tensors
|
||||||
|
|
||||||
|
|
||||||
def is_fake(x: object) -> TypeGuard[Tensor]:
|
def is_fake(x: object) -> TypeIs[Tensor]:
|
||||||
if isinstance(x, FakeTensor):
|
if isinstance(x, FakeTensor):
|
||||||
return True
|
return True
|
||||||
if is_traceable_wrapper_subclass(x):
|
if is_traceable_wrapper_subclass(x):
|
||||||
@ -1213,7 +1213,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
# In this case, it's insufficient to test only one FakeTensor: you need
|
# In this case, it's insufficient to test only one FakeTensor: you need
|
||||||
# to distinguish between our fake tensor and other fake tensors. That's
|
# to distinguish between our fake tensor and other fake tensors. That's
|
||||||
# what this function does.
|
# what this function does.
|
||||||
def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
|
def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
|
||||||
return isinstance(t, FakeTensor) and t.fake_mode is self
|
return isinstance(t, FakeTensor) and t.fake_mode is self
|
||||||
|
|
||||||
# If we should avoid device init. This changes the behavior of various APIs:
|
# If we should avoid device init. This changes the behavior of various APIs:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.overrides import get_default_nowrap_functions
|
from torch.overrides import get_default_nowrap_functions
|
||||||
@ -15,7 +15,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
|
def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
|
||||||
r"""Returns True if the input is a MaskedTensor, else False
|
r"""Returns True if the input is a MaskedTensor, else False
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from torch import device, dtype, Tensor
|
from torch import device, dtype, Tensor
|
||||||
|
|
||||||
@ -8,7 +8,7 @@ class Parameter(Tensor):
|
|||||||
|
|
||||||
def is_lazy(
|
def is_lazy(
|
||||||
param: Tensor,
|
param: Tensor,
|
||||||
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
|
) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ...
|
||||||
|
|
||||||
class UninitializedParameter(Tensor):
|
class UninitializedParameter(Tensor):
|
||||||
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
|
from typing_extensions import TypeAlias, TypeIs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._weights_only_unpickler as _weights_only_unpickler
|
import torch._weights_only_unpickler as _weights_only_unpickler
|
||||||
@ -549,7 +549,7 @@ def storage_to_tensor_type(storage):
|
|||||||
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
|
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
|
||||||
|
|
||||||
|
|
||||||
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
|
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
|
||||||
return isinstance(name_or_buffer, (str, os.PathLike))
|
return isinstance(name_or_buffer, (str, os.PathLike))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import contextlib
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
|
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeIs
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -354,7 +354,7 @@ class TensorWithFlatten(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
|
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||||
"""
|
"""
|
||||||
Returns whether or not a tensor subclass that implements __torch_dispatch__
|
Returns whether or not a tensor subclass that implements __torch_dispatch__
|
||||||
is 'traceable' with torch.compile.
|
is 'traceable' with torch.compile.
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
||||||
@ -97,11 +97,11 @@ def sympy_generic_le(lower, upper):
|
|||||||
return not (lower and not upper)
|
return not (lower and not upper)
|
||||||
|
|
||||||
|
|
||||||
def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
|
def vr_is_bool(vr: ValueRanges[_T]) -> TypeIs[ValueRanges[SympyBoolean]]:
|
||||||
return vr.is_bool
|
return vr.is_bool
|
||||||
|
|
||||||
|
|
||||||
def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
|
def vr_is_expr(vr: ValueRanges[_T]) -> TypeIs[ValueRanges[sympy.Expr]]:
|
||||||
return not vr.is_bool
|
return not vr.is_bool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user