PEP585 update - mostly toplevels (#145178)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-21 13:42:12 -08:00
committed by PyTorch MergeBot
parent 1ce533867f
commit f2cfe8b59f
39 changed files with 356 additions and 386 deletions

View File

@ -24,13 +24,9 @@ import threading
from typing import (
Any as _Any,
Callable as _Callable,
Dict as _Dict,
get_origin as _get_origin,
Optional as _Optional,
overload as _overload,
Set as _Set,
Tuple as _Tuple,
Type as _Type,
TYPE_CHECKING,
TypeVar as _TypeVar,
Union as _Union,
@ -337,7 +333,7 @@ def _load_global_deps() -> None:
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu12 is
cuda_libs: _Dict[str, str] = {
cuda_libs: dict[str, str] = {
"cublas": "libcublas.so.*[0-9]",
"cudnn": "libcudnn.so.*[0-9]",
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
@ -586,7 +582,7 @@ class SymInt:
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
# return hash(builtins.int(self))
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
"""Represent this int as an exact integer ratio"""
return self, 1
@ -698,7 +694,7 @@ class SymFloat:
"""Return True if the float is an integer."""
raise TypeError("type stub not overridden")
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
"""Represent this float as an exact integer ratio"""
return builtins.float(self).as_integer_ratio()
@ -857,22 +853,22 @@ def sym_max(a, b):
assert isinstance(a, all_types), type(a)
assert isinstance(b, all_types), type(b)
if isinstance(a, float_types) or isinstance(b, float_types):
return builtins.float(builtins.max(a, b))
return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
else:
return builtins.max(a, b)
return builtins.max(a, b) # type: ignore[call-overload]
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
try:
import numpy as np
all_types: _Tuple[_Type, ...] = (
all_types: tuple[type, ...] = (
np.integer,
np.floating,
builtins.int,
builtins.float,
)
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
float_types: tuple[type, ...] = (np.floating, builtins.float)
except ModuleNotFoundError:
all_types = (builtins.int, builtins.float)
float_types = (builtins.float,)
@ -894,9 +890,9 @@ def sym_min(a, b):
assert isinstance(a, all_types), type(a)
assert isinstance(b, all_types), type(b)
if isinstance(a, float_types) or isinstance(b, float_types):
return builtins.float(builtins.min(a, b))
return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
else:
return builtins.min(a, b)
return builtins.min(a, b) # type: ignore[call-overload]
def sym_sum(args):
@ -1204,7 +1200,7 @@ def set_default_device(
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None:
r"""
.. warning::
@ -2007,7 +2003,7 @@ class QUInt2x4Storage(_LegacyStorage):
return torch.quint2x4
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = {
UntypedStorage,
DoubleStorage,
FloatStorage,
@ -2030,7 +2026,7 @@ _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
}
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
_tensor_classes: set[type["torch.Tensor"]] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from torch import amp as amp, random as random, serialization as serialization
@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
def __init__(self, mode, options, dynamic):
from torch._inductor.compiler_bisector import CompilerBisector
self.config: _Dict[str, _Any] = {}
self.config: dict[str, _Any] = {}
self.dynamic = dynamic
self.apply_mode(mode)
self.apply_options(options)
@ -2309,13 +2305,13 @@ class _TorchCompileInductorWrapper:
self.apply_options(list_mode_options(mode, self.dynamic))
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
def apply_options(self, options: _Optional[dict[str, _Any]]):
if not options:
return
from torch._inductor import config
current_config: _Dict[str, _Any] = config.get_config_copy()
current_config: dict[str, _Any] = config.get_config_copy()
for key, val in options.items():
attr_name = key.replace("-", "_")
@ -2403,7 +2399,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Callable[_InputT, _RetT]: ...
@ -2416,7 +2412,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
@ -2428,7 +2424,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Union[
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
@ -2624,7 +2620,7 @@ if not _running_with_deploy():
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
ops_table: dict[tuple[str, str], _Callable] = {}
@classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):