PEP585 update - mostly toplevels (#145178)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-21 13:42:12 -08:00
committed by PyTorch MergeBot
parent 1ce533867f
commit f2cfe8b59f
39 changed files with 356 additions and 386 deletions

View File

@ -1,12 +1,12 @@
from enum import Enum
from torch.types import _bool, Tuple
from torch.types import _bool
# Defined in torch/csrc/cuda/shared/cudnn.cpp
is_cuda: _bool
def getRuntimeVersion() -> Tuple[int, int, int]: ...
def getCompileVersion() -> Tuple[int, int, int]: ...
def getRuntimeVersion() -> tuple[int, int, int]: ...
def getCompileVersion() -> tuple[int, int, int]: ...
def getVersionInt() -> int: ...
class RNNMode(int, Enum):

View File

@ -24,13 +24,9 @@ import threading
from typing import (
Any as _Any,
Callable as _Callable,
Dict as _Dict,
get_origin as _get_origin,
Optional as _Optional,
overload as _overload,
Set as _Set,
Tuple as _Tuple,
Type as _Type,
TYPE_CHECKING,
TypeVar as _TypeVar,
Union as _Union,
@ -337,7 +333,7 @@ def _load_global_deps() -> None:
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu12 is
cuda_libs: _Dict[str, str] = {
cuda_libs: dict[str, str] = {
"cublas": "libcublas.so.*[0-9]",
"cudnn": "libcudnn.so.*[0-9]",
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
@ -586,7 +582,7 @@ class SymInt:
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
# return hash(builtins.int(self))
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
"""Represent this int as an exact integer ratio"""
return self, 1
@ -698,7 +694,7 @@ class SymFloat:
"""Return True if the float is an integer."""
raise TypeError("type stub not overridden")
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
"""Represent this float as an exact integer ratio"""
return builtins.float(self).as_integer_ratio()
@ -857,22 +853,22 @@ def sym_max(a, b):
assert isinstance(a, all_types), type(a)
assert isinstance(b, all_types), type(b)
if isinstance(a, float_types) or isinstance(b, float_types):
return builtins.float(builtins.max(a, b))
return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
else:
return builtins.max(a, b)
return builtins.max(a, b) # type: ignore[call-overload]
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
try:
import numpy as np
all_types: _Tuple[_Type, ...] = (
all_types: tuple[type, ...] = (
np.integer,
np.floating,
builtins.int,
builtins.float,
)
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
float_types: tuple[type, ...] = (np.floating, builtins.float)
except ModuleNotFoundError:
all_types = (builtins.int, builtins.float)
float_types = (builtins.float,)
@ -894,9 +890,9 @@ def sym_min(a, b):
assert isinstance(a, all_types), type(a)
assert isinstance(b, all_types), type(b)
if isinstance(a, float_types) or isinstance(b, float_types):
return builtins.float(builtins.min(a, b))
return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
else:
return builtins.min(a, b)
return builtins.min(a, b) # type: ignore[call-overload]
def sym_sum(args):
@ -1204,7 +1200,7 @@ def set_default_device(
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None:
r"""
.. warning::
@ -2007,7 +2003,7 @@ class QUInt2x4Storage(_LegacyStorage):
return torch.quint2x4
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = {
UntypedStorage,
DoubleStorage,
FloatStorage,
@ -2030,7 +2026,7 @@ _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
}
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
_tensor_classes: set[type["torch.Tensor"]] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from torch import amp as amp, random as random, serialization as serialization
@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
def __init__(self, mode, options, dynamic):
from torch._inductor.compiler_bisector import CompilerBisector
self.config: _Dict[str, _Any] = {}
self.config: dict[str, _Any] = {}
self.dynamic = dynamic
self.apply_mode(mode)
self.apply_options(options)
@ -2309,13 +2305,13 @@ class _TorchCompileInductorWrapper:
self.apply_options(list_mode_options(mode, self.dynamic))
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
def apply_options(self, options: _Optional[dict[str, _Any]]):
if not options:
return
from torch._inductor import config
current_config: _Dict[str, _Any] = config.get_config_copy()
current_config: dict[str, _Any] = config.get_config_copy()
for key, val in options.items():
attr_name = key.replace("-", "_")
@ -2403,7 +2399,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Callable[_InputT, _RetT]: ...
@ -2416,7 +2412,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
@ -2428,7 +2424,7 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
) -> _Union[
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
@ -2624,7 +2620,7 @@ if not _running_with_deploy():
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
ops_table: dict[tuple[str, str], _Callable] = {}
@classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):

View File

@ -108,7 +108,7 @@ def custom_op(
# An example usage is FakeTensor: FakeTensor checks if a specific operator
# has an implementation registered via the CustomOp API.
# Indexed by qualname (e.g. aten::foo)
global_registry: typing.Dict[str, "CustomOp"] = {}
global_registry: dict[str, "CustomOp"] = {}
class CustomOp:
@ -136,7 +136,7 @@ class CustomOp:
self.__name__ = None # mypy requires this
# NB: Some of these impls are registered as kernels to DispatchKeys.
# Modifying the _impls dict directly won't do anything in that case.
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
# See NOTE [CustomOp autograd kernel indirection]
self._registered_autograd_kernel_indirection = False
@ -476,7 +476,7 @@ def validate_schema(schema: FunctionSchema) -> None:
)
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
def parse_qualname(qualname: str) -> tuple[str, str]:
names = qualname.split("::", 1)
if len(names) != 2:
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
import itertools
import unittest.mock
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Iterator
import torch
import torch._C

View File

@ -17,13 +17,9 @@ from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
@ -260,8 +256,8 @@ class Guard:
create_fn: Callable[[GuardBuilderBase, Guard], None]
# Export only. These values are written to at time of guard check_fn creation.
guard_types: Optional[List[str]] = None
code_list: Optional[List[str]] = None
guard_types: Optional[list[str]] = None
code_list: Optional[list[str]] = None
obj_weakref: Optional[object] = None
guarded_class_weakref: Optional[type] = None
@ -448,8 +444,8 @@ overlapping with any other input, overlapping_sources represent tensors that eit
@dataclasses.dataclass
class StorageOverlap(GuardEnvExpr):
overlapping_sources: List[Source]
non_overlapping_sources: List[Source]
overlapping_sources: list[Source]
non_overlapping_sources: list[Source]
"""
@ -478,7 +474,7 @@ class GuardsCheckpointState:
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
"""
dynamo_guards: Set[Guard] = set()
dynamo_guards: set[Guard] = set()
def __init__(self, dynamo_guards):
self.dynamo_guards = dynamo_guards
@ -500,7 +496,7 @@ class GuardsCheckpointState:
class ModuleContextCheckpointState:
nn_modules: Dict[str, torch.nn.Module] = {}
nn_modules: dict[str, torch.nn.Module] = {}
def __init__(self, nn_modules):
self.nn_modules = nn_modules
@ -523,7 +519,7 @@ class ModuleContextCheckpointState:
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self) -> None:
self.nn_modules: Dict[str, Any] = {}
self.nn_modules: dict[str, Any] = {}
def copy_graphstate(self):
return ModuleContextCheckpointState(dict(self.nn_modules))
@ -534,7 +530,7 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
class GlobalContextCheckpointState:
global_state: Dict[str, Tuple[Callable, ...]] = {}
global_state: dict[str, tuple[Callable, ...]] = {}
def __init__(self, global_states):
self.global_state = global_states
@ -572,7 +568,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
}
def __init__(self) -> None:
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
self.global_state: dict[str, tuple[Callable, ...]] = {}
def copy_graphstate(self):
return GlobalContextCheckpointState(dict(self.global_state))
@ -628,7 +624,7 @@ class GuardsSet:
guard.user_stack = TracingContext.extract_stack()
self.inner.add(guard)
def update(self, *others: Set[Guard]):
def update(self, *others: set[Guard]):
for o in others:
for g in o:
self.add(g, skip=1)
@ -641,7 +637,7 @@ class GuardsSet:
class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self) -> None:
self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = []
self.aotautograd_guards: list[GuardEnvExpr] = []
def copy_graphstate(self):
return GuardsCheckpointState(set(self.dynamo_guards.inner))
@ -674,9 +670,9 @@ class HopSubgraphCache:
class InvokeSubgraphCache(HopSubgraphCache):
def __init__(self) -> None:
self.autograd_cache: Dict[str, Callable] = {}
self.proxy_dispatch_cache: Dict[str, Callable] = {}
self.dynamo_identifiers: Dict[str, str] = {}
self.autograd_cache: dict[str, Callable] = {}
self.proxy_dispatch_cache: dict[str, Callable] = {}
self.dynamo_identifiers: dict[str, str] = {}
def add_dynamo_identifier(self, cache_key: str, identifier: str):
self.dynamo_identifiers[cache_key] = identifier
@ -748,7 +744,7 @@ class CompileContext:
self.compile_id: Optional[CompileId] = compile_id
self.attempt = 0
# Verbose ShapeEnv guards produced.
self.shape_env_guards: List[str] = []
self.shape_env_guards: list[str] = []
@staticmethod
def current_compile_id():
@ -816,7 +812,7 @@ class TracingContext:
# careful not to accidentally induce guards on the SymInt if
# you ever do change this in aot_autograd.py; you should check
# on permutations preferentially.)
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
# When this is True, whenever we encounter an int in Dynamo tracing,
# we will (1) force unspec it and (2) force it as a size-like unbacked
# integer. This is currently used when processing certain lists of

View File

@ -20,7 +20,7 @@ import types
import typing
import warnings
import weakref
from typing import (
from typing import ( # noqa: F401 # (Dict, List, Tuple) imported by torch.jit.annotations
Any,
Callable,
Dict,
@ -31,7 +31,6 @@ from typing import (
List,
Optional,
Tuple,
Type,
Union,
)
@ -51,7 +50,7 @@ from torch.futures import Future
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
BuiltinUnionType: Union[type, tuple[type, ...]]
if sys.version_info >= (3, 10):
# NOTE: IS_PY310_PLUS doesn't work with mypy.
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
@ -59,7 +58,7 @@ if sys.version_info >= (3, 10):
else:
BuiltinUnionType = () # trick: this makes isinstance short circuit.
LockType: Type
LockType: type
try:
import _thread
@ -71,7 +70,7 @@ except ImportError:
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, dict[str, Callable]]" = (
weakref.WeakKeyDictionary()
) # noqa: T484
@ -225,7 +224,7 @@ def createResolutionCallbackFromEnv(lookup_base):
else:
return getattr(module, qualified_name)
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
def parseNestedExpr(expr, module) -> tuple[Any, int]:
i = 0
while i < len(expr) and expr[i] not in (",", "[", "]"):
i += 1
@ -425,7 +424,7 @@ def can_compile_class(cls) -> bool:
return all(has_code)
def get_callable_argument_names(fn) -> List[str]:
def get_callable_argument_names(fn) -> list[str]:
"""
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
Returns an empty list when other types of arguments are present.
@ -957,7 +956,7 @@ def copy_torchscript_modifier(orig, new) -> None:
# so that they can be imported in nn/functional.py without an import cycle
# qualified_name => list[overload_functions]
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
_overloaded_fns: dict[str, list[Callable]] = {} # noqa: T484
_OVERLOAD_EXAMPLE = """
@ -1042,7 +1041,7 @@ def _clear_fn_overloads(qual_name) -> None:
del _overloaded_fns[qual_name]
def get_class_name_lineno(method) -> Tuple[str, int]:
def get_class_name_lineno(method) -> tuple[str, int]:
current_frame = inspect.currentframe()
# one for the get_class_name call, one for _overload_method call
@ -1068,11 +1067,11 @@ def get_class_name_lineno(method) -> Tuple[str, int]:
# when modules of the same name are in the same file
# qualified_name => class name => list[overload_functions]
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
_overloaded_methods: dict[str, dict[str, list[Callable]]] = {} # noqa: T484
# (qualified_name, class name) => class_fileno
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
_overloaded_method_class_fileno: dict[tuple[str, str], int] = {}
def _overload_method(func):
@ -1324,8 +1323,8 @@ def _get_named_tuple_properties(
def _create_named_tuple(
t,
unqual_name: str,
field_names: List[str],
defaults: Tuple[Any, ...],
field_names: list[str],
defaults: tuple[Any, ...],
):
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
return TupleType(*t)
@ -1487,7 +1486,7 @@ def _isinstance(obj, target_type) -> bool:
class _TensorExtractor(pickle.Pickler):
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
def __init__(self, *args, tensors: list[torch.Tensor], **kwargs):
super().__init__(*args, **kwargs)
self.tensors = tensors
@ -1523,7 +1522,7 @@ def _extract_tensors(obj):
It extracts the tensors contained in the given object, through pickling.
"""
tensors: List[torch.Tensor] = []
tensors: list[torch.Tensor] = []
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
extractor.dump(obj)
return tensors

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
"""Various linear algebra utility methods for internal use."""
from typing import Optional, Tuple
from typing import Optional
import torch
from torch import Tensor
@ -57,7 +57,7 @@ def basis(A):
return torch.linalg.qr(A).Q
def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]:
"""Return eigenpairs of A with specified ordering."""
if largest is None:
largest = False
@ -79,7 +79,7 @@ def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor:
)
def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
def solve(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]:
raise RuntimeError(
"This function was deprecated since version 1.9 and is now removed. "
"`torch.solve` is deprecated in favor of `torch.linalg.solve`. "
@ -91,7 +91,7 @@ def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
)
def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
def lstsq(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]:
raise RuntimeError(
"This function was deprecated since version 1.9 and is now removed. "
"`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n"
@ -114,7 +114,7 @@ def _symeig(
upper=True,
*,
out=None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
raise RuntimeError(
"This function was deprecated since version 1.9 and is now removed. "
"The default behavior has changed from using the upper triangular portion of the matrix by default "
@ -135,7 +135,7 @@ def eig(
*,
e=None,
v=None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
raise RuntimeError(
"This function was deprecated since version 1.9 and is now removed. "
"`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors "

View File

@ -3,7 +3,7 @@
# Author: Pearu Peterson
# Created: February 2020
from typing import Dict, Optional, Tuple
from typing import Optional
import torch
from torch import _linalg_utils as _utils, Tensor
@ -268,10 +268,10 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
largest: Optional[bool] = None,
method: Optional[str] = None,
tracker: None = None,
ortho_iparams: Optional[Dict[str, int]] = None,
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
ortho_iparams: Optional[dict[str, int]] = None,
ortho_fparams: Optional[dict[str, float]] = None,
ortho_bparams: Optional[dict[str, bool]] = None,
) -> tuple[Tensor, Tensor]:
# makes sure that input is contiguous for efficiency.
# Note: autograd does not support dense gradients for sparse input yet.
A = A.contiguous() if (not A.is_sparse) else A
@ -354,10 +354,10 @@ def lobpcg(
largest: Optional[bool] = None,
method: Optional[str] = None,
tracker: None = None,
ortho_iparams: Optional[Dict[str, int]] = None,
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
ortho_iparams: Optional[dict[str, int]] = None,
ortho_fparams: Optional[dict[str, float]] = None,
ortho_bparams: Optional[dict[str, bool]] = None,
) -> tuple[Tensor, Tensor]:
"""Find the k largest (or smallest) eigenvalues and the corresponding
eigenvectors of a symmetric positive definite generalized
eigenvalue problem using matrix-free LOBPCG methods.
@ -591,10 +591,10 @@ def _lobpcg(
largest: Optional[bool] = None,
method: Optional[str] = None,
tracker: None = None,
ortho_iparams: Optional[Dict[str, int]] = None,
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
ortho_iparams: Optional[dict[str, int]] = None,
ortho_fparams: Optional[dict[str, float]] = None,
ortho_bparams: Optional[dict[str, bool]] = None,
) -> tuple[Tensor, Tensor]:
# A must be square:
assert A.shape[-2] == A.shape[-1], A.shape
if B is not None:
@ -697,9 +697,9 @@ class LOBPCG:
B: Optional[Tensor],
X: Tensor,
iK: Optional[Tensor],
iparams: Dict[str, int],
fparams: Dict[str, float],
bparams: Dict[str, bool],
iparams: dict[str, int],
fparams: dict[str, float],
bparams: dict[str, bool],
method: str,
tracker: None,
) -> None:
@ -720,10 +720,10 @@ class LOBPCG:
self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
self.tvars: Dict[str, Tensor] = {}
self.ivars: Dict[str, int] = {"istep": 0}
self.fvars: Dict[str, float] = {"_": 0.0}
self.bvars: Dict[str, bool] = {"_": False}
self.tvars: dict[str, Tensor] = {}
self.ivars: dict[str, int] = {"istep": 0}
self.fvars: dict[str, float] = {"_": 0.0}
self.bvars: dict[str, bool] = {"_": False}
def __str__(self):
lines = ["LOPBCG:"]

View File

@ -14,7 +14,7 @@ import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Optional, Union
from weakref import WeakSet
import torch._logging.structured
@ -53,37 +53,37 @@ class LogRegistry:
# Note: this only contains loggers registered
# from register_log
# e.g. "dynamo" -> "torch._dynamo"
log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict)
# artifact logger qualified names,
# this is populated lazily, as calls to getArtifactLogger
# currently formatted as <module>.__<artifact_name>
# e.g. "torch._dynamo.convert_frame.__guards"
artifact_log_qnames: Set[str] = field(default_factory=set)
artifact_log_qnames: set[str] = field(default_factory=set)
# child logs of registered logs if specified via open
# registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
# these need to be tracked so their levels can be reset properly
# e.g. "torch._dynamo.output_graph"
child_log_qnames: Set[str] = field(default_factory=set)
child_log_qnames: set[str] = field(default_factory=set)
# artifact names, populated by register_artifact
# e.g. "guards"
artifact_names: Set[str] = field(default_factory=set)
artifact_names: set[str] = field(default_factory=set)
# Artifacts that should be visible by default in the error message
visible_artifacts: Set[str] = field(default_factory=set)
visible_artifacts: set[str] = field(default_factory=set)
# A short description of each artifact
artifact_descriptions: Dict[str, str] = field(default_factory=dict)
artifact_descriptions: dict[str, str] = field(default_factory=dict)
# artifacts which are not displayed unless explicitly named in the
# settings. Ex. output_code is NOT displayed even if the inductor
# log level is set to DEBUG. It must be explicitly named in the settings
off_by_default_artifact_names: Set[str] = field(default_factory=set)
off_by_default_artifact_names: set[str] = field(default_factory=set)
# logging format string for artifacts
artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict)
def is_artifact(self, name):
return name in self.artifact_names
@ -92,7 +92,7 @@ class LogRegistry:
return alias in self.log_alias_to_log_qnames
# register a log with an alias
def register_log(self, alias, log_qnames: Union[str, List[str]]):
def register_log(self, alias, log_qnames: Union[str, list[str]]):
if isinstance(log_qnames, str):
log_qnames = [log_qnames]
self.log_alias_to_log_qnames[alias] = log_qnames
@ -124,7 +124,7 @@ class LogRegistry:
self.child_log_qnames.add(log_qname)
# flattens all the qnames together (TODO: consider memoizing?)
def get_log_qnames(self) -> Set[str]:
def get_log_qnames(self) -> set[str]:
return {
qname
for qnames in self.log_alias_to_log_qnames.values()
@ -144,10 +144,10 @@ class LogRegistry:
@dataclass
class LogState:
# qualified log names -> currently set log level
log_qname_to_level: Dict[str, str] = field(default_factory=dict)
log_qname_to_level: dict[str, str] = field(default_factory=dict)
# the set of currently enabled artifacts
artifact_names: Set[str] = field(default_factory=set)
artifact_names: set[str] = field(default_factory=set)
def enable_artifact(self, artifact_name):
self.artifact_names.add(artifact_name)
@ -235,7 +235,7 @@ def set_logs(
fusion: bool = False,
overlap: bool = False,
export: Optional[int] = None,
modules: Optional[Dict[str, Union[int, bool]]] = None,
modules: Optional[dict[str, Union[int, bool]]] = None,
cudagraphs: bool = False,
sym_node: bool = False,
compiled_autograd: bool = False,
@ -1105,7 +1105,7 @@ class LazyString:
# Logs the time it takes to do structured logging by frame/compile id
# key is always {frame_id}_{frame_compile_id}
structured_logging_overhead: Dict[str, float] = defaultdict(float)
structured_logging_overhead: dict[str, float] = defaultdict(float)
def add_structured_logging_overhead(time_spent: float) -> None:
@ -1157,7 +1157,7 @@ def trace_structured(
name: str,
# NB: metadata expected to be dict so adding more info is forward compatible
# Tuple[str, int] is a special case for string interning
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
*,
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
suppress_context: bool = False,
@ -1189,7 +1189,7 @@ def trace_structured(
# are handlers instead of checking the log level
if trace_log.handlers:
start_time = time.time_ns()
record: Dict[str, object] = {}
record: dict[str, object] = {}
record[name] = metadata_fn()
if not suppress_context:
# TODO: Actually, the rank probably should just be emitted once at
@ -1256,7 +1256,7 @@ def dtrace_structured(
name: str,
# NB: metadata expected to be dict so adding more info is forward compatible
# Tuple[str, int] is a special case for string interning
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
*,
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
suppress_context: bool = False,

View File

@ -1,4 +1,4 @@
from typing import Callable, List, Union
from typing import Callable, Union
from typing_extensions import TypeAlias
@ -8,7 +8,7 @@ try:
)
except ImportError:
TAtom: TypeAlias = Union[int, float, bool, str]
TField: TypeAlias = Union[TAtom, List[TAtom]]
TField: TypeAlias = Union[TAtom, list[TAtom]]
TLazyField: TypeAlias = Union[TField, Callable[[], TField]]
def make_scribe_logger(name: str, thrift_src: str) -> Callable[..., None]:

View File

@ -3,15 +3,16 @@ Utilities for converting data types into structured JSON for dumping.
"""
import traceback
from typing import Any, Dict, List, Sequence, Set
from collections.abc import Sequence
from typing import Any
import torch._logging._internal
INTERN_TABLE: Dict[str, int] = {}
INTERN_TABLE: dict[str, int] = {}
DUMPED_FILES: Set[str] = set()
DUMPED_FILES: set[str] = set()
def intern_string(s: str) -> int:
@ -42,7 +43,7 @@ def dump_file(filename: str) -> None:
)
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[dict[str, Any]]:
# dict naming convention here coincides with
# python/combined_traceback.cpp
r = [

View File

@ -2,7 +2,7 @@
__all__ = ["svd_lowrank", "pca_lowrank"]
from typing import Optional, Tuple
from typing import Optional
import torch
from torch import _linalg_utils as _utils, Tensor
@ -88,7 +88,7 @@ def svd_lowrank(
q: Optional[int] = 6,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
batches of matrices, or a sparse matrix :math:`A` such that
:math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
@ -152,7 +152,7 @@ def _svd_lowrank(
q: Optional[int] = 6,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
# Algorithm 5.1 in Halko et al., 2009
q = 6 if q is None else q
@ -186,7 +186,7 @@ def pca_lowrank(
q: Optional[int] = None,
center: bool = True,
niter: int = 2,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
r"""Performs linear Principal Component Analysis (PCA) on a low-rank
matrix, batches of such matrices, or sparse matrix.

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import math
from collections.abc import Sequence
from enum import Enum
from functools import wraps
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -1054,7 +1055,7 @@ def linalg_ldl_factor_ex_meta(
*,
hermitian: bool = False,
check_errors: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
LD = torch.empty_strided(
@ -1114,7 +1115,7 @@ def linalg_ldl_solve_meta(
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
@out_wrapper("P", "L", "U")
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> tuple[Tensor, Tensor, Tensor]:
torch._check(
A.ndim >= 2,
lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
@ -1147,7 +1148,7 @@ def linalg_lu_factor_ex_meta(
*,
pivot: bool = True,
check_errors: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
torch._check(
A.ndim >= 2,
lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
@ -1240,7 +1241,7 @@ def lu_unpack_meta(
pivots: Tensor,
unpack_data: bool = True,
unpack_pivots: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
torch._check(
LU.ndim >= 2,
lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
@ -1275,7 +1276,7 @@ def lu_unpack_meta(
# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
def _parse_qr_mode(mode: str) -> tuple[bool, bool]:
if mode == "reduced":
compute_q = True
reduced = True
@ -1298,7 +1299,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
@out_wrapper("Q", "R")
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> tuple[Tensor, Tensor]:
checkIsMatrix(A, "linalg.qr")
checkFloatingOrComplex(A, "linalg.qr")
@ -1326,7 +1327,7 @@ def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
@out_wrapper("sign", "logabsdet", "LU", "pivots")
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def _linalg_slogdet(A: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
squareCheckInputs(A, "linalg.slogdet")
checkFloatingOrComplex(A, "linalg.slogdet", False)
shape = A.shape
@ -1385,7 +1386,7 @@ def _linalg_svd_meta(
def _linalg_broadcast_batch_dims(
arg1: Tensor,
arg2: Tensor,
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
# broadcast the batch dimensions of arg1 and arg2.
arg1_batch_sizes = arg1.shape[:-2]
arg2_batch_sizes = arg2.shape[:-2]
@ -1403,7 +1404,7 @@ def _linalg_broadcast_batch_dims_name(
arg1: Tensor,
arg2: Tensor,
name: Optional[str],
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
# If there's no name we assume we don't want to check the errors
if name:
linearSolveCheckInputs(arg1, arg2, name)
@ -1438,7 +1439,7 @@ def _linalg_solve_ex(
LU: Optional[Tensor] = None,
pivots: Optional[Tensor] = None,
info: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
checkFloatingOrComplex(A, "linalg.solve")
torch._check(
A.dtype == B.dtype,
@ -1520,7 +1521,7 @@ def triangular_solve_meta(
upper: bool = True,
transpose: bool = False,
unitriangular: bool = False,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
torch._check(
self.ndim >= 2,
lambda: (
@ -2159,12 +2160,12 @@ def device_hint(tensor) -> "str":
def calc_conv_nd_return_shape(
input_tensor: torch.Tensor,
weight: torch.Tensor,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
stride: Union[list[int], int],
padding: Union[list[int], int],
dilation: Union[list[int], int],
is_transposed: bool,
groups: int,
output_padding: Optional[Union[List[int], int]] = None,
output_padding: Optional[Union[list[int], int]] = None,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
@ -2227,7 +2228,7 @@ def calc_conv_nd_return_shape(
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
output_padding_list: Optional[list[int]] = None
if output_padding:
if isinstance(output_padding, IntLike):
output_padding_list = [output_padding] * len(dims)
@ -2310,11 +2311,11 @@ def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
stride: list[int],
padding: list[int],
dilation: list[int],
is_transposed: bool,
output_padding: List[int],
output_padding: list[int],
groups: int,
):
def pick_memory_format():
@ -3176,7 +3177,7 @@ def meta_index_Tensor(self, indices):
torch._check(bool(indices), lambda: "at least one index must be provided")
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[Tensor]] = []
result: list[Optional[Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
torch._check(
@ -3257,9 +3258,9 @@ def meta_index_Tensor(self, indices):
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
before_shape: list[int] = []
after_shape: list[int] = []
replacement_shape: list[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
@ -3379,7 +3380,7 @@ def meta__fused_adam_(
):
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
torch._check(
isinstance(l, List),
isinstance(l, list),
lambda: f"exponent must be a tensor list but got {type(l)}",
)
@ -3405,7 +3406,7 @@ def meta__fused_adam(
):
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
torch._check(
isinstance(l, List),
isinstance(l, list),
lambda: f"exponent must be a tensor list but got {type(l)}",
)
@ -5636,7 +5637,7 @@ def meta__scaled_dot_product_efficient_backward(
philox_seed: Tensor,
philox_offset: Tensor,
dropout_p: float,
grad_input_mask: List[bool],
grad_input_mask: list[bool],
is_causal: bool = False,
scale: Optional[float] = None,
):
@ -6887,8 +6888,8 @@ def meta_local_scalar_dense(self: Tensor):
@register_meta(aten._jagged_to_padded_dense_forward.default)
def meta__jagged_to_padded_dense_forward(
values: Tensor,
offsets: List[Tensor],
max_lengths: List[int],
offsets: list[Tensor],
max_lengths: list[int],
padding_value: float = 0.0,
):
# only one jagged dim is supported for now

View File

@ -6,18 +6,7 @@ import importlib
import inspect
import sys
import types
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec
import torch
@ -79,7 +68,7 @@ class OperatorBase:
# for use with OpOverload; cache lookup is done entirely from C++
# for speed.
# TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
self._dispatch_cache: Dict[
self._dispatch_cache: dict[
DispatchKey, Union[DispatchKey, Callable[..., Any]]
] = {}
@ -90,7 +79,7 @@ class OperatorBase:
# in case you need something unusual, and don't want to clobber
# the existing registrations using the Python operator registration
# API.
self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
self.py_kernels: dict[DispatchKey, Callable[..., Any]] = {}
# This table allows you to override the behavior of a particular
# operator for a particular TorchDispatchMode. In practice,
@ -98,8 +87,8 @@ class OperatorBase:
# thought of as an open world extension of dispatch keys, so it
# makes sense that you should be able to register them, the same
# way you can register dispatch keys.
self.python_key_table: Dict[
Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
self.python_key_table: dict[
type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any]
] = {}
# This table allows you to override the behavior of functorch
@ -122,8 +111,8 @@ class OperatorBase:
def py_impl(
self,
k: Union[
Type[TorchDispatchMode],
Type[torch.Tensor],
type[TorchDispatchMode],
type[torch.Tensor],
TransformType,
DispatchKey,
],
@ -258,7 +247,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
_higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
_higher_order_ops: dict[str, "HigherOrderOperator"] = {}
_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
@ -307,8 +296,8 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
def py_impl(
self,
k: Union[
Type[TorchDispatchMode],
Type[torch.Tensor],
type[TorchDispatchMode],
type[torch.Tensor],
TransformType,
DispatchKey,
],
@ -668,7 +657,7 @@ def mode_stack_state_for_pre_dispatch():
return _mode_stack_state_for_pre_dispatch
cached_ops: Set["OpOverload"] = set()
cached_ops: set["OpOverload"] = set()
def add_cached_op(op_overload):
@ -930,7 +919,7 @@ class OpOverload(OperatorBase):
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
class TorchBindOpOverload(OpOverload):
def _fallthrough_keys(self) -> List[DispatchKey]:
def _fallthrough_keys(self) -> list[DispatchKey]:
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
# enough to the fallback in most cases that we care about.
_DEFAULT_FALLTHROUGH_KEYS = [

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import operator
import typing
import warnings
from collections.abc import Sequence
from contextlib import nullcontext
from enum import Enum
from functools import reduce
@ -15,7 +16,6 @@ from typing import (
NamedTuple,
Optional,
overload,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
@ -51,12 +51,12 @@ if TYPE_CHECKING:
_IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt)
ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
ShapeType: TypeAlias = Union[torch.Size, list[int], tuple[int, ...]]
StrideType: TypeAlias = Union[list[int], tuple[int, ...]]
DimsType: TypeAlias = Union[int, list[int], tuple[int, ...]]
DimsSequenceType: TypeAlias = Union[list[int], tuple[int, ...]]
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
NumberTypeType: TypeAlias = Union[type[bool], type[int], type[float], type[complex]]
# TODO: This needs a lot more type annotations
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
NumberType: TypeAlias = Union[bool, int, float, complex]
@ -107,7 +107,7 @@ torch_function_passthrough = {
TensorLikeType = torch.Tensor
TensorLike = torch.Tensor
TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
TensorSequenceType: TypeAlias = Union[list[TensorLikeType], tuple[TensorLikeType, ...]]
TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
CustomOutParamAnnotation = "__custom_out_param__"
@ -224,7 +224,7 @@ def _check_strides_helper(
only_cuda=True,
significant_only=True,
allow_rhs_unbacked=False,
) -> Tuple[bool, Optional[int]]:
) -> tuple[bool, Optional[int]]:
# NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
# See https://github.com/pytorch/pytorch/issues/77553
# Only compares strides that are "meaningful" -- strides for dimensions with length > 1
@ -245,7 +245,7 @@ def _check_strides_helper(
def check_significant_strides(
a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, allow_rhs_unbacked=False
) -> Tuple[bool, Optional[int]]:
) -> tuple[bool, Optional[int]]:
return _check_strides_helper(
a,
b,
@ -257,7 +257,7 @@ def check_significant_strides(
def check_all_strides(
a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
) -> Tuple[bool, Optional[int]]:
) -> tuple[bool, Optional[int]]:
return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
@ -454,7 +454,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
# short-circuit, which can cause different strides.
def compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=False
) -> List[int]:
) -> list[int]:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if not _skip_checks and len(tensors) == 0:
@ -549,7 +549,7 @@ def compute_elementwise_output_logical_to_physical_perm(
return list(reversed(perm))
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
"""
Computes the output strides for elementwise operations.
"""
@ -708,7 +708,7 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
@overload
def canonicalize_dims(
rank: int, indices: Sequence[int], wrap_scalar: bool = True
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
pass
@ -854,20 +854,20 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
# A typical case is Tensor.permute .
def extract_dims_from_varargs(
dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]]
) -> DimsSequenceType:
if dims and isinstance(dims[0], Sequence):
assert len(dims) == 1
dims = cast(Tuple[DimsSequenceType], dims)
dims = cast(tuple[DimsSequenceType], dims)
return dims[0]
else:
return cast(DimsSequenceType, dims)
def extract_shape_from_varargs(
shape: Union[ShapeType, Tuple[ShapeType]],
shape: Union[ShapeType, tuple[ShapeType]],
validate=True,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
"""
Returns a shape from varargs.
@ -895,7 +895,7 @@ def extract_shape_from_varargs(
return shape # type: ignore[return-value]
def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
def infer_size_shapes(a: ShapeType, b: ShapeType) -> tuple[int, ...]:
ndim = max(len(a), len(b))
expandedSizes = [0] * ndim
@ -920,7 +920,7 @@ def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
return tuple(expandedSizes)
def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
"""
Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements.
@ -1390,7 +1390,7 @@ class RETURN_TYPE(Enum):
# TODO: when NumberType contains the sym types, can simplify this
def number_type(
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
) -> Type:
) -> type:
if isinstance(x, torch.SymInt):
return int
elif isinstance(x, torch.SymFloat):
@ -1401,7 +1401,7 @@ def number_type(
return type(x)
def expr_type(x: sympy.Basic) -> Type:
def expr_type(x: sympy.Basic) -> type:
import sympy
if x.kind is sympy.core.kind.BooleanKind:
@ -1417,7 +1417,7 @@ def expr_type(x: sympy.Basic) -> Type:
def elementwise_dtypes(
*_args,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
) -> Tuple[torch.dtype, torch.dtype]:
) -> tuple[torch.dtype, torch.dtype]:
"""
Computes the computation and result dtypes for elementwise type promotion
on the given arguments and with the given elementwise type promotion kind.
@ -1601,7 +1601,7 @@ def reduction_dtypes(
arg,
output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.dtype, Optional[torch.dtype]]:
) -> tuple[torch.dtype, Optional[torch.dtype]]:
# even though some reductions, like amin or amax, don't strictly require type promotion,
# all the math ops (including comparisons) are still defined only for a computation type,
# so promotion will still happen. We are doing it explicitly here
@ -1628,7 +1628,7 @@ def reduction_dtypes(
# batched_matrix_contiguous_strides and contiguous_strides
def make_contiguous_strides_for(
shape: ShapeType, row_major: bool = True
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> tuple[Union[_IntLikeT, int], ...]:
"""
Returns the strides of a contiguous tensor if row_major
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
@ -1662,14 +1662,14 @@ def make_contiguous_strides_for(
def make_channels_last_1d_strides_for(
shape: Sequence[_IntLikeT],
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> tuple[Union[_IntLikeT, int], ...]:
torch._check(
len(shape) == 3,
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
)
multiplier: Union[_IntLikeT, int] = 1
strides: List[Union[_IntLikeT, int]] = [0] * 3
strides: list[Union[_IntLikeT, int]] = [0] * 3
for idx in (1, -1, 0):
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
@ -1681,7 +1681,7 @@ def make_channels_last_1d_strides_for(
def make_channels_last_2d_strides_for(
shape: Sequence[_IntLikeT],
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> tuple[Union[_IntLikeT, int], ...]:
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
torch._check(
len(shape) == 4,
@ -1689,7 +1689,7 @@ def make_channels_last_2d_strides_for(
)
multiplier: Union[_IntLikeT, int] = 1
strides: List[Union[_IntLikeT, int]] = [0] * 4
strides: list[Union[_IntLikeT, int]] = [0] * 4
for idx in (1, -1, -2, 0):
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
@ -1701,14 +1701,14 @@ def make_channels_last_2d_strides_for(
def make_channels_last_3d_strides_for(
shape: Sequence[_IntLikeT],
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> tuple[Union[_IntLikeT, int], ...]:
torch._check(
len(shape) == 5,
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
)
multiplier: Union[_IntLikeT, int] = 1
strides: List[Union[_IntLikeT, int]] = [0] * 5
strides: list[Union[_IntLikeT, int]] = [0] * 5
for idx in (1, -1, -2, -3, 0):
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
@ -1720,7 +1720,7 @@ def make_channels_last_3d_strides_for(
def make_channels_last_strides_for(
shape: Sequence[_IntLikeT],
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> tuple[Union[_IntLikeT, int], ...]:
ndim = len(shape) if isinstance(shape, Sequence) else 1
if ndim == 3:
return make_channels_last_1d_strides_for(shape)
@ -1736,7 +1736,7 @@ def make_channels_last_strides_for(
def compute_reduction_output_shape(
shape: ShapeType, dimensions: Sequence
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
for idx in dimensions:
validate_idx(len(shape), idx)
@ -1755,7 +1755,7 @@ def validate_no_repeating_dims(dims: Sequence):
raise RuntimeError("duplicate value in the list of dims")
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> tuple[int, ...]:
if dims is None:
return tuple(range(len(shape)))
dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
@ -1848,7 +1848,7 @@ def check_in_bounds_for_storage(
category=FutureWarning,
)
def check(
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
b: bool, s: Callable[[], str], exc_type: type[Exception] = RuntimeError
) -> None:
"""
Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.

View File

@ -1,18 +1,11 @@
# mypy: allow-untyped-defs
import inspect
import types
import warnings
from collections.abc import Sequence
from functools import wraps
from types import GenericAlias
from typing import (
Callable,
List,
NamedTuple,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
)
from typing import Callable, NamedTuple, Optional, overload, TypeVar
from typing_extensions import ParamSpec
import torch
@ -272,7 +265,9 @@ def out_wrapper(
bc_out_type = (
TensorLikeType
if is_tensor
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
else types.GenericAlias(
tuple, tuple(TensorLikeType for _ in range(len(out_names)))
)
)
return_type = (
TensorLikeType
@ -316,7 +311,7 @@ def out_wrapper(
)
or (
fn.__name__ == "unbind"
and isinstance(result, (List, tuple)) # type: ignore[arg-type]
and isinstance(result, (list, tuple)) # type: ignore[arg-type]
)
)
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829

View File

@ -3,7 +3,7 @@ import ast
import functools
import inspect
from textwrap import dedent
from typing import Any, List, NamedTuple, Optional, Tuple
from typing import Any, NamedTuple, Optional
from torch._C import ErrorReport
from torch._C._jit_tree_views import SourceRangeFactory
@ -12,7 +12,7 @@ from torch._C._jit_tree_views import SourceRangeFactory
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
) -> tuple[list[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
@ -35,7 +35,7 @@ def get_source_lines_and_file(
return sourcelines, file_lineno, filename
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
def normalize_source_lines(sourcelines: list[str]) -> list[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
@ -100,7 +100,7 @@ class SourceContext(SourceRangeFactory):
self.funcname = funcname
@functools.lru_cache(maxsize=None)
@functools.cache
def make_source_context(*args):
return SourceContext(*args)

View File

@ -6,7 +6,7 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch._C as _C
@ -173,8 +173,8 @@ class Tensor(torch._C.TensorBase):
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int],
Tuple[torch.qscheme, Tensor, Tensor, int],
tuple[torch.qscheme, float, int],
tuple[torch.qscheme, Tensor, Tensor, int],
]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (
@ -317,7 +317,7 @@ class Tensor(torch._C.TensorBase):
# See Note [Don't serialize hooks]
warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
backward_hooks: dict[Any, Any] = OrderedDict()
skip_data = torch.serialization._serialization_tls.skip_data
materialize_fake_tensors = (
@ -386,7 +386,7 @@ class Tensor(torch._C.TensorBase):
)
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int]
]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (
@ -750,7 +750,7 @@ class Tensor(torch._C.TensorBase):
"post accumulate grad hooks cannot be registered on non-leaf tensors"
)
if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict()
from torch.utils.hooks import RemovableHandle
@ -1493,7 +1493,7 @@ class Tensor(torch._C.TensorBase):
return self.to_sparse()
def dim_order(
self, *, ambiguity_check: Union[bool, List[torch.memory_format]] = False
self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False
):
"""
dim_order(ambiguity_check=False) -> tuple
@ -1725,7 +1725,7 @@ class Tensor(torch._C.TensorBase):
return xla_dlpack.to_dlpack(self)
return torch.to_dlpack(self)
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__dlpack_device__, (self,), self)

View File

@ -3,7 +3,7 @@ import contextlib
import dataclasses
import math
import textwrap
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch
from torch import inf
@ -95,7 +95,7 @@ def set_printoptions(
PRINT_OPTS.sci_mode = sci_mode
def get_printoptions() -> Dict[str, Any]:
def get_printoptions() -> dict[str, Any]:
r"""Gets the current options for printing, as a dictionary that
can be passed as ``**kwargs`` to set_printoptions().
"""

View File

@ -2,7 +2,6 @@
"""Adds docstrings to functions defined in the torch._C module."""
import re
from typing import Dict
import torch._C
from torch._C import _add_docstr as add_docstr
@ -171,7 +170,7 @@ rocm_fp16_notes = {
:ref:`different precision<fp16_on_mi200>` for backward."""
}
reproducibility_notes: Dict[str, str] = {
reproducibility_notes: dict[str, str] = {
"forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \
a CUDA device. See :doc:`/notes/randomness` for more information.""",
"backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \

View File

@ -6,7 +6,7 @@ import sys
import traceback
import warnings
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Generic, List, Optional, TYPE_CHECKING
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING
from typing_extensions import deprecated, ParamSpec
import torch
@ -245,7 +245,7 @@ def _rebuild_tensor_v3(
return t
_sparse_tensors_to_validate: List["torch.Tensor"] = []
_sparse_tensors_to_validate: list["torch.Tensor"] = []
# In _legacy_load() in serialization.py we unpickle storages after the sparse
@ -635,7 +635,7 @@ def _take_tensors(tensors, size_limit):
Blocks of tensors of same type and within size_limit. The yielded
tensors are only ordered as the original sequence within its types.
"""
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
buf_dict: defaultdict[str, list] = defaultdict(lambda: [[], 0])
for tensor in tensors:
t = tensor.type()
if tensor.is_sparse:
@ -674,7 +674,7 @@ def render_call(fn, args, kwargs):
if str_fn is None:
str_fn = str(fn)
str_args: List[str] = []
str_args: list[str] = []
with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
str_args.extend(repr(a) for a in args)
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
@ -986,7 +986,7 @@ class _LazySeedTracker:
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
def get_calls(self) -> list:
return self.call_order
@ -997,7 +997,7 @@ P = ParamSpec("P")
class CallbackRegistry(Generic[P]):
def __init__(self, name: str):
self.name = name
self.callback_list: List[Callable[P, None]] = []
self.callback_list: list[Callable[P, None]] = []
def add_callback(self, cb: Callable[P, None]) -> None:
self.callback_list.append(cb)

View File

@ -4,7 +4,7 @@ import logging
import os
import sys
import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, TypeVar
from typing_extensions import ParamSpec
import torch
@ -116,7 +116,7 @@ def compile_time_strobelight_meta(
#
# Killswitch is at
# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
def signpost_event(category: str, name: str, parameters: dict[str, Any]):
log.info("%s %s: %r", category, name, parameters)
@ -231,7 +231,7 @@ def max_clock_rate():
return 1100
def get_mast_job_name_version() -> Optional[Tuple[str, int]]:
def get_mast_job_name_version() -> Optional[tuple[str, int]]:
return None
@ -256,8 +256,8 @@ def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
def log_chromium_event_internal(
event: Dict[str, Any],
stack: List[str],
event: dict[str, Any],
stack: list[str],
logger_uuid: str,
start_time_ns: int,
):
@ -265,6 +265,6 @@ def log_chromium_event_internal(
def record_chromium_event_internal(
event: Dict[str, Any],
event: dict[str, Any],
):
return None

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import functools
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import deprecated
import torch
@ -8,14 +8,14 @@ from torch import Tensor
from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
in_dims_t = Union[int, tuple]
out_dims_t = Union[int, tuple[int, ...]]
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]],
flat_args: List,
flat_in_dims: list[Optional[int]],
flat_args: list,
) -> int:
batch_sizes = [
arg.size(in_dim)
@ -30,7 +30,7 @@ def _validate_and_get_batch_size(
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
@ -42,7 +42,7 @@ def _as_tuple(
value: Any,
num_elements: int,
error_message_lambda: Callable[[], str],
) -> Tuple:
) -> tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
@ -54,10 +54,10 @@ def _as_tuple(
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
in_dims: in_dims_t,
args: Tuple,
args: tuple,
vmap_level: int,
func: Callable,
) -> Tuple[Tuple, int]:
) -> tuple[tuple, int]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
@ -114,13 +114,13 @@ def _create_batched_inputs(
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int,
batch_size: int,
func: Callable,
allow_none_pass_through: bool = False,
) -> Tuple:
) -> tuple:
num_outputs = _num_outputs(batched_outputs)
out_dims_as_tuple = _as_tuple(
out_dims,

View File

@ -68,7 +68,7 @@ from pickle import (
)
from struct import unpack
from sys import maxsize
from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, Union
import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
@ -83,15 +83,15 @@ _blocklisted_modules = [
"nt",
]
_marked_safe_globals_set: Set[Union[Callable, Tuple[Callable, str]]] = set()
_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
def _add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
def _get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
global _marked_safe_globals_set
return list(_marked_safe_globals_set)
@ -102,14 +102,14 @@ def _clear_safe_globals():
def _remove_safe_globals(
globals_to_remove: List[Union[Callable, Tuple[Callable, str]]],
globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
):
global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
class _safe_globals:
def __init__(self, safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
self.safe_globals = safe_globals
def __enter__(self):
@ -127,7 +127,7 @@ class _safe_globals:
# the dynamic additions to safe_globals would not be picked up by
# _get_allowed_globals due to the lru_cache
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
rc: dict[str, Any] = {}
for f in _marked_safe_globals_set:
if isinstance(f, tuple):
if len(f) != 2:
@ -171,7 +171,7 @@ def _tensor_rebuild_functions():
# Unpickling machinery
@_functools.lru_cache(maxsize=1)
def _get_allowed_globals():
rc: Dict[str, Any] = {
rc: dict[str, Any] = {
"collections.OrderedDict": OrderedDict,
"collections.Counter": Counter,
"torch.nn.parameter.Parameter": torch.nn.Parameter,
@ -221,7 +221,7 @@ def _get_allowed_globals():
return rc
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
def _read_global_instruction(readline: Callable) -> tuple[str, str]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
@ -233,7 +233,7 @@ def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
return module, name
def get_globals_in_pkl(file) -> Set[str]:
def get_globals_in_pkl(file) -> set[str]:
globals_in_checkpoint = set()
read = file.read
readline = file.readline
@ -302,7 +302,7 @@ class Unpickler:
self.encoding = encoding
self.readline = file.readline
self.read = file.read
self.memo: Dict[int, Any] = {}
self.memo: dict[int, Any] = {}
self.proto: int = -1
def load(self):
@ -311,7 +311,7 @@ class Unpickler:
Return the reconstituted object hierarchy specified in the file.
"""
self.metastack = []
self.stack: List[Any] = []
self.stack: list[Any] = []
self.append = self.stack.append
read = self.read
while True:

View File

@ -5,11 +5,15 @@ import inspect
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
from typing import Any, cast, Optional, overload, TYPE_CHECKING, Union
import torch
if TYPE_CHECKING:
from collections.abc import Iterable
__all__ = ["OptState", "GradScaler"]
@ -21,7 +25,7 @@ class _MultiDeviceReplicator:
def __init__(self, master_tensor: torch.Tensor) -> None:
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
@ -42,7 +46,7 @@ class OptState(Enum):
STEPPED = 2
def _refresh_per_optimizer_state() -> Dict[str, Any]:
def _refresh_per_optimizer_state() -> dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
@ -147,13 +151,13 @@ class GradScaler:
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker: Optional[torch.Tensor] = None
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
self._per_optimizer_states: dict[int, dict[str, Any]] = defaultdict(
_refresh_per_optimizer_state
)
def _check_scale_growth_tracker(
self, funcname: str
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, (
f"Attempted {funcname} but _scale is None. " + fix
@ -175,11 +179,11 @@ class GradScaler:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
...
@overload
@ -210,7 +214,7 @@ class GradScaler:
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[
stash: list[
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
@ -237,7 +241,7 @@ class GradScaler:
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
) -> dict[torch.device, torch.Tensor]:
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
@ -247,8 +251,8 @@ class GradScaler:
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads: Dict[
torch.device, Dict[torch.dtype, List[torch.Tensor]]
per_device_and_dtype_grads: dict[
torch.device, dict[torch.dtype, list[torch.Tensor]]
] = defaultdict(lambda: defaultdict(list))
with torch.no_grad():
for group in optimizer.param_groups:
@ -343,7 +347,7 @@ class GradScaler:
def _maybe_opt_step(
self,
optimizer: torch.optim.Optimizer,
optimizer_state: Dict[str, Any],
optimizer_state: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Optional[float]:
@ -596,7 +600,7 @@ class GradScaler:
r"""Return a bool indicating whether this instance is enabled."""
return self._enabled
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
r"""Return the state of the scaler as a :class:`dict`.
It contains five entries:
@ -623,7 +627,7 @@ class GradScaler:
}
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
r"""Load the scaler state.
If this instance is disabled, :meth:`load_state_dict` is a no-op.
@ -650,7 +654,7 @@ class GradScaler:
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, (
@ -666,10 +670,10 @@ class GradScaler:
state["_growth_tracker"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
@ -681,5 +685,5 @@ class GradScaler:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -330,7 +330,7 @@ def backward(
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
tensors = cast(
Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,)
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
)
else:
tensors = tuple(tensors)

View File

@ -9,7 +9,6 @@ from typing import (
Any,
Callable,
cast,
Deque,
Literal,
NamedTuple,
Optional,
@ -764,7 +763,7 @@ def _register_logging_hooks_on_whole_graph(
if not roots:
return
seen: set[Node] = set()
q: Deque[Node] = deque()
q: deque[Node] = deque()
for node in roots:
if node is not None:
seen.add(node)

View File

@ -2,7 +2,6 @@
import time
from collections import defaultdict
from functools import partial
from typing import DefaultDict
import torch
@ -115,7 +114,7 @@ def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
for out, val in zip(subgraph.outputs(), node.outputs()):
value_map[val.unique()] = rec_value_map[out.unique()]
op_id_counter: DefaultDict[str, int] = defaultdict(int)
op_id_counter: defaultdict[str, int] = defaultdict(int)
def name_for(node):
kind = node.kind()[node.kind().index("::") + 2 :]

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import itertools
import operator
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.nn.functional as F
@ -154,9 +155,9 @@ def broadcast_shapes(*shapes):
def split(
tensor: Tensor,
split_size_or_sections: Union[int, List[int]],
split_size_or_sections: Union[int, list[int]],
dim: int = 0,
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
@ -421,13 +422,13 @@ def einsum(*args: Any) -> Tensor:
if TYPE_CHECKING:
# The JIT doesn't understand Union, so only add type annotation for mypy
def meshgrid(
*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
) -> Tuple[Tensor, ...]:
*tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None
) -> tuple[Tensor, ...]:
return _meshgrid(*tensors, indexing=indexing)
else:
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]:
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
This is helpful when you want to visualize data over some
@ -807,7 +808,7 @@ if TYPE_CHECKING:
# done by the caller of the _impl function
_unique_impl_out = Any
else:
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
_unique_impl_out = tuple[Tensor, Tensor, Tensor]
def _unique_impl(
@ -817,7 +818,7 @@ def _unique_impl(
return_counts: bool = False,
dim: Optional[int] = None,
) -> _unique_impl_out:
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor]
Returns the unique elements of the input tensor.
@ -1056,7 +1057,7 @@ def _return_counts(
return_counts=False,
dim=None,
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
# type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
@ -1088,7 +1089,7 @@ def _return_inverse(
return_counts=False,
dim=None,
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
# type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
@ -1140,7 +1141,7 @@ def _consecutive_return_counts(
return_counts=False,
dim=None,
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
# type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@ -1172,7 +1173,7 @@ def _consecutive_return_inverse(
return_counts=False,
dim=None,
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
# type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@ -1236,7 +1237,7 @@ else:
def tensordot( # noqa: F811
a,
b,
dims: Tuple[List[int], List[int]],
dims: tuple[list[int], list[int]],
out: Optional[torch.Tensor] = None,
):
pass
@ -1245,7 +1246,7 @@ else:
def tensordot( # noqa: F811
a,
b,
dims: List[List[int]],
dims: list[list[int]],
out: Optional[torch.Tensor] = None,
):
pass
@ -1322,13 +1323,13 @@ def tensordot( # noqa: F811
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
raise RuntimeError(
"tensordot expects dims to be int or "
+ "Tuple[List[int], List[int]] or "
+ "List[List[int]] containing two lists, but got "
+ "tuple[list[int], list[int]] or "
+ "list[list[int]] containing two lists, but got "
+ f"dims={dims}"
)
dims_a: List[int] = []
dims_b: List[int] = []
dims_a: list[int] = []
dims_b: list[int] = []
if isinstance(dims, (tuple, list)):
dims_a, dims_b = dims
@ -1337,8 +1338,8 @@ def tensordot( # noqa: F811
num_elements = dims.numel()
if num_elements > 1:
assert dims.size()[0] == 2
dims_a = torch.jit.annotate(List[int], dims[0].tolist())
dims_b = torch.jit.annotate(List[int], dims[1].tolist())
dims_a = torch.jit.annotate(list[int], dims[0].tolist())
dims_b = torch.jit.annotate(list[int], dims[1].tolist())
else:
dims_val = int(dims.item())
if dims_val < 0:
@ -1896,7 +1897,7 @@ def norm( # noqa: F811
def unravel_index(
indices: Tensor,
shape: Union[int, Sequence[int], torch.Size],
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
index into an arbitrary tensor of the specified shape.
@ -2041,7 +2042,7 @@ def chain_matmul(*matrices, out=None):
def _lu_impl(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
# type: (Tensor, bool, bool, Any) -> tuple[Tensor, Tensor, Tensor]
r"""Computes the LU factorization of a matrix or batches of matrices
:attr:`A`. Returns a tuple containing the LU factorization and
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
@ -2143,7 +2144,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
if TYPE_CHECKING:
_ListOrSeq = Sequence[Tensor]
else:
_ListOrSeq = List[Tensor]
_ListOrSeq = list[Tensor]
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
@ -2159,7 +2160,7 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
# type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor, Tensor]]) -> tuple[Tensor, Tensor, Tensor]
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
@ -2175,7 +2176,7 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
# type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor]
# need to check for torch_function here so that we exit if
if has_torch_function_unary(A):
return handle_torch_function(

View File

@ -27,7 +27,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
"""
def __init__(
self, *, devices: Optional[List[Union[int, str, torch.device]]] = None
self, *, devices: Optional[list[Union[int, str, torch.device]]] = None
):
r"""
Create an empty unset ``Future``. If the future is intended to hold
@ -278,7 +278,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
self.set_result(result) # type: ignore[arg-type]
def collect_all(futures: List[Future]) -> Future[List[Future]]:
def collect_all(futures: list[Future]) -> Future[list[Future]]:
r"""
Collects the provided :class:`~torch.futures.Future` objects into a single
combined :class:`~torch.futures.Future` that is completed when all of the
@ -305,12 +305,12 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]:
fut1 result = 1
"""
return cast(
Future[List[Future]],
torch._C._collect_all(cast(List[torch._C.Future], futures)),
Future[list[Future]],
torch._C._collect_all(cast(list[torch._C.Future], futures)),
)
def wait_all(futures: List[Future]) -> List:
def wait_all(futures: list[Future]) -> list:
r"""
Waits for all provided futures to be complete, and returns
the list of completed values. If any of the futures encounters an error,
@ -327,5 +327,5 @@ def wait_all(futures: List[Future]) -> List:
"""
return [
fut.wait()
for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()
for fut in torch._C._collect_all(cast(list[torch._C.Future], futures)).wait()
]

View File

@ -12,7 +12,7 @@ import uuid
import warnings
import zipfile
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from typing_extensions import deprecated
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse # noqa: F401
@ -784,7 +784,7 @@ def _legacy_zip_load(
model_dir: str,
map_location: MAP_LOCATION,
weights_only: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
@ -808,7 +808,7 @@ def load_state_dict_from_url(
check_hash: bool = False,
file_name: Optional[str] = None,
weights_only: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically

View File

@ -7,22 +7,22 @@
# a type attached and restored via `restore_type_tag` below. The legacy
# functions should stick around for backwards-compatibility.
from typing import List, Union
from typing import Union
def build_intlist(data: List[int]) -> List[int]:
def build_intlist(data: list[int]) -> list[int]:
return data
def build_tensorlist(data: List[object]) -> List[object]:
def build_tensorlist(data: list[object]) -> list[object]:
return data
def build_doublelist(data: List[float]) -> List[float]:
def build_doublelist(data: list[float]) -> list[float]:
return data
def build_boollist(data: List[bool]) -> List[bool]:
def build_boollist(data: list[bool]) -> list[bool]:
return data

View File

@ -6,17 +6,13 @@ import re
import sys
import traceback
import weakref
from collections.abc import Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
overload,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
@ -59,8 +55,8 @@ _P = ParamSpec("_P")
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
# libraries calling into kernels not intended to be called.
_impls: Set[str] = set()
_defs: Set[str] = set()
_impls: set[str] = set()
_defs: set[str] = set()
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ["prim"]
@ -111,9 +107,9 @@ class Library:
kind, ns, dispatch_key, filename, lineno
)
self.ns = ns
self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set()
self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
self._op_defs: set[str] = set()
self._op_impls: set[str] = set()
self._registration_handles: list[torch._library.utils.RegistrationHandle] = []
self.kind = kind
self.dispatch_key = dispatch_key
# Use a finalizer to setup the "destructor" instead of __del__.
@ -459,7 +455,7 @@ def _scoped_library(*args, **kwargs):
lib._destroy()
_keep_alive: List[Library] = []
_keep_alive: list[Library] = []
NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
@ -1362,12 +1358,12 @@ _OPCHECK_DEFAULT_UTILS = (
def opcheck(
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
raise_exception: bool = True,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Given an operator and some sample arguments, tests if the operator is
registered correctly.

View File

@ -27,8 +27,9 @@ import contextlib
import functools
import types
import warnings
from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import Any, Callable, Optional
import torch
from torch._C import (
@ -95,7 +96,7 @@ def _disable_user_warnings(
@functools.lru_cache(None)
@_disable_user_warnings
def get_ignored_functions() -> Set[Callable]:
def get_ignored_functions() -> set[Callable]:
"""
Return public functions that cannot be overridden by ``__torch_function__``.
@ -374,7 +375,7 @@ def get_ignored_functions() -> Set[Callable]:
@functools.lru_cache(None)
def get_default_nowrap_functions() -> Set[Callable]:
def get_default_nowrap_functions() -> set[Callable]:
"""
Return public functions that do not wrap in a subclass when invoked by
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
@ -401,7 +402,7 @@ def get_default_nowrap_functions() -> Set[Callable]:
@functools.lru_cache(None)
@_disable_user_warnings
def get_testing_overrides() -> Dict[Callable, Callable]:
def get_testing_overrides() -> dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions
Returns
@ -427,7 +428,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
# function signatures for native kernels that can be consumed by inspect.
# See Issue #28233.
Tensor = torch.Tensor
ret: Dict[Callable, Callable] = {
ret: dict[Callable, Callable] = {
torch.abs: lambda input, out=None: -1,
torch.absolute: lambda input, out=None: -1,
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
@ -1592,8 +1593,8 @@ def wrap_torch_function(dispatcher: Callable):
def _get_overloaded_args(
relevant_args: Iterable[Any],
get_type_fn: Optional[Callable[[Any], Type]] = None,
) -> List[Any]:
get_type_fn: Optional[Callable[[Any], type]] = None,
) -> list[Any]:
"""Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations,
@ -1634,8 +1635,8 @@ def _get_overloaded_args(
if not torch._C._is_torch_function_enabled():
return []
# Runtime is O(num_arguments * num_unique_types)
overloaded_types: Set[Type] = set()
overloaded_args: List[Any] = []
overloaded_types: set[type] = set()
overloaded_args: list[Any] = []
for arg in relevant_args:
arg_type = get_type_fn(arg)
# We only collect arguments if they have a unique type, which ensures
@ -1807,7 +1808,7 @@ has_torch_function_variadic = _add_docstr(
@functools.lru_cache(None)
def _get_overridable_functions() -> (
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
tuple[dict[Any, list[Callable]], dict[Callable, str]]
):
overridable_funcs = collections.defaultdict(list)
index = {}
@ -1893,7 +1894,7 @@ def _get_overridable_functions() -> (
@_disable_user_warnings
def get_overridable_functions() -> Dict[Any, List[Callable]]:
def get_overridable_functions() -> dict[Any, list[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
@ -1927,7 +1928,7 @@ def resolve_name(f):
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
def _get_tensor_methods() -> set[Callable]:
"""Returns a set of the overridable methods on ``torch.Tensor``"""
overridable_funcs = get_overridable_functions()
methods = set(overridable_funcs[torch.Tensor])

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
from typing import Generator
from collections.abc import Generator
import torch
from torch._C import default_generator

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
from collections.abc import Iterable
from math import sqrt
from typing import Callable, Iterable, Optional, TypeVar
from typing import Callable, Optional, TypeVar
import torch
from torch import Tensor

View File

@ -8,16 +8,7 @@ import functools
import io
import threading
import warnings
from typing import (
Any,
cast,
Dict as _Dict,
Optional as _Optional,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self
import torch
@ -42,7 +33,7 @@ except ModuleNotFoundError:
_share_memory_lock = threading.Lock()
_share_memory_map: _Dict[int, threading.RLock] = {}
_share_memory_map: dict[int, threading.RLock] = {}
T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]")
@ -136,35 +127,35 @@ class _StorageBase:
raise NotImplementedError
@classmethod
def _new_using_filename_cpu(cls: Type[T], size: _int) -> T:
def _new_using_filename_cpu(cls, size: _int) -> Self:
raise NotImplementedError
@classmethod
def _new_using_fd_cpu(cls: Type[T], size: _int) -> T:
def _new_using_fd_cpu(cls, size: _int) -> Self:
raise NotImplementedError
@classmethod
def from_buffer(cls: Type[T], *args, **kwargs) -> T:
def from_buffer(cls, *args, **kwargs) -> Self:
raise NotImplementedError
@classmethod
def _new_shared_filename_cpu(
cls: Type[T],
cls,
manager,
obj,
size,
*,
device=None,
dtype=None,
) -> T:
) -> Self:
raise NotImplementedError
@classmethod
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T:
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self:
raise NotImplementedError
@classmethod
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
def _new_with_weak_ptr(cls, *args, **kwargs) -> Self:
raise NotImplementedError
def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
@ -192,7 +183,7 @@ class _StorageBase:
raise NotImplementedError
@classmethod
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T:
def _new_shared_cuda(cls, *args, **kwargs) -> Self:
raise NotImplementedError
def _shared_incref(self, *args, **kwargs):
@ -535,7 +526,7 @@ def _load_from_bytes(b):
return torch.load(io.BytesIO(b), weights_only=False)
@functools.lru_cache(maxsize=None)
@functools.cache
def _new_dtypes():
# These are dtypes serialized as UntypedStorage unlike those in
# _dtype_to_storage_type_map
@ -556,7 +547,7 @@ def _new_dtypes():
}
@functools.lru_cache(maxsize=None)
@functools.cache
def _dtype_to_storage_type_map():
# NOTE: We should no longer add dtypes to this map. This map
# is only used for BC/FC with older PyTorch versions. Going forward,
@ -584,7 +575,7 @@ def _dtype_to_storage_type_map():
}
@functools.lru_cache(maxsize=None)
@functools.cache
def _storage_type_to_dtype_map():
dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()}
return dtype_map

View File

@ -1,4 +1,5 @@
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any
from torch._vendor.packaging.version import InvalidVersion, Version
from torch.version import __version__ as internal_version

View File

@ -12,7 +12,8 @@ from builtins import ( # noqa: F401
int as _int,
str as _str,
)
from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union
from collections.abc import Sequence
from typing import Any, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
@ -46,7 +47,7 @@ _TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047
Sequence["GradientEdge"],
]
_size: TypeAlias = Union[Size, List[int], Tuple[int, ...]] # noqa: PYI042,PYI047
_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047
_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047
_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047
@ -76,7 +77,7 @@ class Storage:
dtype: _dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo: Dict[int, Any]) -> "Storage":
def __deepcopy__(self, memo: dict[int, Any]) -> "Storage":
raise NotImplementedError
def _new_shared(self, size: int) -> "Storage":