mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - mostly toplevels (#145178)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce533867f
commit
f2cfe8b59f
@ -1,12 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
from torch.types import _bool, Tuple
|
||||
from torch.types import _bool
|
||||
|
||||
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
||||
is_cuda: _bool
|
||||
|
||||
def getRuntimeVersion() -> Tuple[int, int, int]: ...
|
||||
def getCompileVersion() -> Tuple[int, int, int]: ...
|
||||
def getRuntimeVersion() -> tuple[int, int, int]: ...
|
||||
def getCompileVersion() -> tuple[int, int, int]: ...
|
||||
def getVersionInt() -> int: ...
|
||||
|
||||
class RNNMode(int, Enum):
|
||||
|
@ -24,13 +24,9 @@ import threading
|
||||
from typing import (
|
||||
Any as _Any,
|
||||
Callable as _Callable,
|
||||
Dict as _Dict,
|
||||
get_origin as _get_origin,
|
||||
Optional as _Optional,
|
||||
overload as _overload,
|
||||
Set as _Set,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
@ -337,7 +333,7 @@ def _load_global_deps() -> None:
|
||||
except OSError as err:
|
||||
# Can only happen for wheel with cuda libs as PYPI deps
|
||||
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
||||
cuda_libs: _Dict[str, str] = {
|
||||
cuda_libs: dict[str, str] = {
|
||||
"cublas": "libcublas.so.*[0-9]",
|
||||
"cudnn": "libcudnn.so.*[0-9]",
|
||||
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
|
||||
@ -586,7 +582,7 @@ class SymInt:
|
||||
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
|
||||
# return hash(builtins.int(self))
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
|
||||
def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
|
||||
"""Represent this int as an exact integer ratio"""
|
||||
return self, 1
|
||||
|
||||
@ -698,7 +694,7 @@ class SymFloat:
|
||||
"""Return True if the float is an integer."""
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
||||
def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
|
||||
"""Represent this float as an exact integer ratio"""
|
||||
return builtins.float(self).as_integer_ratio()
|
||||
|
||||
@ -857,22 +853,22 @@ def sym_max(a, b):
|
||||
assert isinstance(a, all_types), type(a)
|
||||
assert isinstance(b, all_types), type(b)
|
||||
if isinstance(a, float_types) or isinstance(b, float_types):
|
||||
return builtins.float(builtins.max(a, b))
|
||||
return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
|
||||
else:
|
||||
return builtins.max(a, b)
|
||||
return builtins.max(a, b) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
|
||||
def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
all_types: _Tuple[_Type, ...] = (
|
||||
all_types: tuple[type, ...] = (
|
||||
np.integer,
|
||||
np.floating,
|
||||
builtins.int,
|
||||
builtins.float,
|
||||
)
|
||||
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
|
||||
float_types: tuple[type, ...] = (np.floating, builtins.float)
|
||||
except ModuleNotFoundError:
|
||||
all_types = (builtins.int, builtins.float)
|
||||
float_types = (builtins.float,)
|
||||
@ -894,9 +890,9 @@ def sym_min(a, b):
|
||||
assert isinstance(a, all_types), type(a)
|
||||
assert isinstance(b, all_types), type(b)
|
||||
if isinstance(a, float_types) or isinstance(b, float_types):
|
||||
return builtins.float(builtins.min(a, b))
|
||||
return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
|
||||
else:
|
||||
return builtins.min(a, b)
|
||||
return builtins.min(a, b) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def sym_sum(args):
|
||||
@ -1204,7 +1200,7 @@ def set_default_device(
|
||||
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
|
||||
|
||||
|
||||
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
|
||||
def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None:
|
||||
r"""
|
||||
.. warning::
|
||||
|
||||
@ -2007,7 +2003,7 @@ class QUInt2x4Storage(_LegacyStorage):
|
||||
return torch.quint2x4
|
||||
|
||||
|
||||
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
UntypedStorage,
|
||||
DoubleStorage,
|
||||
FloatStorage,
|
||||
@ -2030,7 +2026,7 @@ _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
||||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
|
||||
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
|
||||
_tensor_classes: set[type["torch.Tensor"]] = set()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from torch import amp as amp, random as random, serialization as serialization
|
||||
@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
|
||||
def __init__(self, mode, options, dynamic):
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
self.config: _Dict[str, _Any] = {}
|
||||
self.config: dict[str, _Any] = {}
|
||||
self.dynamic = dynamic
|
||||
self.apply_mode(mode)
|
||||
self.apply_options(options)
|
||||
@ -2309,13 +2305,13 @@ class _TorchCompileInductorWrapper:
|
||||
|
||||
self.apply_options(list_mode_options(mode, self.dynamic))
|
||||
|
||||
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
||||
def apply_options(self, options: _Optional[dict[str, _Any]]):
|
||||
if not options:
|
||||
return
|
||||
|
||||
from torch._inductor import config
|
||||
|
||||
current_config: _Dict[str, _Any] = config.get_config_copy()
|
||||
current_config: dict[str, _Any] = config.get_config_copy()
|
||||
|
||||
for key, val in options.items():
|
||||
attr_name = key.replace("-", "_")
|
||||
@ -2403,7 +2399,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[_InputT, _RetT]: ...
|
||||
|
||||
@ -2416,7 +2412,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
|
||||
|
||||
@ -2428,7 +2424,7 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Union[
|
||||
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
|
||||
@ -2624,7 +2620,7 @@ if not _running_with_deploy():
|
||||
|
||||
class _TritonLibrary:
|
||||
lib = torch.library.Library("triton", "DEF")
|
||||
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
|
||||
ops_table: dict[tuple[str, str], _Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
|
||||
|
@ -108,7 +108,7 @@ def custom_op(
|
||||
# An example usage is FakeTensor: FakeTensor checks if a specific operator
|
||||
# has an implementation registered via the CustomOp API.
|
||||
# Indexed by qualname (e.g. aten::foo)
|
||||
global_registry: typing.Dict[str, "CustomOp"] = {}
|
||||
global_registry: dict[str, "CustomOp"] = {}
|
||||
|
||||
|
||||
class CustomOp:
|
||||
@ -136,7 +136,7 @@ class CustomOp:
|
||||
self.__name__ = None # mypy requires this
|
||||
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
||||
# Modifying the _impls dict directly won't do anything in that case.
|
||||
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
|
||||
self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
|
||||
# See NOTE [CustomOp autograd kernel indirection]
|
||||
self._registered_autograd_kernel_indirection = False
|
||||
|
||||
@ -476,7 +476,7 @@ def validate_schema(schema: FunctionSchema) -> None:
|
||||
)
|
||||
|
||||
|
||||
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
|
||||
def parse_qualname(qualname: str) -> tuple[str, str]:
|
||||
names = qualname.split("::", 1)
|
||||
if len(names) != 2:
|
||||
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
|
||||
|
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
import unittest.mock
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
|
@ -17,13 +17,9 @@ from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -260,8 +256,8 @@ class Guard:
|
||||
create_fn: Callable[[GuardBuilderBase, Guard], None]
|
||||
|
||||
# Export only. These values are written to at time of guard check_fn creation.
|
||||
guard_types: Optional[List[str]] = None
|
||||
code_list: Optional[List[str]] = None
|
||||
guard_types: Optional[list[str]] = None
|
||||
code_list: Optional[list[str]] = None
|
||||
obj_weakref: Optional[object] = None
|
||||
guarded_class_weakref: Optional[type] = None
|
||||
|
||||
@ -448,8 +444,8 @@ overlapping with any other input, overlapping_sources represent tensors that eit
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StorageOverlap(GuardEnvExpr):
|
||||
overlapping_sources: List[Source]
|
||||
non_overlapping_sources: List[Source]
|
||||
overlapping_sources: list[Source]
|
||||
non_overlapping_sources: list[Source]
|
||||
|
||||
|
||||
"""
|
||||
@ -478,7 +474,7 @@ class GuardsCheckpointState:
|
||||
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
|
||||
"""
|
||||
|
||||
dynamo_guards: Set[Guard] = set()
|
||||
dynamo_guards: set[Guard] = set()
|
||||
|
||||
def __init__(self, dynamo_guards):
|
||||
self.dynamo_guards = dynamo_guards
|
||||
@ -500,7 +496,7 @@ class GuardsCheckpointState:
|
||||
|
||||
|
||||
class ModuleContextCheckpointState:
|
||||
nn_modules: Dict[str, torch.nn.Module] = {}
|
||||
nn_modules: dict[str, torch.nn.Module] = {}
|
||||
|
||||
def __init__(self, nn_modules):
|
||||
self.nn_modules = nn_modules
|
||||
@ -523,7 +519,7 @@ class ModuleContextCheckpointState:
|
||||
|
||||
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
def __init__(self) -> None:
|
||||
self.nn_modules: Dict[str, Any] = {}
|
||||
self.nn_modules: dict[str, Any] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||
@ -534,7 +530,7 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
|
||||
|
||||
class GlobalContextCheckpointState:
|
||||
global_state: Dict[str, Tuple[Callable, ...]] = {}
|
||||
global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
|
||||
def __init__(self, global_states):
|
||||
self.global_state = global_states
|
||||
@ -572,7 +568,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
|
||||
self.global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return GlobalContextCheckpointState(dict(self.global_state))
|
||||
@ -628,7 +624,7 @@ class GuardsSet:
|
||||
guard.user_stack = TracingContext.extract_stack()
|
||||
self.inner.add(guard)
|
||||
|
||||
def update(self, *others: Set[Guard]):
|
||||
def update(self, *others: set[Guard]):
|
||||
for o in others:
|
||||
for g in o:
|
||||
self.add(g, skip=1)
|
||||
@ -641,7 +637,7 @@ class GuardsSet:
|
||||
class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||
def __init__(self) -> None:
|
||||
self.dynamo_guards: GuardsSet = GuardsSet()
|
||||
self.aotautograd_guards: List[GuardEnvExpr] = []
|
||||
self.aotautograd_guards: list[GuardEnvExpr] = []
|
||||
|
||||
def copy_graphstate(self):
|
||||
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
||||
@ -674,9 +670,9 @@ class HopSubgraphCache:
|
||||
|
||||
class InvokeSubgraphCache(HopSubgraphCache):
|
||||
def __init__(self) -> None:
|
||||
self.autograd_cache: Dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: Dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: Dict[str, str] = {}
|
||||
self.autograd_cache: dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: dict[str, str] = {}
|
||||
|
||||
def add_dynamo_identifier(self, cache_key: str, identifier: str):
|
||||
self.dynamo_identifiers[cache_key] = identifier
|
||||
@ -748,7 +744,7 @@ class CompileContext:
|
||||
self.compile_id: Optional[CompileId] = compile_id
|
||||
self.attempt = 0
|
||||
# Verbose ShapeEnv guards produced.
|
||||
self.shape_env_guards: List[str] = []
|
||||
self.shape_env_guards: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def current_compile_id():
|
||||
@ -816,7 +812,7 @@ class TracingContext:
|
||||
# careful not to accidentally induce guards on the SymInt if
|
||||
# you ever do change this in aot_autograd.py; you should check
|
||||
# on permutations preferentially.)
|
||||
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
|
||||
self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
|
||||
# When this is True, whenever we encounter an int in Dynamo tracing,
|
||||
# we will (1) force unspec it and (2) force it as a size-like unbacked
|
||||
# integer. This is currently used when processing certain lists of
|
||||
|
@ -20,7 +20,7 @@ import types
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import (
|
||||
from typing import ( # noqa: F401 # (Dict, List, Tuple) imported by torch.jit.annotations
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -31,7 +31,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -51,7 +50,7 @@ from torch.futures import Future
|
||||
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
|
||||
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
|
||||
|
||||
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
|
||||
BuiltinUnionType: Union[type, tuple[type, ...]]
|
||||
if sys.version_info >= (3, 10):
|
||||
# NOTE: IS_PY310_PLUS doesn't work with mypy.
|
||||
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
|
||||
@ -59,7 +58,7 @@ if sys.version_info >= (3, 10):
|
||||
else:
|
||||
BuiltinUnionType = () # trick: this makes isinstance short circuit.
|
||||
|
||||
LockType: Type
|
||||
LockType: type
|
||||
try:
|
||||
import _thread
|
||||
|
||||
@ -71,7 +70,7 @@ except ImportError:
|
||||
|
||||
# Wrapper functions that can call either of 2 functions depending on a boolean
|
||||
# argument
|
||||
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
|
||||
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, dict[str, Callable]]" = (
|
||||
weakref.WeakKeyDictionary()
|
||||
) # noqa: T484
|
||||
|
||||
@ -225,7 +224,7 @@ def createResolutionCallbackFromEnv(lookup_base):
|
||||
else:
|
||||
return getattr(module, qualified_name)
|
||||
|
||||
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
|
||||
def parseNestedExpr(expr, module) -> tuple[Any, int]:
|
||||
i = 0
|
||||
while i < len(expr) and expr[i] not in (",", "[", "]"):
|
||||
i += 1
|
||||
@ -425,7 +424,7 @@ def can_compile_class(cls) -> bool:
|
||||
return all(has_code)
|
||||
|
||||
|
||||
def get_callable_argument_names(fn) -> List[str]:
|
||||
def get_callable_argument_names(fn) -> list[str]:
|
||||
"""
|
||||
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
|
||||
Returns an empty list when other types of arguments are present.
|
||||
@ -957,7 +956,7 @@ def copy_torchscript_modifier(orig, new) -> None:
|
||||
# so that they can be imported in nn/functional.py without an import cycle
|
||||
|
||||
# qualified_name => list[overload_functions]
|
||||
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
|
||||
_overloaded_fns: dict[str, list[Callable]] = {} # noqa: T484
|
||||
|
||||
|
||||
_OVERLOAD_EXAMPLE = """
|
||||
@ -1042,7 +1041,7 @@ def _clear_fn_overloads(qual_name) -> None:
|
||||
del _overloaded_fns[qual_name]
|
||||
|
||||
|
||||
def get_class_name_lineno(method) -> Tuple[str, int]:
|
||||
def get_class_name_lineno(method) -> tuple[str, int]:
|
||||
current_frame = inspect.currentframe()
|
||||
|
||||
# one for the get_class_name call, one for _overload_method call
|
||||
@ -1068,11 +1067,11 @@ def get_class_name_lineno(method) -> Tuple[str, int]:
|
||||
# when modules of the same name are in the same file
|
||||
|
||||
# qualified_name => class name => list[overload_functions]
|
||||
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
|
||||
_overloaded_methods: dict[str, dict[str, list[Callable]]] = {} # noqa: T484
|
||||
|
||||
|
||||
# (qualified_name, class name) => class_fileno
|
||||
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
|
||||
_overloaded_method_class_fileno: dict[tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def _overload_method(func):
|
||||
@ -1324,8 +1323,8 @@ def _get_named_tuple_properties(
|
||||
def _create_named_tuple(
|
||||
t,
|
||||
unqual_name: str,
|
||||
field_names: List[str],
|
||||
defaults: Tuple[Any, ...],
|
||||
field_names: list[str],
|
||||
defaults: tuple[Any, ...],
|
||||
):
|
||||
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
|
||||
return TupleType(*t)
|
||||
@ -1487,7 +1486,7 @@ def _isinstance(obj, target_type) -> bool:
|
||||
|
||||
|
||||
class _TensorExtractor(pickle.Pickler):
|
||||
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
|
||||
def __init__(self, *args, tensors: list[torch.Tensor], **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tensors = tensors
|
||||
|
||||
@ -1523,7 +1522,7 @@ def _extract_tensors(obj):
|
||||
|
||||
It extracts the tensors contained in the given object, through pickling.
|
||||
"""
|
||||
tensors: List[torch.Tensor] = []
|
||||
tensors: list[torch.Tensor] = []
|
||||
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
|
||||
extractor.dump(obj)
|
||||
return tensors
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Various linear algebra utility methods for internal use."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -57,7 +57,7 @@ def basis(A):
|
||||
return torch.linalg.qr(A).Q
|
||||
|
||||
|
||||
def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
|
||||
def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]:
|
||||
"""Return eigenpairs of A with specified ordering."""
|
||||
if largest is None:
|
||||
largest = False
|
||||
@ -79,7 +79,7 @@ def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor:
|
||||
)
|
||||
|
||||
|
||||
def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
||||
def solve(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. "
|
||||
"`torch.solve` is deprecated in favor of `torch.linalg.solve`. "
|
||||
@ -91,7 +91,7 @@ def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
||||
)
|
||||
|
||||
|
||||
def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
||||
def lstsq(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. "
|
||||
"`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n"
|
||||
@ -114,7 +114,7 @@ def _symeig(
|
||||
upper=True,
|
||||
*,
|
||||
out=None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. "
|
||||
"The default behavior has changed from using the upper triangular portion of the matrix by default "
|
||||
@ -135,7 +135,7 @@ def eig(
|
||||
*,
|
||||
e=None,
|
||||
v=None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. "
|
||||
"`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors "
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Author: Pearu Peterson
|
||||
# Created: February 2020
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _linalg_utils as _utils, Tensor
|
||||
@ -268,10 +268,10 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
|
||||
largest: Optional[bool] = None,
|
||||
method: Optional[str] = None,
|
||||
tracker: None = None,
|
||||
ortho_iparams: Optional[Dict[str, int]] = None,
|
||||
ortho_fparams: Optional[Dict[str, float]] = None,
|
||||
ortho_bparams: Optional[Dict[str, bool]] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
ortho_iparams: Optional[dict[str, int]] = None,
|
||||
ortho_fparams: Optional[dict[str, float]] = None,
|
||||
ortho_bparams: Optional[dict[str, bool]] = None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# makes sure that input is contiguous for efficiency.
|
||||
# Note: autograd does not support dense gradients for sparse input yet.
|
||||
A = A.contiguous() if (not A.is_sparse) else A
|
||||
@ -354,10 +354,10 @@ def lobpcg(
|
||||
largest: Optional[bool] = None,
|
||||
method: Optional[str] = None,
|
||||
tracker: None = None,
|
||||
ortho_iparams: Optional[Dict[str, int]] = None,
|
||||
ortho_fparams: Optional[Dict[str, float]] = None,
|
||||
ortho_bparams: Optional[Dict[str, bool]] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
ortho_iparams: Optional[dict[str, int]] = None,
|
||||
ortho_fparams: Optional[dict[str, float]] = None,
|
||||
ortho_bparams: Optional[dict[str, bool]] = None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Find the k largest (or smallest) eigenvalues and the corresponding
|
||||
eigenvectors of a symmetric positive definite generalized
|
||||
eigenvalue problem using matrix-free LOBPCG methods.
|
||||
@ -591,10 +591,10 @@ def _lobpcg(
|
||||
largest: Optional[bool] = None,
|
||||
method: Optional[str] = None,
|
||||
tracker: None = None,
|
||||
ortho_iparams: Optional[Dict[str, int]] = None,
|
||||
ortho_fparams: Optional[Dict[str, float]] = None,
|
||||
ortho_bparams: Optional[Dict[str, bool]] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
ortho_iparams: Optional[dict[str, int]] = None,
|
||||
ortho_fparams: Optional[dict[str, float]] = None,
|
||||
ortho_bparams: Optional[dict[str, bool]] = None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# A must be square:
|
||||
assert A.shape[-2] == A.shape[-1], A.shape
|
||||
if B is not None:
|
||||
@ -697,9 +697,9 @@ class LOBPCG:
|
||||
B: Optional[Tensor],
|
||||
X: Tensor,
|
||||
iK: Optional[Tensor],
|
||||
iparams: Dict[str, int],
|
||||
fparams: Dict[str, float],
|
||||
bparams: Dict[str, bool],
|
||||
iparams: dict[str, int],
|
||||
fparams: dict[str, float],
|
||||
bparams: dict[str, bool],
|
||||
method: str,
|
||||
tracker: None,
|
||||
) -> None:
|
||||
@ -720,10 +720,10 @@ class LOBPCG:
|
||||
self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
|
||||
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
|
||||
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
|
||||
self.tvars: Dict[str, Tensor] = {}
|
||||
self.ivars: Dict[str, int] = {"istep": 0}
|
||||
self.fvars: Dict[str, float] = {"_": 0.0}
|
||||
self.bvars: Dict[str, bool] = {"_": False}
|
||||
self.tvars: dict[str, Tensor] = {}
|
||||
self.ivars: dict[str, int] = {"istep": 0}
|
||||
self.fvars: dict[str, float] = {"_": 0.0}
|
||||
self.bvars: dict[str, bool] = {"_": False}
|
||||
|
||||
def __str__(self):
|
||||
lines = ["LOPBCG:"]
|
||||
|
@ -14,7 +14,7 @@ import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from weakref import WeakSet
|
||||
|
||||
import torch._logging.structured
|
||||
@ -53,37 +53,37 @@ class LogRegistry:
|
||||
# Note: this only contains loggers registered
|
||||
# from register_log
|
||||
# e.g. "dynamo" -> "torch._dynamo"
|
||||
log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
|
||||
log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict)
|
||||
|
||||
# artifact logger qualified names,
|
||||
# this is populated lazily, as calls to getArtifactLogger
|
||||
# currently formatted as <module>.__<artifact_name>
|
||||
# e.g. "torch._dynamo.convert_frame.__guards"
|
||||
artifact_log_qnames: Set[str] = field(default_factory=set)
|
||||
artifact_log_qnames: set[str] = field(default_factory=set)
|
||||
|
||||
# child logs of registered logs if specified via open
|
||||
# registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
|
||||
# these need to be tracked so their levels can be reset properly
|
||||
# e.g. "torch._dynamo.output_graph"
|
||||
child_log_qnames: Set[str] = field(default_factory=set)
|
||||
child_log_qnames: set[str] = field(default_factory=set)
|
||||
|
||||
# artifact names, populated by register_artifact
|
||||
# e.g. "guards"
|
||||
artifact_names: Set[str] = field(default_factory=set)
|
||||
artifact_names: set[str] = field(default_factory=set)
|
||||
|
||||
# Artifacts that should be visible by default in the error message
|
||||
visible_artifacts: Set[str] = field(default_factory=set)
|
||||
visible_artifacts: set[str] = field(default_factory=set)
|
||||
|
||||
# A short description of each artifact
|
||||
artifact_descriptions: Dict[str, str] = field(default_factory=dict)
|
||||
artifact_descriptions: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# artifacts which are not displayed unless explicitly named in the
|
||||
# settings. Ex. output_code is NOT displayed even if the inductor
|
||||
# log level is set to DEBUG. It must be explicitly named in the settings
|
||||
off_by_default_artifact_names: Set[str] = field(default_factory=set)
|
||||
off_by_default_artifact_names: set[str] = field(default_factory=set)
|
||||
|
||||
# logging format string for artifacts
|
||||
artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
|
||||
artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict)
|
||||
|
||||
def is_artifact(self, name):
|
||||
return name in self.artifact_names
|
||||
@ -92,7 +92,7 @@ class LogRegistry:
|
||||
return alias in self.log_alias_to_log_qnames
|
||||
|
||||
# register a log with an alias
|
||||
def register_log(self, alias, log_qnames: Union[str, List[str]]):
|
||||
def register_log(self, alias, log_qnames: Union[str, list[str]]):
|
||||
if isinstance(log_qnames, str):
|
||||
log_qnames = [log_qnames]
|
||||
self.log_alias_to_log_qnames[alias] = log_qnames
|
||||
@ -124,7 +124,7 @@ class LogRegistry:
|
||||
self.child_log_qnames.add(log_qname)
|
||||
|
||||
# flattens all the qnames together (TODO: consider memoizing?)
|
||||
def get_log_qnames(self) -> Set[str]:
|
||||
def get_log_qnames(self) -> set[str]:
|
||||
return {
|
||||
qname
|
||||
for qnames in self.log_alias_to_log_qnames.values()
|
||||
@ -144,10 +144,10 @@ class LogRegistry:
|
||||
@dataclass
|
||||
class LogState:
|
||||
# qualified log names -> currently set log level
|
||||
log_qname_to_level: Dict[str, str] = field(default_factory=dict)
|
||||
log_qname_to_level: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# the set of currently enabled artifacts
|
||||
artifact_names: Set[str] = field(default_factory=set)
|
||||
artifact_names: set[str] = field(default_factory=set)
|
||||
|
||||
def enable_artifact(self, artifact_name):
|
||||
self.artifact_names.add(artifact_name)
|
||||
@ -235,7 +235,7 @@ def set_logs(
|
||||
fusion: bool = False,
|
||||
overlap: bool = False,
|
||||
export: Optional[int] = None,
|
||||
modules: Optional[Dict[str, Union[int, bool]]] = None,
|
||||
modules: Optional[dict[str, Union[int, bool]]] = None,
|
||||
cudagraphs: bool = False,
|
||||
sym_node: bool = False,
|
||||
compiled_autograd: bool = False,
|
||||
@ -1105,7 +1105,7 @@ class LazyString:
|
||||
|
||||
# Logs the time it takes to do structured logging by frame/compile id
|
||||
# key is always {frame_id}_{frame_compile_id}
|
||||
structured_logging_overhead: Dict[str, float] = defaultdict(float)
|
||||
structured_logging_overhead: dict[str, float] = defaultdict(float)
|
||||
|
||||
|
||||
def add_structured_logging_overhead(time_spent: float) -> None:
|
||||
@ -1157,7 +1157,7 @@ def trace_structured(
|
||||
name: str,
|
||||
# NB: metadata expected to be dict so adding more info is forward compatible
|
||||
# Tuple[str, int] is a special case for string interning
|
||||
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
|
||||
metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
|
||||
*,
|
||||
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
||||
suppress_context: bool = False,
|
||||
@ -1189,7 +1189,7 @@ def trace_structured(
|
||||
# are handlers instead of checking the log level
|
||||
if trace_log.handlers:
|
||||
start_time = time.time_ns()
|
||||
record: Dict[str, object] = {}
|
||||
record: dict[str, object] = {}
|
||||
record[name] = metadata_fn()
|
||||
if not suppress_context:
|
||||
# TODO: Actually, the rank probably should just be emitted once at
|
||||
@ -1256,7 +1256,7 @@ def dtrace_structured(
|
||||
name: str,
|
||||
# NB: metadata expected to be dict so adding more info is forward compatible
|
||||
# Tuple[str, int] is a special case for string interning
|
||||
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
|
||||
metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
|
||||
*,
|
||||
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
||||
suppress_context: bool = False,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ try:
|
||||
)
|
||||
except ImportError:
|
||||
TAtom: TypeAlias = Union[int, float, bool, str]
|
||||
TField: TypeAlias = Union[TAtom, List[TAtom]]
|
||||
TField: TypeAlias = Union[TAtom, list[TAtom]]
|
||||
TLazyField: TypeAlias = Union[TField, Callable[[], TField]]
|
||||
|
||||
def make_scribe_logger(name: str, thrift_src: str) -> Callable[..., None]:
|
||||
|
@ -3,15 +3,16 @@ Utilities for converting data types into structured JSON for dumping.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Sequence, Set
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch._logging._internal
|
||||
|
||||
|
||||
INTERN_TABLE: Dict[str, int] = {}
|
||||
INTERN_TABLE: dict[str, int] = {}
|
||||
|
||||
|
||||
DUMPED_FILES: Set[str] = set()
|
||||
DUMPED_FILES: set[str] = set()
|
||||
|
||||
|
||||
def intern_string(s: str) -> int:
|
||||
@ -42,7 +43,7 @@ def dump_file(filename: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
|
||||
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[dict[str, Any]]:
|
||||
# dict naming convention here coincides with
|
||||
# python/combined_traceback.cpp
|
||||
r = [
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
__all__ = ["svd_lowrank", "pca_lowrank"]
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import _linalg_utils as _utils, Tensor
|
||||
@ -88,7 +88,7 @@ def svd_lowrank(
|
||||
q: Optional[int] = 6,
|
||||
niter: Optional[int] = 2,
|
||||
M: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
|
||||
batches of matrices, or a sparse matrix :math:`A` such that
|
||||
:math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
|
||||
@ -152,7 +152,7 @@ def _svd_lowrank(
|
||||
q: Optional[int] = 6,
|
||||
niter: Optional[int] = 2,
|
||||
M: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# Algorithm 5.1 in Halko et al., 2009
|
||||
|
||||
q = 6 if q is None else q
|
||||
@ -186,7 +186,7 @@ def pca_lowrank(
|
||||
q: Optional[int] = None,
|
||||
center: bool = True,
|
||||
niter: int = 2,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
r"""Performs linear Principal Component Analysis (PCA) on a low-rank
|
||||
matrix, batches of such matrices, or sparse matrix.
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -1054,7 +1055,7 @@ def linalg_ldl_factor_ex_meta(
|
||||
*,
|
||||
hermitian: bool = False,
|
||||
check_errors: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
|
||||
checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
|
||||
LD = torch.empty_strided(
|
||||
@ -1114,7 +1115,7 @@ def linalg_ldl_solve_meta(
|
||||
|
||||
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
|
||||
@out_wrapper("P", "L", "U")
|
||||
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> tuple[Tensor, Tensor, Tensor]:
|
||||
torch._check(
|
||||
A.ndim >= 2,
|
||||
lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
|
||||
@ -1147,7 +1148,7 @@ def linalg_lu_factor_ex_meta(
|
||||
*,
|
||||
pivot: bool = True,
|
||||
check_errors: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
torch._check(
|
||||
A.ndim >= 2,
|
||||
lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
|
||||
@ -1240,7 +1241,7 @@ def lu_unpack_meta(
|
||||
pivots: Tensor,
|
||||
unpack_data: bool = True,
|
||||
unpack_pivots: bool = True,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
torch._check(
|
||||
LU.ndim >= 2,
|
||||
lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
|
||||
@ -1275,7 +1276,7 @@ def lu_unpack_meta(
|
||||
|
||||
|
||||
# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
|
||||
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
|
||||
def _parse_qr_mode(mode: str) -> tuple[bool, bool]:
|
||||
if mode == "reduced":
|
||||
compute_q = True
|
||||
reduced = True
|
||||
@ -1298,7 +1299,7 @@ def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
|
||||
|
||||
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
|
||||
@out_wrapper("Q", "R")
|
||||
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
|
||||
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> tuple[Tensor, Tensor]:
|
||||
checkIsMatrix(A, "linalg.qr")
|
||||
checkFloatingOrComplex(A, "linalg.qr")
|
||||
|
||||
@ -1326,7 +1327,7 @@ def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
|
||||
|
||||
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
|
||||
@out_wrapper("sign", "logabsdet", "LU", "pivots")
|
||||
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
def _linalg_slogdet(A: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
squareCheckInputs(A, "linalg.slogdet")
|
||||
checkFloatingOrComplex(A, "linalg.slogdet", False)
|
||||
shape = A.shape
|
||||
@ -1385,7 +1386,7 @@ def _linalg_svd_meta(
|
||||
def _linalg_broadcast_batch_dims(
|
||||
arg1: Tensor,
|
||||
arg2: Tensor,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
# broadcast the batch dimensions of arg1 and arg2.
|
||||
arg1_batch_sizes = arg1.shape[:-2]
|
||||
arg2_batch_sizes = arg2.shape[:-2]
|
||||
@ -1403,7 +1404,7 @@ def _linalg_broadcast_batch_dims_name(
|
||||
arg1: Tensor,
|
||||
arg2: Tensor,
|
||||
name: Optional[str],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# If there's no name we assume we don't want to check the errors
|
||||
if name:
|
||||
linearSolveCheckInputs(arg1, arg2, name)
|
||||
@ -1438,7 +1439,7 @@ def _linalg_solve_ex(
|
||||
LU: Optional[Tensor] = None,
|
||||
pivots: Optional[Tensor] = None,
|
||||
info: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
checkFloatingOrComplex(A, "linalg.solve")
|
||||
torch._check(
|
||||
A.dtype == B.dtype,
|
||||
@ -1520,7 +1521,7 @@ def triangular_solve_meta(
|
||||
upper: bool = True,
|
||||
transpose: bool = False,
|
||||
unitriangular: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
torch._check(
|
||||
self.ndim >= 2,
|
||||
lambda: (
|
||||
@ -2159,12 +2160,12 @@ def device_hint(tensor) -> "str":
|
||||
def calc_conv_nd_return_shape(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
stride: Union[List[int], int],
|
||||
padding: Union[List[int], int],
|
||||
dilation: Union[List[int], int],
|
||||
stride: Union[list[int], int],
|
||||
padding: Union[list[int], int],
|
||||
dilation: Union[list[int], int],
|
||||
is_transposed: bool,
|
||||
groups: int,
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
output_padding: Optional[Union[list[int], int]] = None,
|
||||
):
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
@ -2227,7 +2228,7 @@ def calc_conv_nd_return_shape(
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
output_padding_list: Optional[list[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, IntLike):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
@ -2310,11 +2311,11 @@ def meta_conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
output_padding: list[int],
|
||||
groups: int,
|
||||
):
|
||||
def pick_memory_format():
|
||||
@ -3176,7 +3177,7 @@ def meta_index_Tensor(self, indices):
|
||||
torch._check(bool(indices), lambda: "at least one index must be provided")
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[Tensor]] = []
|
||||
result: list[Optional[Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
torch._check(
|
||||
@ -3257,9 +3258,9 @@ def meta_index_Tensor(self, indices):
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
before_shape: list[int] = []
|
||||
after_shape: list[int] = []
|
||||
replacement_shape: list[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
@ -3379,7 +3380,7 @@ def meta__fused_adam_(
|
||||
):
|
||||
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
|
||||
torch._check(
|
||||
isinstance(l, List),
|
||||
isinstance(l, list),
|
||||
lambda: f"exponent must be a tensor list but got {type(l)}",
|
||||
)
|
||||
|
||||
@ -3405,7 +3406,7 @@ def meta__fused_adam(
|
||||
):
|
||||
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
|
||||
torch._check(
|
||||
isinstance(l, List),
|
||||
isinstance(l, list),
|
||||
lambda: f"exponent must be a tensor list but got {type(l)}",
|
||||
)
|
||||
|
||||
@ -5636,7 +5637,7 @@ def meta__scaled_dot_product_efficient_backward(
|
||||
philox_seed: Tensor,
|
||||
philox_offset: Tensor,
|
||||
dropout_p: float,
|
||||
grad_input_mask: List[bool],
|
||||
grad_input_mask: list[bool],
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
@ -6887,8 +6888,8 @@ def meta_local_scalar_dense(self: Tensor):
|
||||
@register_meta(aten._jagged_to_padded_dense_forward.default)
|
||||
def meta__jagged_to_padded_dense_forward(
|
||||
values: Tensor,
|
||||
offsets: List[Tensor],
|
||||
max_lengths: List[int],
|
||||
offsets: list[Tensor],
|
||||
max_lengths: list[int],
|
||||
padding_value: float = 0.0,
|
||||
):
|
||||
# only one jagged dim is supported for now
|
||||
|
@ -6,18 +6,7 @@ import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
@ -79,7 +68,7 @@ class OperatorBase:
|
||||
# for use with OpOverload; cache lookup is done entirely from C++
|
||||
# for speed.
|
||||
# TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
|
||||
self._dispatch_cache: Dict[
|
||||
self._dispatch_cache: dict[
|
||||
DispatchKey, Union[DispatchKey, Callable[..., Any]]
|
||||
] = {}
|
||||
|
||||
@ -90,7 +79,7 @@ class OperatorBase:
|
||||
# in case you need something unusual, and don't want to clobber
|
||||
# the existing registrations using the Python operator registration
|
||||
# API.
|
||||
self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
|
||||
self.py_kernels: dict[DispatchKey, Callable[..., Any]] = {}
|
||||
|
||||
# This table allows you to override the behavior of a particular
|
||||
# operator for a particular TorchDispatchMode. In practice,
|
||||
@ -98,8 +87,8 @@ class OperatorBase:
|
||||
# thought of as an open world extension of dispatch keys, so it
|
||||
# makes sense that you should be able to register them, the same
|
||||
# way you can register dispatch keys.
|
||||
self.python_key_table: Dict[
|
||||
Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
|
||||
self.python_key_table: dict[
|
||||
type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any]
|
||||
] = {}
|
||||
|
||||
# This table allows you to override the behavior of functorch
|
||||
@ -122,8 +111,8 @@ class OperatorBase:
|
||||
def py_impl(
|
||||
self,
|
||||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
type[TorchDispatchMode],
|
||||
type[torch.Tensor],
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
@ -258,7 +247,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
|
||||
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
|
||||
|
||||
|
||||
_higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
|
||||
_higher_order_ops: dict[str, "HigherOrderOperator"] = {}
|
||||
|
||||
_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
|
||||
DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
|
||||
@ -307,8 +296,8 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
def py_impl(
|
||||
self,
|
||||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
type[TorchDispatchMode],
|
||||
type[torch.Tensor],
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
@ -668,7 +657,7 @@ def mode_stack_state_for_pre_dispatch():
|
||||
return _mode_stack_state_for_pre_dispatch
|
||||
|
||||
|
||||
cached_ops: Set["OpOverload"] = set()
|
||||
cached_ops: set["OpOverload"] = set()
|
||||
|
||||
|
||||
def add_cached_op(op_overload):
|
||||
@ -930,7 +919,7 @@ class OpOverload(OperatorBase):
|
||||
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
|
||||
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
|
||||
class TorchBindOpOverload(OpOverload):
|
||||
def _fallthrough_keys(self) -> List[DispatchKey]:
|
||||
def _fallthrough_keys(self) -> list[DispatchKey]:
|
||||
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
|
||||
# enough to the fallback in most cases that we care about.
|
||||
_DEFAULT_FALLTHROUGH_KEYS = [
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import operator
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
@ -15,7 +16,6 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
@ -51,12 +51,12 @@ if TYPE_CHECKING:
|
||||
_IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt)
|
||||
|
||||
|
||||
ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
|
||||
DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
|
||||
DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
|
||||
ShapeType: TypeAlias = Union[torch.Size, list[int], tuple[int, ...]]
|
||||
StrideType: TypeAlias = Union[list[int], tuple[int, ...]]
|
||||
DimsType: TypeAlias = Union[int, list[int], tuple[int, ...]]
|
||||
DimsSequenceType: TypeAlias = Union[list[int], tuple[int, ...]]
|
||||
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
|
||||
NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
|
||||
NumberTypeType: TypeAlias = Union[type[bool], type[int], type[float], type[complex]]
|
||||
# TODO: This needs a lot more type annotations
|
||||
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
|
||||
NumberType: TypeAlias = Union[bool, int, float, complex]
|
||||
@ -107,7 +107,7 @@ torch_function_passthrough = {
|
||||
|
||||
TensorLikeType = torch.Tensor
|
||||
TensorLike = torch.Tensor
|
||||
TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
|
||||
TensorSequenceType: TypeAlias = Union[list[TensorLikeType], tuple[TensorLikeType, ...]]
|
||||
TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
|
||||
|
||||
CustomOutParamAnnotation = "__custom_out_param__"
|
||||
@ -224,7 +224,7 @@ def _check_strides_helper(
|
||||
only_cuda=True,
|
||||
significant_only=True,
|
||||
allow_rhs_unbacked=False,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
) -> tuple[bool, Optional[int]]:
|
||||
# NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
|
||||
# See https://github.com/pytorch/pytorch/issues/77553
|
||||
# Only compares strides that are "meaningful" -- strides for dimensions with length > 1
|
||||
@ -245,7 +245,7 @@ def _check_strides_helper(
|
||||
|
||||
def check_significant_strides(
|
||||
a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, allow_rhs_unbacked=False
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
) -> tuple[bool, Optional[int]]:
|
||||
return _check_strides_helper(
|
||||
a,
|
||||
b,
|
||||
@ -257,7 +257,7 @@ def check_significant_strides(
|
||||
|
||||
def check_all_strides(
|
||||
a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
) -> tuple[bool, Optional[int]]:
|
||||
return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
|
||||
|
||||
|
||||
@ -454,7 +454,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
# short-circuit, which can cause different strides.
|
||||
def compute_elementwise_output_logical_to_physical_perm(
|
||||
*tensors, _skip_checks=False
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if not _skip_checks and len(tensors) == 0:
|
||||
@ -549,7 +549,7 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||
return list(reversed(perm))
|
||||
|
||||
|
||||
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
||||
"""
|
||||
Computes the output strides for elementwise operations.
|
||||
"""
|
||||
@ -708,7 +708,7 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
|
||||
@overload
|
||||
def canonicalize_dims(
|
||||
rank: int, indices: Sequence[int], wrap_scalar: bool = True
|
||||
) -> Tuple[int, ...]:
|
||||
) -> tuple[int, ...]:
|
||||
pass
|
||||
|
||||
|
||||
@ -854,20 +854,20 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
|
||||
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
|
||||
# A typical case is Tensor.permute .
|
||||
def extract_dims_from_varargs(
|
||||
dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
|
||||
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]]
|
||||
) -> DimsSequenceType:
|
||||
if dims and isinstance(dims[0], Sequence):
|
||||
assert len(dims) == 1
|
||||
dims = cast(Tuple[DimsSequenceType], dims)
|
||||
dims = cast(tuple[DimsSequenceType], dims)
|
||||
return dims[0]
|
||||
else:
|
||||
return cast(DimsSequenceType, dims)
|
||||
|
||||
|
||||
def extract_shape_from_varargs(
|
||||
shape: Union[ShapeType, Tuple[ShapeType]],
|
||||
shape: Union[ShapeType, tuple[ShapeType]],
|
||||
validate=True,
|
||||
) -> Tuple[int, ...]:
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Returns a shape from varargs.
|
||||
|
||||
@ -895,7 +895,7 @@ def extract_shape_from_varargs(
|
||||
return shape # type: ignore[return-value]
|
||||
|
||||
|
||||
def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
|
||||
def infer_size_shapes(a: ShapeType, b: ShapeType) -> tuple[int, ...]:
|
||||
ndim = max(len(a), len(b))
|
||||
expandedSizes = [0] * ndim
|
||||
|
||||
@ -920,7 +920,7 @@ def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
|
||||
return tuple(expandedSizes)
|
||||
|
||||
|
||||
def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
|
||||
def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||
"""
|
||||
Infers the size of a dim with size -1, if it exists.
|
||||
Also checks that new shape is compatible with the number of elements.
|
||||
@ -1390,7 +1390,7 @@ class RETURN_TYPE(Enum):
|
||||
# TODO: when NumberType contains the sym types, can simplify this
|
||||
def number_type(
|
||||
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
|
||||
) -> Type:
|
||||
) -> type:
|
||||
if isinstance(x, torch.SymInt):
|
||||
return int
|
||||
elif isinstance(x, torch.SymFloat):
|
||||
@ -1401,7 +1401,7 @@ def number_type(
|
||||
return type(x)
|
||||
|
||||
|
||||
def expr_type(x: sympy.Basic) -> Type:
|
||||
def expr_type(x: sympy.Basic) -> type:
|
||||
import sympy
|
||||
|
||||
if x.kind is sympy.core.kind.BooleanKind:
|
||||
@ -1417,7 +1417,7 @@ def expr_type(x: sympy.Basic) -> Type:
|
||||
def elementwise_dtypes(
|
||||
*_args,
|
||||
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
) -> Tuple[torch.dtype, torch.dtype]:
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
"""
|
||||
Computes the computation and result dtypes for elementwise type promotion
|
||||
on the given arguments and with the given elementwise type promotion kind.
|
||||
@ -1601,7 +1601,7 @@ def reduction_dtypes(
|
||||
arg,
|
||||
output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.dtype, Optional[torch.dtype]]:
|
||||
) -> tuple[torch.dtype, Optional[torch.dtype]]:
|
||||
# even though some reductions, like amin or amax, don't strictly require type promotion,
|
||||
# all the math ops (including comparisons) are still defined only for a computation type,
|
||||
# so promotion will still happen. We are doing it explicitly here
|
||||
@ -1628,7 +1628,7 @@ def reduction_dtypes(
|
||||
# batched_matrix_contiguous_strides and contiguous_strides
|
||||
def make_contiguous_strides_for(
|
||||
shape: ShapeType, row_major: bool = True
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> tuple[Union[_IntLikeT, int], ...]:
|
||||
"""
|
||||
Returns the strides of a contiguous tensor if row_major
|
||||
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
|
||||
@ -1662,14 +1662,14 @@ def make_contiguous_strides_for(
|
||||
|
||||
def make_channels_last_1d_strides_for(
|
||||
shape: Sequence[_IntLikeT],
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> tuple[Union[_IntLikeT, int], ...]:
|
||||
torch._check(
|
||||
len(shape) == 3,
|
||||
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
|
||||
)
|
||||
|
||||
multiplier: Union[_IntLikeT, int] = 1
|
||||
strides: List[Union[_IntLikeT, int]] = [0] * 3
|
||||
strides: list[Union[_IntLikeT, int]] = [0] * 3
|
||||
for idx in (1, -1, 0):
|
||||
# NOTE: intentionally divergence from make_contiguous_strides_for
|
||||
# This is consistent with eager
|
||||
@ -1681,7 +1681,7 @@ def make_channels_last_1d_strides_for(
|
||||
|
||||
def make_channels_last_2d_strides_for(
|
||||
shape: Sequence[_IntLikeT],
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> tuple[Union[_IntLikeT, int], ...]:
|
||||
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
|
||||
torch._check(
|
||||
len(shape) == 4,
|
||||
@ -1689,7 +1689,7 @@ def make_channels_last_2d_strides_for(
|
||||
)
|
||||
|
||||
multiplier: Union[_IntLikeT, int] = 1
|
||||
strides: List[Union[_IntLikeT, int]] = [0] * 4
|
||||
strides: list[Union[_IntLikeT, int]] = [0] * 4
|
||||
for idx in (1, -1, -2, 0):
|
||||
# NOTE: intentionally divergence from make_contiguous_strides_for
|
||||
# This is consistent with eager
|
||||
@ -1701,14 +1701,14 @@ def make_channels_last_2d_strides_for(
|
||||
|
||||
def make_channels_last_3d_strides_for(
|
||||
shape: Sequence[_IntLikeT],
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> tuple[Union[_IntLikeT, int], ...]:
|
||||
torch._check(
|
||||
len(shape) == 5,
|
||||
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
|
||||
)
|
||||
|
||||
multiplier: Union[_IntLikeT, int] = 1
|
||||
strides: List[Union[_IntLikeT, int]] = [0] * 5
|
||||
strides: list[Union[_IntLikeT, int]] = [0] * 5
|
||||
for idx in (1, -1, -2, -3, 0):
|
||||
# NOTE: intentionally divergence from make_contiguous_strides_for
|
||||
# This is consistent with eager
|
||||
@ -1720,7 +1720,7 @@ def make_channels_last_3d_strides_for(
|
||||
|
||||
def make_channels_last_strides_for(
|
||||
shape: Sequence[_IntLikeT],
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> tuple[Union[_IntLikeT, int], ...]:
|
||||
ndim = len(shape) if isinstance(shape, Sequence) else 1
|
||||
if ndim == 3:
|
||||
return make_channels_last_1d_strides_for(shape)
|
||||
@ -1736,7 +1736,7 @@ def make_channels_last_strides_for(
|
||||
|
||||
def compute_reduction_output_shape(
|
||||
shape: ShapeType, dimensions: Sequence
|
||||
) -> Tuple[int, ...]:
|
||||
) -> tuple[int, ...]:
|
||||
for idx in dimensions:
|
||||
validate_idx(len(shape), idx)
|
||||
|
||||
@ -1755,7 +1755,7 @@ def validate_no_repeating_dims(dims: Sequence):
|
||||
raise RuntimeError("duplicate value in the list of dims")
|
||||
|
||||
|
||||
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
|
||||
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> tuple[int, ...]:
|
||||
if dims is None:
|
||||
return tuple(range(len(shape)))
|
||||
dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
|
||||
@ -1848,7 +1848,7 @@ def check_in_bounds_for_storage(
|
||||
category=FutureWarning,
|
||||
)
|
||||
def check(
|
||||
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
|
||||
b: bool, s: Callable[[], str], exc_type: type[Exception] = RuntimeError
|
||||
) -> None:
|
||||
"""
|
||||
Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
|
||||
|
@ -1,18 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
Callable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Callable, NamedTuple, Optional, overload, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -272,7 +265,9 @@ def out_wrapper(
|
||||
bc_out_type = (
|
||||
TensorLikeType
|
||||
if is_tensor
|
||||
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
|
||||
else types.GenericAlias(
|
||||
tuple, tuple(TensorLikeType for _ in range(len(out_names)))
|
||||
)
|
||||
)
|
||||
return_type = (
|
||||
TensorLikeType
|
||||
@ -316,7 +311,7 @@ def out_wrapper(
|
||||
)
|
||||
or (
|
||||
fn.__name__ == "unbind"
|
||||
and isinstance(result, (List, tuple)) # type: ignore[arg-type]
|
||||
and isinstance(result, (list, tuple)) # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
|
||||
|
@ -3,7 +3,7 @@ import ast
|
||||
import functools
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
from torch._C import ErrorReport
|
||||
from torch._C._jit_tree_views import SourceRangeFactory
|
||||
@ -12,7 +12,7 @@ from torch._C._jit_tree_views import SourceRangeFactory
|
||||
def get_source_lines_and_file(
|
||||
obj: Any,
|
||||
error_msg: Optional[str] = None,
|
||||
) -> Tuple[List[str], int, Optional[str]]:
|
||||
) -> tuple[list[str], int, Optional[str]]:
|
||||
"""
|
||||
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
|
||||
|
||||
@ -35,7 +35,7 @@ def get_source_lines_and_file(
|
||||
return sourcelines, file_lineno, filename
|
||||
|
||||
|
||||
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
||||
def normalize_source_lines(sourcelines: list[str]) -> list[str]:
|
||||
"""
|
||||
This helper function accepts a list of source lines. It finds the
|
||||
indentation level of the function definition (`def`), then it indents
|
||||
@ -100,7 +100,7 @@ class SourceContext(SourceRangeFactory):
|
||||
self.funcname = funcname
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def make_source_context(*args):
|
||||
return SourceContext(*args)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
@ -173,8 +173,8 @@ class Tensor(torch._C.TensorBase):
|
||||
if self.is_quantized:
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[
|
||||
Tuple[torch.qscheme, float, int],
|
||||
Tuple[torch.qscheme, Tensor, Tensor, int],
|
||||
tuple[torch.qscheme, float, int],
|
||||
tuple[torch.qscheme, Tensor, Tensor, int],
|
||||
]
|
||||
if self.qscheme() == torch.per_tensor_affine:
|
||||
quantizer_params = (
|
||||
@ -317,7 +317,7 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
# See Note [Don't serialize hooks]
|
||||
warn_if_has_hooks(self)
|
||||
backward_hooks: Dict[Any, Any] = OrderedDict()
|
||||
backward_hooks: dict[Any, Any] = OrderedDict()
|
||||
|
||||
skip_data = torch.serialization._serialization_tls.skip_data
|
||||
materialize_fake_tensors = (
|
||||
@ -386,7 +386,7 @@ class Tensor(torch._C.TensorBase):
|
||||
)
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[
|
||||
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
|
||||
tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int]
|
||||
]
|
||||
if self.qscheme() == torch.per_tensor_affine:
|
||||
quantizer_params = (
|
||||
@ -750,7 +750,7 @@ class Tensor(torch._C.TensorBase):
|
||||
"post accumulate grad hooks cannot be registered on non-leaf tensors"
|
||||
)
|
||||
if self._post_accumulate_grad_hooks is None:
|
||||
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
|
||||
self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict()
|
||||
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
@ -1493,7 +1493,7 @@ class Tensor(torch._C.TensorBase):
|
||||
return self.to_sparse()
|
||||
|
||||
def dim_order(
|
||||
self, *, ambiguity_check: Union[bool, List[torch.memory_format]] = False
|
||||
self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False
|
||||
):
|
||||
"""
|
||||
dim_order(ambiguity_check=False) -> tuple
|
||||
@ -1725,7 +1725,7 @@ class Tensor(torch._C.TensorBase):
|
||||
return xla_dlpack.to_dlpack(self)
|
||||
return torch.to_dlpack(self)
|
||||
|
||||
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
|
||||
def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import contextlib
|
||||
import dataclasses
|
||||
import math
|
||||
import textwrap
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import inf
|
||||
@ -95,7 +95,7 @@ def set_printoptions(
|
||||
PRINT_OPTS.sci_mode = sci_mode
|
||||
|
||||
|
||||
def get_printoptions() -> Dict[str, Any]:
|
||||
def get_printoptions() -> dict[str, Any]:
|
||||
r"""Gets the current options for printing, as a dictionary that
|
||||
can be passed as ``**kwargs`` to set_printoptions().
|
||||
"""
|
||||
|
@ -2,7 +2,6 @@
|
||||
"""Adds docstrings to functions defined in the torch._C module."""
|
||||
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import torch._C
|
||||
from torch._C import _add_docstr as add_docstr
|
||||
@ -171,7 +170,7 @@ rocm_fp16_notes = {
|
||||
:ref:`different precision<fp16_on_mi200>` for backward."""
|
||||
}
|
||||
|
||||
reproducibility_notes: Dict[str, str] = {
|
||||
reproducibility_notes: dict[str, str] = {
|
||||
"forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \
|
||||
a CUDA device. See :doc:`/notes/randomness` for more information.""",
|
||||
"backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, DefaultDict, Generic, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
import torch
|
||||
@ -245,7 +245,7 @@ def _rebuild_tensor_v3(
|
||||
return t
|
||||
|
||||
|
||||
_sparse_tensors_to_validate: List["torch.Tensor"] = []
|
||||
_sparse_tensors_to_validate: list["torch.Tensor"] = []
|
||||
|
||||
|
||||
# In _legacy_load() in serialization.py we unpickle storages after the sparse
|
||||
@ -635,7 +635,7 @@ def _take_tensors(tensors, size_limit):
|
||||
Blocks of tensors of same type and within size_limit. The yielded
|
||||
tensors are only ordered as the original sequence within its types.
|
||||
"""
|
||||
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
|
||||
buf_dict: defaultdict[str, list] = defaultdict(lambda: [[], 0])
|
||||
for tensor in tensors:
|
||||
t = tensor.type()
|
||||
if tensor.is_sparse:
|
||||
@ -674,7 +674,7 @@ def render_call(fn, args, kwargs):
|
||||
if str_fn is None:
|
||||
str_fn = str(fn)
|
||||
|
||||
str_args: List[str] = []
|
||||
str_args: list[str] = []
|
||||
with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
|
||||
str_args.extend(repr(a) for a in args)
|
||||
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
|
||||
@ -986,7 +986,7 @@ class _LazySeedTracker:
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
def get_calls(self) -> list:
|
||||
return self.call_order
|
||||
|
||||
|
||||
@ -997,7 +997,7 @@ P = ParamSpec("P")
|
||||
class CallbackRegistry(Generic[P]):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.callback_list: List[Callable[P, None]] = []
|
||||
self.callback_list: list[Callable[P, None]] = []
|
||||
|
||||
def add_callback(self, cb: Callable[P, None]) -> None:
|
||||
self.callback_list.append(cb)
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -116,7 +116,7 @@ def compile_time_strobelight_meta(
|
||||
#
|
||||
# Killswitch is at
|
||||
# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
|
||||
def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
|
||||
def signpost_event(category: str, name: str, parameters: dict[str, Any]):
|
||||
log.info("%s %s: %r", category, name, parameters)
|
||||
|
||||
|
||||
@ -231,7 +231,7 @@ def max_clock_rate():
|
||||
return 1100
|
||||
|
||||
|
||||
def get_mast_job_name_version() -> Optional[Tuple[str, int]]:
|
||||
def get_mast_job_name_version() -> Optional[tuple[str, int]]:
|
||||
return None
|
||||
|
||||
|
||||
@ -256,8 +256,8 @@ def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
|
||||
|
||||
|
||||
def log_chromium_event_internal(
|
||||
event: Dict[str, Any],
|
||||
stack: List[str],
|
||||
event: dict[str, Any],
|
||||
stack: list[str],
|
||||
logger_uuid: str,
|
||||
start_time_ns: int,
|
||||
):
|
||||
@ -265,6 +265,6 @@ def log_chromium_event_internal(
|
||||
|
||||
|
||||
def record_chromium_event_internal(
|
||||
event: Dict[str, Any],
|
||||
event: dict[str, Any],
|
||||
):
|
||||
return None
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -8,14 +8,14 @@ from torch import Tensor
|
||||
from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
out_dims_t = Union[int, Tuple[int, ...]]
|
||||
in_dims_t = Union[int, tuple]
|
||||
out_dims_t = Union[int, tuple[int, ...]]
|
||||
|
||||
|
||||
# Checks that all args-to-be-batched have the same batch dim size
|
||||
def _validate_and_get_batch_size(
|
||||
flat_in_dims: List[Optional[int]],
|
||||
flat_args: List,
|
||||
flat_in_dims: list[Optional[int]],
|
||||
flat_args: list,
|
||||
) -> int:
|
||||
batch_sizes = [
|
||||
arg.size(in_dim)
|
||||
@ -30,7 +30,7 @@ def _validate_and_get_batch_size(
|
||||
return batch_sizes[0]
|
||||
|
||||
|
||||
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
|
||||
if isinstance(batched_outputs, tuple):
|
||||
return len(batched_outputs)
|
||||
return 1
|
||||
@ -42,7 +42,7 @@ def _as_tuple(
|
||||
value: Any,
|
||||
num_elements: int,
|
||||
error_message_lambda: Callable[[], str],
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
if len(value) != num_elements:
|
||||
@ -54,10 +54,10 @@ def _as_tuple(
|
||||
# Returns the (potentially) batched arguments and the batch_size.
|
||||
def _create_batched_inputs(
|
||||
in_dims: in_dims_t,
|
||||
args: Tuple,
|
||||
args: tuple,
|
||||
vmap_level: int,
|
||||
func: Callable,
|
||||
) -> Tuple[Tuple, int]:
|
||||
) -> tuple[tuple, int]:
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
|
||||
@ -114,13 +114,13 @@ def _create_batched_inputs(
|
||||
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int,
|
||||
batch_size: int,
|
||||
func: Callable,
|
||||
allow_none_pass_through: bool = False,
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims,
|
||||
|
@ -68,7 +68,7 @@ from pickle import (
|
||||
)
|
||||
from struct import unpack
|
||||
from sys import maxsize
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
||||
@ -83,15 +83,15 @@ _blocklisted_modules = [
|
||||
"nt",
|
||||
]
|
||||
|
||||
_marked_safe_globals_set: Set[Union[Callable, Tuple[Callable, str]]] = set()
|
||||
_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
|
||||
|
||||
|
||||
def _add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
|
||||
def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
||||
global _marked_safe_globals_set
|
||||
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
|
||||
|
||||
|
||||
def _get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
|
||||
def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
|
||||
global _marked_safe_globals_set
|
||||
return list(_marked_safe_globals_set)
|
||||
|
||||
@ -102,14 +102,14 @@ def _clear_safe_globals():
|
||||
|
||||
|
||||
def _remove_safe_globals(
|
||||
globals_to_remove: List[Union[Callable, Tuple[Callable, str]]],
|
||||
globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
|
||||
):
|
||||
global _marked_safe_globals_set
|
||||
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
|
||||
|
||||
|
||||
class _safe_globals:
|
||||
def __init__(self, safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
|
||||
def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
||||
self.safe_globals = safe_globals
|
||||
|
||||
def __enter__(self):
|
||||
@ -127,7 +127,7 @@ class _safe_globals:
|
||||
# the dynamic additions to safe_globals would not be picked up by
|
||||
# _get_allowed_globals due to the lru_cache
|
||||
def _get_user_allowed_globals():
|
||||
rc: Dict[str, Any] = {}
|
||||
rc: dict[str, Any] = {}
|
||||
for f in _marked_safe_globals_set:
|
||||
if isinstance(f, tuple):
|
||||
if len(f) != 2:
|
||||
@ -171,7 +171,7 @@ def _tensor_rebuild_functions():
|
||||
# Unpickling machinery
|
||||
@_functools.lru_cache(maxsize=1)
|
||||
def _get_allowed_globals():
|
||||
rc: Dict[str, Any] = {
|
||||
rc: dict[str, Any] = {
|
||||
"collections.OrderedDict": OrderedDict,
|
||||
"collections.Counter": Counter,
|
||||
"torch.nn.parameter.Parameter": torch.nn.Parameter,
|
||||
@ -221,7 +221,7 @@ def _get_allowed_globals():
|
||||
return rc
|
||||
|
||||
|
||||
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
|
||||
def _read_global_instruction(readline: Callable) -> tuple[str, str]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
@ -233,7 +233,7 @@ def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
|
||||
return module, name
|
||||
|
||||
|
||||
def get_globals_in_pkl(file) -> Set[str]:
|
||||
def get_globals_in_pkl(file) -> set[str]:
|
||||
globals_in_checkpoint = set()
|
||||
read = file.read
|
||||
readline = file.readline
|
||||
@ -302,7 +302,7 @@ class Unpickler:
|
||||
self.encoding = encoding
|
||||
self.readline = file.readline
|
||||
self.read = file.read
|
||||
self.memo: Dict[int, Any] = {}
|
||||
self.memo: dict[int, Any] = {}
|
||||
self.proto: int = -1
|
||||
|
||||
def load(self):
|
||||
@ -311,7 +311,7 @@ class Unpickler:
|
||||
Return the reconstituted object hierarchy specified in the file.
|
||||
"""
|
||||
self.metastack = []
|
||||
self.stack: List[Any] = []
|
||||
self.stack: list[Any] = []
|
||||
self.append = self.stack.append
|
||||
read = self.read
|
||||
while True:
|
||||
|
@ -5,11 +5,15 @@ import inspect
|
||||
import warnings
|
||||
from collections import abc, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
|
||||
from typing import Any, cast, Optional, overload, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
__all__ = ["OptState", "GradScaler"]
|
||||
|
||||
|
||||
@ -21,7 +25,7 @@ class _MultiDeviceReplicator:
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device: torch.device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
@ -42,7 +46,7 @@ class OptState(Enum):
|
||||
STEPPED = 2
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state() -> Dict[str, Any]:
|
||||
def _refresh_per_optimizer_state() -> dict[str, Any]:
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
@ -147,13 +151,13 @@ class GradScaler:
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker: Optional[torch.Tensor] = None
|
||||
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
|
||||
self._per_optimizer_states: dict[int, dict[str, Any]] = defaultdict(
|
||||
_refresh_per_optimizer_state
|
||||
)
|
||||
|
||||
def _check_scale_growth_tracker(
|
||||
self, funcname: str
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, (
|
||||
f"Attempted {funcname} but _scale is None. " + fix
|
||||
@ -175,11 +179,11 @@ class GradScaler:
|
||||
...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -210,7 +214,7 @@ class GradScaler:
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
stash: List[
|
||||
stash: list[
|
||||
_MultiDeviceReplicator
|
||||
] = [] # holds a reference that can be overwritten by apply_scale
|
||||
|
||||
@ -237,7 +241,7 @@ class GradScaler:
|
||||
inv_scale: torch.Tensor,
|
||||
found_inf: torch.Tensor,
|
||||
allow_fp16: bool,
|
||||
) -> Dict[torch.device, torch.Tensor]:
|
||||
) -> dict[torch.device, torch.Tensor]:
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
@ -247,8 +251,8 @@ class GradScaler:
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads: Dict[
|
||||
torch.device, Dict[torch.dtype, List[torch.Tensor]]
|
||||
per_device_and_dtype_grads: dict[
|
||||
torch.device, dict[torch.dtype, list[torch.Tensor]]
|
||||
] = defaultdict(lambda: defaultdict(list))
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
@ -343,7 +347,7 @@ class GradScaler:
|
||||
def _maybe_opt_step(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
optimizer_state: Dict[str, Any],
|
||||
optimizer_state: dict[str, Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[float]:
|
||||
@ -596,7 +600,7 @@ class GradScaler:
|
||||
r"""Return a bool indicating whether this instance is enabled."""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
r"""Return the state of the scaler as a :class:`dict`.
|
||||
|
||||
It contains five entries:
|
||||
@ -623,7 +627,7 @@ class GradScaler:
|
||||
}
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
r"""Load the scaler state.
|
||||
|
||||
If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
@ -650,7 +654,7 @@ class GradScaler:
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, (
|
||||
@ -666,10 +670,10 @@ class GradScaler:
|
||||
state["_growth_tracker"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
|
||||
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
@ -681,5 +685,5 @@ class GradScaler:
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
|
||||
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
@ -330,7 +330,7 @@ def backward(
|
||||
|
||||
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
|
||||
tensors = cast(
|
||||
Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,)
|
||||
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
|
||||
)
|
||||
else:
|
||||
tensors = tuple(tensors)
|
||||
|
@ -9,7 +9,6 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Deque,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
@ -764,7 +763,7 @@ def _register_logging_hooks_on_whole_graph(
|
||||
if not roots:
|
||||
return
|
||||
seen: set[Node] = set()
|
||||
q: Deque[Node] = deque()
|
||||
q: deque[Node] = deque()
|
||||
for node in roots:
|
||||
if node is not None:
|
||||
seen.add(node)
|
||||
|
@ -2,7 +2,6 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import DefaultDict
|
||||
|
||||
import torch
|
||||
|
||||
@ -115,7 +114,7 @@ def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
|
||||
for out, val in zip(subgraph.outputs(), node.outputs()):
|
||||
value_map[val.unique()] = rec_value_map[out.unique()]
|
||||
|
||||
op_id_counter: DefaultDict[str, int] = defaultdict(int)
|
||||
op_id_counter: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
def name_for(node):
|
||||
kind = node.kind()[node.kind().index("::") + 2 :]
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -154,9 +155,9 @@ def broadcast_shapes(*shapes):
|
||||
|
||||
def split(
|
||||
tensor: Tensor,
|
||||
split_size_or_sections: Union[int, List[int]],
|
||||
split_size_or_sections: Union[int, list[int]],
|
||||
dim: int = 0,
|
||||
) -> Tuple[Tensor, ...]:
|
||||
) -> tuple[Tensor, ...]:
|
||||
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
|
||||
@ -421,13 +422,13 @@ def einsum(*args: Any) -> Tensor:
|
||||
if TYPE_CHECKING:
|
||||
# The JIT doesn't understand Union, so only add type annotation for mypy
|
||||
def meshgrid(
|
||||
*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
|
||||
) -> Tuple[Tensor, ...]:
|
||||
*tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None
|
||||
) -> tuple[Tensor, ...]:
|
||||
return _meshgrid(*tensors, indexing=indexing)
|
||||
|
||||
else:
|
||||
|
||||
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
||||
def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]:
|
||||
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
|
||||
|
||||
This is helpful when you want to visualize data over some
|
||||
@ -807,7 +808,7 @@ if TYPE_CHECKING:
|
||||
# done by the caller of the _impl function
|
||||
_unique_impl_out = Any
|
||||
else:
|
||||
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
|
||||
_unique_impl_out = tuple[Tensor, Tensor, Tensor]
|
||||
|
||||
|
||||
def _unique_impl(
|
||||
@ -817,7 +818,7 @@ def _unique_impl(
|
||||
return_counts: bool = False,
|
||||
dim: Optional[int] = None,
|
||||
) -> _unique_impl_out:
|
||||
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
|
||||
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor]
|
||||
|
||||
Returns the unique elements of the input tensor.
|
||||
|
||||
@ -1056,7 +1057,7 @@ def _return_counts(
|
||||
return_counts=False,
|
||||
dim=None,
|
||||
):
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
||||
@ -1088,7 +1089,7 @@ def _return_inverse(
|
||||
return_counts=False,
|
||||
dim=None,
|
||||
):
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
||||
@ -1140,7 +1141,7 @@ def _consecutive_return_counts(
|
||||
return_counts=False,
|
||||
dim=None,
|
||||
):
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
@ -1172,7 +1173,7 @@ def _consecutive_return_inverse(
|
||||
return_counts=False,
|
||||
dim=None,
|
||||
):
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
@ -1236,7 +1237,7 @@ else:
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims: Tuple[List[int], List[int]],
|
||||
dims: tuple[list[int], list[int]],
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
@ -1245,7 +1246,7 @@ else:
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims: List[List[int]],
|
||||
dims: list[list[int]],
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
@ -1322,13 +1323,13 @@ def tensordot( # noqa: F811
|
||||
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
|
||||
raise RuntimeError(
|
||||
"tensordot expects dims to be int or "
|
||||
+ "Tuple[List[int], List[int]] or "
|
||||
+ "List[List[int]] containing two lists, but got "
|
||||
+ "tuple[list[int], list[int]] or "
|
||||
+ "list[list[int]] containing two lists, but got "
|
||||
+ f"dims={dims}"
|
||||
)
|
||||
|
||||
dims_a: List[int] = []
|
||||
dims_b: List[int] = []
|
||||
dims_a: list[int] = []
|
||||
dims_b: list[int] = []
|
||||
|
||||
if isinstance(dims, (tuple, list)):
|
||||
dims_a, dims_b = dims
|
||||
@ -1337,8 +1338,8 @@ def tensordot( # noqa: F811
|
||||
num_elements = dims.numel()
|
||||
if num_elements > 1:
|
||||
assert dims.size()[0] == 2
|
||||
dims_a = torch.jit.annotate(List[int], dims[0].tolist())
|
||||
dims_b = torch.jit.annotate(List[int], dims[1].tolist())
|
||||
dims_a = torch.jit.annotate(list[int], dims[0].tolist())
|
||||
dims_b = torch.jit.annotate(list[int], dims[1].tolist())
|
||||
else:
|
||||
dims_val = int(dims.item())
|
||||
if dims_val < 0:
|
||||
@ -1896,7 +1897,7 @@ def norm( # noqa: F811
|
||||
def unravel_index(
|
||||
indices: Tensor,
|
||||
shape: Union[int, Sequence[int], torch.Size],
|
||||
) -> Tuple[Tensor, ...]:
|
||||
) -> tuple[Tensor, ...]:
|
||||
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
|
||||
index into an arbitrary tensor of the specified shape.
|
||||
|
||||
@ -2041,7 +2042,7 @@ def chain_matmul(*matrices, out=None):
|
||||
|
||||
|
||||
def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||
# type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, Any) -> tuple[Tensor, Tensor, Tensor]
|
||||
r"""Computes the LU factorization of a matrix or batches of matrices
|
||||
:attr:`A`. Returns a tuple containing the LU factorization and
|
||||
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
|
||||
@ -2143,7 +2144,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||
if TYPE_CHECKING:
|
||||
_ListOrSeq = Sequence[Tensor]
|
||||
else:
|
||||
_ListOrSeq = List[Tensor]
|
||||
_ListOrSeq = list[Tensor]
|
||||
|
||||
|
||||
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
|
||||
@ -2159,7 +2160,7 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
|
||||
|
||||
|
||||
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
|
||||
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor, Tensor]]) -> tuple[Tensor, Tensor, Tensor]
|
||||
if has_torch_function_unary(A):
|
||||
return handle_torch_function(
|
||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
|
||||
@ -2175,7 +2176,7 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
|
||||
|
||||
|
||||
def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
|
||||
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
||||
# type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor]
|
||||
# need to check for torch_function here so that we exit if
|
||||
if has_torch_function_unary(A):
|
||||
return handle_torch_function(
|
||||
|
@ -27,7 +27,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, devices: Optional[List[Union[int, str, torch.device]]] = None
|
||||
self, *, devices: Optional[list[Union[int, str, torch.device]]] = None
|
||||
):
|
||||
r"""
|
||||
Create an empty unset ``Future``. If the future is intended to hold
|
||||
@ -278,7 +278,7 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
self.set_result(result) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||||
def collect_all(futures: list[Future]) -> Future[list[Future]]:
|
||||
r"""
|
||||
Collects the provided :class:`~torch.futures.Future` objects into a single
|
||||
combined :class:`~torch.futures.Future` that is completed when all of the
|
||||
@ -305,12 +305,12 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||||
fut1 result = 1
|
||||
"""
|
||||
return cast(
|
||||
Future[List[Future]],
|
||||
torch._C._collect_all(cast(List[torch._C.Future], futures)),
|
||||
Future[list[Future]],
|
||||
torch._C._collect_all(cast(list[torch._C.Future], futures)),
|
||||
)
|
||||
|
||||
|
||||
def wait_all(futures: List[Future]) -> List:
|
||||
def wait_all(futures: list[Future]) -> list:
|
||||
r"""
|
||||
Waits for all provided futures to be complete, and returns
|
||||
the list of completed values. If any of the futures encounters an error,
|
||||
@ -327,5 +327,5 @@ def wait_all(futures: List[Future]) -> List:
|
||||
"""
|
||||
return [
|
||||
fut.wait()
|
||||
for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()
|
||||
for fut in torch._C._collect_all(cast(list[torch._C.Future], futures)).wait()
|
||||
]
|
||||
|
@ -12,7 +12,7 @@ import uuid
|
||||
import warnings
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlparse # noqa: F401
|
||||
@ -784,7 +784,7 @@ def _legacy_zip_load(
|
||||
model_dir: str,
|
||||
map_location: MAP_LOCATION,
|
||||
weights_only: bool,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
||||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||
@ -808,7 +808,7 @@ def load_state_dict_from_url(
|
||||
check_hash: bool = False,
|
||||
file_name: Optional[str] = None,
|
||||
weights_only: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
|
||||
If downloaded file is a zip file, it will be automatically
|
||||
|
@ -7,22 +7,22 @@
|
||||
# a type attached and restored via `restore_type_tag` below. The legacy
|
||||
# functions should stick around for backwards-compatibility.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
|
||||
def build_intlist(data: List[int]) -> List[int]:
|
||||
def build_intlist(data: list[int]) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
def build_tensorlist(data: List[object]) -> List[object]:
|
||||
def build_tensorlist(data: list[object]) -> list[object]:
|
||||
return data
|
||||
|
||||
|
||||
def build_doublelist(data: List[float]) -> List[float]:
|
||||
def build_doublelist(data: list[float]) -> list[float]:
|
||||
return data
|
||||
|
||||
|
||||
def build_boollist(data: List[bool]) -> List[bool]:
|
||||
def build_boollist(data: list[bool]) -> list[bool]:
|
||||
return data
|
||||
|
||||
|
||||
|
@ -6,17 +6,13 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -59,8 +55,8 @@ _P = ParamSpec("_P")
|
||||
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
|
||||
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
|
||||
# libraries calling into kernels not intended to be called.
|
||||
_impls: Set[str] = set()
|
||||
_defs: Set[str] = set()
|
||||
_impls: set[str] = set()
|
||||
_defs: set[str] = set()
|
||||
|
||||
# prim is reserved by TorchScript interpreter
|
||||
_reserved_namespaces = ["prim"]
|
||||
@ -111,9 +107,9 @@ class Library:
|
||||
kind, ns, dispatch_key, filename, lineno
|
||||
)
|
||||
self.ns = ns
|
||||
self._op_defs: Set[str] = set()
|
||||
self._op_impls: Set[str] = set()
|
||||
self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
|
||||
self._op_defs: set[str] = set()
|
||||
self._op_impls: set[str] = set()
|
||||
self._registration_handles: list[torch._library.utils.RegistrationHandle] = []
|
||||
self.kind = kind
|
||||
self.dispatch_key = dispatch_key
|
||||
# Use a finalizer to setup the "destructor" instead of __del__.
|
||||
@ -459,7 +455,7 @@ def _scoped_library(*args, **kwargs):
|
||||
lib._destroy()
|
||||
|
||||
|
||||
_keep_alive: List[Library] = []
|
||||
_keep_alive: list[Library] = []
|
||||
|
||||
|
||||
NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
|
||||
@ -1362,12 +1358,12 @@ _OPCHECK_DEFAULT_UTILS = (
|
||||
|
||||
def opcheck(
|
||||
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
|
||||
raise_exception: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Given an operator and some sample arguments, tests if the operator is
|
||||
registered correctly.
|
||||
|
||||
|
@ -27,8 +27,9 @@ import contextlib
|
||||
import functools
|
||||
import types
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._C import (
|
||||
@ -95,7 +96,7 @@ def _disable_user_warnings(
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@_disable_user_warnings
|
||||
def get_ignored_functions() -> Set[Callable]:
|
||||
def get_ignored_functions() -> set[Callable]:
|
||||
"""
|
||||
Return public functions that cannot be overridden by ``__torch_function__``.
|
||||
|
||||
@ -374,7 +375,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_default_nowrap_functions() -> Set[Callable]:
|
||||
def get_default_nowrap_functions() -> set[Callable]:
|
||||
"""
|
||||
Return public functions that do not wrap in a subclass when invoked by
|
||||
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
|
||||
@ -401,7 +402,7 @@ def get_default_nowrap_functions() -> Set[Callable]:
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@_disable_user_warnings
|
||||
def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
"""Return a dict containing dummy overrides for all overridable functions
|
||||
|
||||
Returns
|
||||
@ -427,7 +428,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
# function signatures for native kernels that can be consumed by inspect.
|
||||
# See Issue #28233.
|
||||
Tensor = torch.Tensor
|
||||
ret: Dict[Callable, Callable] = {
|
||||
ret: dict[Callable, Callable] = {
|
||||
torch.abs: lambda input, out=None: -1,
|
||||
torch.absolute: lambda input, out=None: -1,
|
||||
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
|
||||
@ -1592,8 +1593,8 @@ def wrap_torch_function(dispatcher: Callable):
|
||||
|
||||
def _get_overloaded_args(
|
||||
relevant_args: Iterable[Any],
|
||||
get_type_fn: Optional[Callable[[Any], Type]] = None,
|
||||
) -> List[Any]:
|
||||
get_type_fn: Optional[Callable[[Any], type]] = None,
|
||||
) -> list[Any]:
|
||||
"""Returns a list of arguments on which to call __torch_function__.
|
||||
|
||||
Checks arguments in relevant_args for __torch_function__ implementations,
|
||||
@ -1634,8 +1635,8 @@ def _get_overloaded_args(
|
||||
if not torch._C._is_torch_function_enabled():
|
||||
return []
|
||||
# Runtime is O(num_arguments * num_unique_types)
|
||||
overloaded_types: Set[Type] = set()
|
||||
overloaded_args: List[Any] = []
|
||||
overloaded_types: set[type] = set()
|
||||
overloaded_args: list[Any] = []
|
||||
for arg in relevant_args:
|
||||
arg_type = get_type_fn(arg)
|
||||
# We only collect arguments if they have a unique type, which ensures
|
||||
@ -1807,7 +1808,7 @@ has_torch_function_variadic = _add_docstr(
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_overridable_functions() -> (
|
||||
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
|
||||
tuple[dict[Any, list[Callable]], dict[Callable, str]]
|
||||
):
|
||||
overridable_funcs = collections.defaultdict(list)
|
||||
index = {}
|
||||
@ -1893,7 +1894,7 @@ def _get_overridable_functions() -> (
|
||||
|
||||
|
||||
@_disable_user_warnings
|
||||
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
def get_overridable_functions() -> dict[Any, list[Callable]]:
|
||||
"""List functions that are overridable via __torch_function__
|
||||
|
||||
Returns
|
||||
@ -1927,7 +1928,7 @@ def resolve_name(f):
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_tensor_methods() -> Set[Callable]:
|
||||
def _get_tensor_methods() -> set[Callable]:
|
||||
"""Returns a set of the overridable methods on ``torch.Tensor``"""
|
||||
overridable_funcs = get_overridable_functions()
|
||||
methods = set(overridable_funcs[torch.Tensor])
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import Generator
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch._C import default_generator
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Iterable
|
||||
from math import sqrt
|
||||
from typing import Callable, Iterable, Optional, TypeVar
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -8,16 +8,7 @@ import functools
|
||||
import io
|
||||
import threading
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Dict as _Dict,
|
||||
Optional as _Optional,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -42,7 +33,7 @@ except ModuleNotFoundError:
|
||||
|
||||
|
||||
_share_memory_lock = threading.Lock()
|
||||
_share_memory_map: _Dict[int, threading.RLock] = {}
|
||||
_share_memory_map: dict[int, threading.RLock] = {}
|
||||
|
||||
T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]")
|
||||
|
||||
@ -136,35 +127,35 @@ class _StorageBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _new_using_filename_cpu(cls: Type[T], size: _int) -> T:
|
||||
def _new_using_filename_cpu(cls, size: _int) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _new_using_fd_cpu(cls: Type[T], size: _int) -> T:
|
||||
def _new_using_fd_cpu(cls, size: _int) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_buffer(cls: Type[T], *args, **kwargs) -> T:
|
||||
def from_buffer(cls, *args, **kwargs) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _new_shared_filename_cpu(
|
||||
cls: Type[T],
|
||||
cls,
|
||||
manager,
|
||||
obj,
|
||||
size,
|
||||
*,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> T:
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T:
|
||||
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
|
||||
def _new_with_weak_ptr(cls, *args, **kwargs) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
|
||||
@ -192,7 +183,7 @@ class _StorageBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T:
|
||||
def _new_shared_cuda(cls, *args, **kwargs) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def _shared_incref(self, *args, **kwargs):
|
||||
@ -535,7 +526,7 @@ def _load_from_bytes(b):
|
||||
return torch.load(io.BytesIO(b), weights_only=False)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _new_dtypes():
|
||||
# These are dtypes serialized as UntypedStorage unlike those in
|
||||
# _dtype_to_storage_type_map
|
||||
@ -556,7 +547,7 @@ def _new_dtypes():
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _dtype_to_storage_type_map():
|
||||
# NOTE: We should no longer add dtypes to this map. This map
|
||||
# is only used for BC/FC with older PyTorch versions. Going forward,
|
||||
@ -584,7 +575,7 @@ def _dtype_to_storage_type_map():
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _storage_type_to_dtype_map():
|
||||
dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()}
|
||||
return dtype_map
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Iterable
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from torch._vendor.packaging.version import InvalidVersion, Version
|
||||
from torch.version import __version__ as internal_version
|
||||
|
@ -12,7 +12,8 @@ from builtins import ( # noqa: F401
|
||||
int as _int,
|
||||
str as _str,
|
||||
)
|
||||
from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
|
||||
@ -46,7 +47,7 @@ _TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047
|
||||
Sequence["GradientEdge"],
|
||||
]
|
||||
|
||||
_size: TypeAlias = Union[Size, List[int], Tuple[int, ...]] # noqa: PYI042,PYI047
|
||||
_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047
|
||||
_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047
|
||||
_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047
|
||||
|
||||
@ -76,7 +77,7 @@ class Storage:
|
||||
dtype: _dtype
|
||||
_torch_load_uninitialized: bool
|
||||
|
||||
def __deepcopy__(self, memo: Dict[int, Any]) -> "Storage":
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> "Storage":
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_shared(self, size: int) -> "Storage":
|
||||
|
Reference in New Issue
Block a user