From 32d4582e021241f3310dfe1bf010f424ba9f05f1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Oct 2024 19:40:58 +0000 Subject: [PATCH] Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)" This reverts commit 16caa8c1b3a02e47b5f52d3c2d40d7931cc427dc. Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/jeanschmidt due to checking if this will solve inductor errors ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2427565425)) --- .ci/docker/requirements-ci.txt | 2 +- pyproject.toml | 2 +- requirements.txt | 2 +- setup.py | 2 +- torch/__init__.py | 6 +++--- torch/_dynamo/utils.py | 6 +++--- torch/_inductor/pattern_matcher.py | 8 ++++---- torch/_subclasses/fake_tensor.py | 6 +++--- torch/masked/maskedtensor/core.py | 4 ++-- torch/nn/parameter.pyi | 4 ++-- torch/serialization.py | 4 ++-- torch/utils/_python_dispatch.py | 6 +++--- 12 files changed, 26 insertions(+), 26 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 420c305e725e..17e6e8525e1c 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -257,7 +257,7 @@ tb-nightly==2.13.0a20230426 #test that import: # needed by torchgen utils -typing-extensions>=4.10.0 +typing-extensions #Description: type hints for python #Pinned versions: #test that import: diff --git a/pyproject.toml b/pyproject.toml index b03e4c3ae051..1e7def7ec492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = [ "ninja", "pyyaml", "cmake", - "typing-extensions>=4.10.0", + "typing-extensions", "requests", ] # Use legacy backend to import local packages in setup.py diff --git a/requirements.txt b/requirements.txt index 6ce86e87d892..f22947eb2eb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ requests # is required until pytorch build not refactored to work for latest setuptools. setuptools<=72.1.0 types-dataclasses -typing-extensions>=4.10.0 +typing-extensions>=4.8.0 sympy==1.13.1 ; python_version >= "3.9" filelock networkx diff --git a/setup.py b/setup.py index 46ed58f36b94..0bd12aacacfc 100644 --- a/setup.py +++ b/setup.py @@ -1159,7 +1159,7 @@ def main(): ) install_requires = [ "filelock", - "typing-extensions>=4.10.0", + "typing-extensions>=4.8.0", 'setuptools ; python_version >= "3.12"', 'sympy==1.13.1 ; python_version >= "3.9"', "networkx", diff --git a/torch/__init__.py b/torch/__init__.py index 995d90763531..fac732ede0db 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -34,7 +34,7 @@ from typing import ( TypeVar as _TypeVar, Union as _Union, ) -from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs +from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard if TYPE_CHECKING: @@ -1008,7 +1008,7 @@ def typename(obj: _Any, /) -> str: return f"{module}.{qualname}" -def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: +def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: r"""Returns True if `obj` is a PyTorch tensor. Note that this function is simply doing ``isinstance(obj, Tensor)``. @@ -1028,7 +1028,7 @@ def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: return isinstance(obj, torch.Tensor) -def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]: +def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]: r"""Returns True if `obj` is a PyTorch storage object. Args: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7e97d854835d..e59953466ece 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -56,7 +56,7 @@ from typing import ( Union, ValuesView, ) -from typing_extensions import Literal, TypeIs +from typing_extensions import Literal, TypeGuard import torch import torch._functorch.config @@ -569,14 +569,14 @@ class ExactWeakKeyDictionary: @overload -def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]: +def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: ... @overload def istype( obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] -) -> TypeIs[T]: +) -> TypeGuard[T]: ... diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index dd65522888f6..36e8765759be 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -70,7 +70,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import Self, TypeIs +from typing_extensions import Self, TypeGuard import torch import torch._guards @@ -305,10 +305,10 @@ class FailedMatch(RuntimeError): MatchResult = Union[Match, FailedMatch] -def is_match(m: MatchResult) -> TypeIs[Match]: +def is_match(m: MatchResult) -> TypeGuard[Match]: """ - TypeIs cannot act on `self`. Thus this function exists to let mypy - recognize FailedMatch.__bool__ as a TypeIs. + TypeGuards cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeGuard. """ return bool(m) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 4e2bf18d3cfc..f59e6242c982 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -32,7 +32,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import Self, TypeIs +from typing_extensions import Self, TypeGuard from weakref import ReferenceType import torch @@ -169,7 +169,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]: return plain_tensors -def is_fake(x: object) -> TypeIs[Tensor]: +def is_fake(x: object) -> TypeGuard[Tensor]: if isinstance(x, FakeTensor): return True if is_traceable_wrapper_subclass(x): @@ -1213,7 +1213,7 @@ class FakeTensorMode(TorchDispatchMode): # In this case, it's insufficient to test only one FakeTensor: you need # to distinguish between our fake tensor and other fake tensors. That's # what this function does. - def is_our_fake(self, t: object) -> TypeIs[FakeTensor]: + def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: return isinstance(t, FakeTensor) and t.fake_mode is self # If we should avoid device init. This changes the behavior of various APIs: diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index d1cc62032593..22f98b7a3182 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -3,7 +3,7 @@ import warnings from typing import Any -from typing_extensions import TypeIs +from typing_extensions import TypeGuard import torch from torch.overrides import get_default_nowrap_functions @@ -15,7 +15,7 @@ __all__ = [ ] -def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]: +def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]: r"""Returns True if the input is a MaskedTensor, else False Args: diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 6b5afa860b86..9c998fb07f2c 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing_extensions import TypeIs +from typing_extensions import TypeGuard from torch import device, dtype, Tensor @@ -8,7 +8,7 @@ class Parameter(Tensor): def is_lazy( param: Tensor, -) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ... +) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ... class UninitializedParameter(Tensor): def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... diff --git a/torch/serialization.py b/torch/serialization.py index 3fcf9f27deae..362680d86e22 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -28,7 +28,7 @@ from typing import ( Type, Union, ) -from typing_extensions import TypeAlias, TypeIs +from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ import torch import torch._weights_only_unpickler as _weights_only_unpickler @@ -620,7 +620,7 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]: +def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: return isinstance(name_or_buffer, (str, os.PathLike)) diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 04604bc6ec59..bf0853f1fe49 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -4,7 +4,7 @@ import contextlib import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type -from typing_extensions import TypeIs +from typing_extensions import TypeGuard from collections import deque import torch @@ -365,7 +365,7 @@ class TensorWithFlatten(Protocol): -def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: +def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: """ Returns whether or not a tensor subclass that implements __torch_dispatch__ is 'traceable' with torch.compile. @@ -402,7 +402,7 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: and hasattr(t, "__tensor_unflatten__") ) -def is_traceable_wrapper_subclass_type(t: Type) -> TypeIs[Type[TensorWithFlatten]]: +def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]: """Same as above, but takes a type argument instead of an instance.""" return (issubclass(t, torch.Tensor) and t != torch.Tensor and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))