mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - mostly toplevels (#145178)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce533867f
commit
f2cfe8b59f
@ -24,13 +24,9 @@ import threading
|
||||
from typing import (
|
||||
Any as _Any,
|
||||
Callable as _Callable,
|
||||
Dict as _Dict,
|
||||
get_origin as _get_origin,
|
||||
Optional as _Optional,
|
||||
overload as _overload,
|
||||
Set as _Set,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
@ -337,7 +333,7 @@ def _load_global_deps() -> None:
|
||||
except OSError as err:
|
||||
# Can only happen for wheel with cuda libs as PYPI deps
|
||||
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
||||
cuda_libs: _Dict[str, str] = {
|
||||
cuda_libs: dict[str, str] = {
|
||||
"cublas": "libcublas.so.*[0-9]",
|
||||
"cudnn": "libcudnn.so.*[0-9]",
|
||||
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
|
||||
@ -586,7 +582,7 @@ class SymInt:
|
||||
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
|
||||
# return hash(builtins.int(self))
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
|
||||
def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
|
||||
"""Represent this int as an exact integer ratio"""
|
||||
return self, 1
|
||||
|
||||
@ -698,7 +694,7 @@ class SymFloat:
|
||||
"""Return True if the float is an integer."""
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
||||
def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
|
||||
"""Represent this float as an exact integer ratio"""
|
||||
return builtins.float(self).as_integer_ratio()
|
||||
|
||||
@ -857,22 +853,22 @@ def sym_max(a, b):
|
||||
assert isinstance(a, all_types), type(a)
|
||||
assert isinstance(b, all_types), type(b)
|
||||
if isinstance(a, float_types) or isinstance(b, float_types):
|
||||
return builtins.float(builtins.max(a, b))
|
||||
return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
|
||||
else:
|
||||
return builtins.max(a, b)
|
||||
return builtins.max(a, b) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
|
||||
def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
all_types: _Tuple[_Type, ...] = (
|
||||
all_types: tuple[type, ...] = (
|
||||
np.integer,
|
||||
np.floating,
|
||||
builtins.int,
|
||||
builtins.float,
|
||||
)
|
||||
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
|
||||
float_types: tuple[type, ...] = (np.floating, builtins.float)
|
||||
except ModuleNotFoundError:
|
||||
all_types = (builtins.int, builtins.float)
|
||||
float_types = (builtins.float,)
|
||||
@ -894,9 +890,9 @@ def sym_min(a, b):
|
||||
assert isinstance(a, all_types), type(a)
|
||||
assert isinstance(b, all_types), type(b)
|
||||
if isinstance(a, float_types) or isinstance(b, float_types):
|
||||
return builtins.float(builtins.min(a, b))
|
||||
return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
|
||||
else:
|
||||
return builtins.min(a, b)
|
||||
return builtins.min(a, b) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def sym_sum(args):
|
||||
@ -1204,7 +1200,7 @@ def set_default_device(
|
||||
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
|
||||
|
||||
|
||||
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
|
||||
def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None:
|
||||
r"""
|
||||
.. warning::
|
||||
|
||||
@ -2007,7 +2003,7 @@ class QUInt2x4Storage(_LegacyStorage):
|
||||
return torch.quint2x4
|
||||
|
||||
|
||||
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
UntypedStorage,
|
||||
DoubleStorage,
|
||||
FloatStorage,
|
||||
@ -2030,7 +2026,7 @@ _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
|
||||
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
|
||||
_tensor_classes: set[type["torch.Tensor"]] = set()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from torch import amp as amp, random as random, serialization as serialization
|
||||
@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
|
||||
def __init__(self, mode, options, dynamic):
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
self.config: _Dict[str, _Any] = {}
|
||||
self.config: dict[str, _Any] = {}
|
||||
self.dynamic = dynamic
|
||||
self.apply_mode(mode)
|
||||
self.apply_options(options)
|
||||
@ -2309,13 +2305,13 @@ class _TorchCompileInductorWrapper:
|
||||
|
||||
self.apply_options(list_mode_options(mode, self.dynamic))
|
||||
|
||||
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
||||
def apply_options(self, options: _Optional[dict[str, _Any]]):
|
||||
if not options:
|
||||
return
|
||||
|
||||
from torch._inductor import config
|
||||
|
||||
current_config: _Dict[str, _Any] = config.get_config_copy()
|
||||
current_config: dict[str, _Any] = config.get_config_copy()
|
||||
|
||||
for key, val in options.items():
|
||||
attr_name = key.replace("-", "_")
|
||||
@ -2403,7 +2399,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[_InputT, _RetT]: ...
|
||||
|
||||
@ -2416,7 +2412,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
|
||||
|
||||
@ -2428,7 +2424,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Union[
|
||||
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
|
||||
@ -2624,7 +2620,7 @@ if not _running_with_deploy():
|
||||
|
||||
class _TritonLibrary:
|
||||
lib = torch.library.Library("triton", "DEF")
|
||||
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
|
||||
ops_table: dict[tuple[str, str], _Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
|
||||
|
Reference in New Issue
Block a user