Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)"

This reverts commit 16caa8c1b3a02e47b5f52d3c2d40d7931cc427dc.

Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/jeanschmidt due to checking if this will solve inductor errors ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2427565425))
This commit is contained in:
PyTorch MergeBot
2024-10-21 19:40:58 +00:00
parent ff2f751bfb
commit 32d4582e02
12 changed files with 26 additions and 26 deletions

View File

@ -4,7 +4,7 @@ import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
from typing_extensions import TypeIs
from typing_extensions import TypeGuard
from collections import deque
import torch
@ -365,7 +365,7 @@ class TensorWithFlatten(Protocol):
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
"""
Returns whether or not a tensor subclass that implements __torch_dispatch__
is 'traceable' with torch.compile.
@ -402,7 +402,7 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
and hasattr(t, "__tensor_unflatten__")
)
def is_traceable_wrapper_subclass_type(t: Type) -> TypeIs[Type[TensorWithFlatten]]:
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
"""Same as above, but takes a type argument instead of an instance."""
return (issubclass(t, torch.Tensor) and t != torch.Tensor
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))