diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 97d114e06fbb..2f2a1fec522b 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -1,11 +1,11 @@ -from typing import Callable, Tuple +from typing import Callable from torch._dynamo.compiled_autograd import AutogradCompilerInstance def set_autograd_compiler( autograd_compiler: Callable[[], AutogradCompilerInstance] | None, dynamic: bool, -) -> Tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ... +) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index a8f91ff9e20d..49b1c6d1b735 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,5 +1,5 @@ import types -from typing import Dict, NewType, Tuple +from typing import NewType from torch._dynamo.types import DynamoCallback, DynamoGuardHook @@ -31,17 +31,17 @@ class _ExtraState: # properties Dynamo cares about for a frame. class _PyInterpreterFrame: f_code: types.CodeType - f_locals: Dict[str, object] - f_globals: Dict[str, object] - f_builtins: Dict[str, object] + f_locals: dict[str, object] + f_globals: dict[str, object] + f_builtins: dict[str, object] f_lasti: int f_lineo: int f_back: types.FrameType # A tuple containing cell objects captured by this frame. - closure: Tuple[types.CellType] + closure: tuple[types.CellType] def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... py_opcode_caches: list[int] -def code_framelocals_names(code: types.CodeType) -> Tuple[str]: ... +def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 9ec79a08301b..efc8306eebd4 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Dict +from typing import Any, Callable import torch @@ -121,7 +121,7 @@ def install_storage_overlapping_guard( ): ... def profile_guard_manager( guard_manager: GuardManager, - f_locals: Dict[str, Any], + f_locals: dict[str, Any], ) -> float: ... class TensorGuards: diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi index 422e59984d03..5b0dee51a710 100644 --- a/torch/_C/_functions.pyi +++ b/torch/_C/_functions.pyi @@ -1,4 +1,4 @@ -from typing import AnyStr, overload, Tuple +from typing import AnyStr, overload from torch import Tensor @@ -16,4 +16,4 @@ class DelayedError: @overload def __call__(self, i0: Tensor) -> Tensor: ... @overload - def __call__(self, *args: Tensor) -> Tuple[Tensor, ...]: ... + def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ... diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 39860d3d4c5f..37b50a2efddf 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import inspect from collections import defaultdict +from collections.abc import Sequence from functools import lru_cache, partial, wraps from itertools import chain from typing import ( @@ -9,7 +10,6 @@ from typing import ( FrozenSet, List, Optional, - Sequence, Set, TYPE_CHECKING, TypeVar, @@ -44,8 +44,8 @@ _P = ParamSpec("_P") # TODO: relax key type here; torch registrations should be possible to; but # right now this type is accurate -global_decomposition_table: Dict[ - str, Dict[torch._ops.OperatorBase, Callable] +global_decomposition_table: dict[ + str, dict[torch._ops.OperatorBase, Callable] ] = defaultdict(dict) decomposition_table = global_decomposition_table["post_autograd"] @@ -78,7 +78,7 @@ def _add_op_to_registry(registry, op, fn): If op is OpOverload, it will be added to the registry directly. If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. """ - overloads: List[Union[torch._ops.OperatorBase]] = [] + overloads: list[Union[torch._ops.OperatorBase]] = [] if isinstance(op, HigherOrderOperator): # There's no concept of overloads for HigherOrderOperator registry[op] = fn @@ -232,7 +232,7 @@ def register_decomposition( def get_decompositions( aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], type: str = "post_autograd", -) -> Dict[torch._ops.OperatorBase, Callable]: +) -> dict[torch._ops.OperatorBase, Callable]: """ Retrieve a dictionary of decompositions corresponding to the list of operator overloads and overload packets passed as input. Overload @@ -251,7 +251,7 @@ def get_decompositions( for opo in registry: if isinstance(opo, (OpOverload, OpOverloadPacket)): packets_to_overloads[opo.overloadpacket].append(opo) - decompositions: Dict[torch._ops.OperatorBase, Callable] = {} + decompositions: dict[torch._ops.OperatorBase, Callable] = {} for op in aten_ops: if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: for op_overload in packets_to_overloads[op]: @@ -262,7 +262,7 @@ def get_decompositions( def remove_decompositions( - decompositions: Dict[torch._ops.OperatorBase, Callable], + decompositions: dict[torch._ops.OperatorBase, Callable], aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], ) -> None: """ @@ -297,7 +297,7 @@ def core_aten_decompositions() -> "CustomDecompTable": # excluding decompositions that results in prim ops # Resulting opset of decomposition is core aten ops def _core_aten_decompositions_post_autograd() -> ( - Dict[torch._ops.OperatorBase, Callable] + dict[torch._ops.OperatorBase, Callable] ): aten = torch.ops.aten return get_decompositions( diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 1ae86e04aef2..9543da39dd14 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -5,10 +5,11 @@ import itertools import numbers import operator import sys +from collections.abc import Iterable from enum import Enum from functools import partial, reduce from itertools import chain, product -from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Optional, Union import torch import torch._meta_registrations @@ -39,7 +40,7 @@ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] # None of these functions are publicly accessible; get at them # from torch._decomps -__all__: List[str] = [] +__all__: list[str] = [] aten = torch._ops.ops.aten @@ -299,7 +300,7 @@ def _prelu_kernel_backward( grad_output: Tensor, self: Tensor, weight: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: input_grad = torch.where(self > 0, grad_output, weight * grad_output) weight_grad = torch.where(self > 0, 0.0, self * grad_output) return (input_grad, weight_grad) @@ -690,7 +691,7 @@ def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: @out_wrapper() def slice_backward( grad_output: Tensor, - input_sizes: List[int], + input_sizes: list[int], dim: int, start: int, end: int, @@ -760,7 +761,7 @@ def slice_forward( def _normalize_start_end( x: Tensor, dim: int, start: Optional[int], end: Optional[int] -) -> Tuple[int, int]: +) -> tuple[int, int]: """ Normalize start and end such that both are in the range [0, x.get_size()[dim]] and start <= end. @@ -824,7 +825,7 @@ def slice_scatter( @register_decomposition(aten.select_backward) @out_wrapper() -def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int): +def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int): grad_input = grad_output.new_zeros(input_sizes) return torch.select_scatter(grad_input, grad_output, dim, index) @@ -832,7 +833,7 @@ def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index @register_decomposition(aten.diagonal_backward) @out_wrapper() def diagonal_backward( - grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int + grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int ): grad_input = grad_output.new_zeros(input_sizes) return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) @@ -899,10 +900,10 @@ def _im2col_col2im_indices_along_dim( @out_wrapper() def im2col( input: Tensor, - kernel_size: List[int], - dilation: List[int], - padding: List[int], - stride: List[int], + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], ) -> Tensor: torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") @@ -982,11 +983,11 @@ def im2col( @pw_cast_for_opmath def col2im( input: Tensor, - output_size: List[int], - kernel_size: List[int], - dilation: List[int], - padding: List[int], - stride: List[int], + output_size: list[int], + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], ) -> Tensor: torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") @@ -1094,7 +1095,7 @@ def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): @register_decomposition(aten.unfold_backward) @out_wrapper() def unfold_backward( - grad: Tensor, input_size: List[int], dimension: int, size: int, step: int + grad: Tensor, input_size: list[int], dimension: int, size: int, step: int ) -> Tensor: if len(input_size) == 0: return torch.squeeze_copy(grad, 0) @@ -1257,7 +1258,7 @@ def embedding_dense_backward( ) -def prod(x: List[int]): +def prod(x: list[int]): r = 1 for i in x: r *= i @@ -1265,10 +1266,10 @@ def prod(x: List[int]): def _pad_chunk( - tensors: List[Tensor], + tensors: list[Tensor], dim: int, num_chunks: int, -) -> List[Tensor]: +) -> list[Tensor]: padded_tensors = [] for tensor in tensors: tensor_size = tensor.size() @@ -1285,7 +1286,7 @@ def _pad_chunk( return padded_tensors -def have_same_ndims(tensors: List[Tensor]): +def have_same_ndims(tensors: list[Tensor]): ndim = tensors[0].ndim for tensor in tensors: if tensor.ndim != ndim: @@ -1293,7 +1294,7 @@ def have_same_ndims(tensors: List[Tensor]): return True -def leading_dimension_matches(tensors: List[Tensor], dim: int): +def leading_dimension_matches(tensors: list[Tensor], dim: int): leading_dim_sizes = tensors[0].size()[:dim] for tensor in tensors: torch._check( @@ -1303,7 +1304,7 @@ def leading_dimension_matches(tensors: List[Tensor], dim: int): def _preprocess_chunk_cat_inputs( - tensors: List[Tensor], + tensors: list[Tensor], dim: int, num_chunks: int, ): @@ -1341,7 +1342,7 @@ def _preprocess_chunk_cat_inputs( @register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) def _chunk_cat( - tensors: List[Tensor], + tensors: list[Tensor], dim: int, num_chunks: int, out: Optional[Tensor] = None, @@ -1361,10 +1362,10 @@ def _chunk_cat( ) def split_with_sizes_copy( self: Tensor, - split_sizes: List[int], + split_sizes: list[int], dim: int = 0, - out: Optional[List[Tensor]] = None, -) -> Optional[List[Tensor]]: + out: Optional[list[Tensor]] = None, +) -> Optional[list[Tensor]]: splits = aten.split_with_sizes(self, split_sizes, dim=dim) if out is None: return [s.clone(memory_format=torch.contiguous_format) for s in splits] @@ -1376,19 +1377,19 @@ def split_with_sizes_copy( @register_decomposition(aten.unsafe_split.Tensor) -def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: +def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: return aten.split.Tensor(input, split_size, dim) @register_decomposition(aten.unsafe_split_with_sizes.default) def unsafe_split_with_sizes( - input: Tensor, split_sizes: List[int], dim: int = 0 -) -> Tuple[Tensor, ...]: + input: Tensor, split_sizes: list[int], dim: int = 0 +) -> tuple[Tensor, ...]: return aten.split_with_sizes.default(input, split_sizes, dim) @register_decomposition(aten.split.Tensor) -def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: +def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: input_sizes = self.shape dim_size = input_sizes[dim] if split_size == 0: @@ -1412,7 +1413,7 @@ def tensor_split_tensor_indices_or_sections_py_impl( self: Tensor, tensor_indices_or_sections: Tensor, dim: int = 0, -) -> Tuple[Tensor, ...]: +) -> tuple[Tensor, ...]: assert tensor_indices_or_sections.device.type == "cpu" assert tensor_indices_or_sections.dtype == torch.int64 split_dim = tensor_indices_or_sections.dim() @@ -1505,8 +1506,8 @@ def native_group_norm_backward( C: int, HxW: int, group: int, - output_mask: List[bool], -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: utils.check_same_device( grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False ) @@ -1593,12 +1594,12 @@ def native_group_norm_backward_out( C: int, HxW: int, group: int, - output_mask: List[bool], + output_mask: list[bool], *, out0: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor, -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: result = native_group_norm_backward( grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask ) @@ -1622,13 +1623,13 @@ def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: def native_layer_norm_backward( grad_out: Tensor, input: Tensor, - normalized_shape: List[int], + normalized_shape: list[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], - output_mask: List[bool], -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_ndim = input.dim() computation_dtype = utils.get_computation_dtype(input.dtype) @@ -1643,8 +1644,8 @@ def native_layer_norm_backward( axis = input_ndim - len(normalized_shape) inner_dims = input_shape[axis:] outer_dims = input_shape[:axis] - inner_dim_indices: List[int] = [] - outer_dim_indices: List[int] = [] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] for i in range(input_ndim): if i >= axis: inner_dim_indices.append(i) @@ -1705,17 +1706,17 @@ def native_layer_norm_backward( def native_layer_norm_backward_out( grad_out: Tensor, input: Tensor, - normalized_shape: List[int], + normalized_shape: list[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], - output_mask: List[bool], + output_mask: list[bool], *, out0: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor, -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: result = native_layer_norm_backward( grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask ) @@ -1738,7 +1739,7 @@ def native_batch_norm_helper( momentum: float, eps: float, functional: bool, -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: reduction_dims = [0] + list(range(2, input.dim())) computation_dtype = utils.get_computation_dtype(input.dtype) new_running_mean = running_mean @@ -1821,7 +1822,7 @@ def native_batch_norm( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: output, save_mean, save_rstd, _, _ = native_batch_norm_helper( input, weight, bias, running_mean, running_var, training, momentum, eps, False ) @@ -1849,7 +1850,7 @@ def native_batch_norm_decomposition( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: if running_mean is None and running_var is None: return aten._native_batch_norm_legit( input, weight, bias, training, momentum, eps @@ -1876,7 +1877,7 @@ def native_batch_norm_decomposition( @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) -def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]: +def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]: dim_size = tensor.size(dim) split_size = (dim_size + chunks - 1) // chunks @@ -1896,7 +1897,7 @@ def _native_batch_norm_legit_no_training( running_var: Tensor, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: return aten._native_batch_norm_legit.default( input, weight, @@ -1919,7 +1920,7 @@ def _native_batch_norm_legit( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: output, save_mean, save_rstd, _, _ = native_batch_norm_helper( input, weight, bias, running_mean, running_var, training, momentum, eps, False ) @@ -1934,7 +1935,7 @@ def _native_batch_norm_legit_no_stats( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: output, save_mean, save_rstd, _, _ = native_batch_norm_helper( input, weight, bias, None, None, training, momentum, eps, False ) @@ -1951,7 +1952,7 @@ def _native_batch_norm_legit_functional( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ( output, save_mean, @@ -2002,7 +2003,7 @@ def _batch_norm_with_update( running_var: Tensor, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: output, save_mean, save_rstd, _, _ = native_batch_norm_helper( input, weight, @@ -2029,7 +2030,7 @@ def _batch_norm_with_update_functional( running_var: Tensor, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: ( output, save_mean, @@ -2056,7 +2057,7 @@ def _batch_norm_no_update( running_var: Tensor, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: output, save_mean, save_rstd, _, _ = native_batch_norm_helper( input, weight, @@ -2190,9 +2191,9 @@ def batch_norm_backward( save_invstd: Optional[Tensor], train: bool, eps: float, - output_mask: List[bool], + output_mask: list[bool], reserve: Tensor, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: return native_batch_norm_backward( grad_out, input, @@ -2218,8 +2219,8 @@ def native_batch_norm_backward( save_invstd: Optional[Tensor], train: bool, eps: float, - output_mask: List[bool], -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: input_dtype = input.dtype if weight is not None: weight_dtype = weight.dtype @@ -2262,10 +2263,10 @@ def native_batch_norm_backward( mean = running_mean_cast invstd = torch.rsqrt(running_var_cast + eps) - broadcast_mask: List[int] = [1] * input_rank + broadcast_mask: list[int] = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [] + reduction_axes: list[int] = [] for i in range(input_rank): if i != axis: reduction_axes.append(i) @@ -2320,12 +2321,12 @@ def native_batch_norm_backward_out( save_invstd: Optional[Tensor], train: bool, eps: float, - output_mask: List[bool], + output_mask: list[bool], *, out0: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: result = native_batch_norm_backward( grad_out, input, @@ -2403,7 +2404,7 @@ def cudnn_batch_norm_backward( @register_decomposition(aten._adaptive_avg_pool2d) @out_wrapper() @pw_cast_for_opmath -def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): +def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]): # Preconditions device = input.device shape = input.shape @@ -2506,7 +2507,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): def _max_unpoolnd( - self: TensorLike, indices: TensorLike, output_size: List[int], dim: int + self: TensorLike, indices: TensorLike, output_size: list[int], dim: int ): # If the input tensors self and indices came from max_pool call as # required by the documentation, this operation is deterministic @@ -2534,7 +2535,7 @@ def _max_unpoolnd( def max_unpool2d( self: TensorLike, indices: TensorLike, - output_size: List[int], + output_size: list[int], ): torch._check( indices.dtype == torch.int64, @@ -2581,9 +2582,9 @@ def max_unpool2d( def max_unpool3d( input: TensorLike, indices: TensorLike, - output_size: List[int], - stride: List[int], - padding: List[int], + output_size: list[int], + stride: list[int], + padding: list[int], ): torch._check( indices.dtype == torch.int64, lambda: "elements in indices should be type int64" @@ -2761,7 +2762,7 @@ def _index_copy( @register_decomposition(aten.log_sigmoid_forward) @out_wrapper("output", "buffer") @pw_cast_for_opmath -def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) z = torch.exp(-torch.abs(self)) if self.is_cuda or self.is_xpu: @@ -2840,8 +2841,8 @@ def get_scale_value(scales, idx): @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) def _upsample_nearest_vec( input: Tensor, - output_size: Optional[List[int]], - scale_factors: Optional[List[float]], + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], ) -> Tensor: osize = upsample_compute_output_size(input.size(), output_size, scale_factors) scales = ( @@ -2861,8 +2862,8 @@ def _upsample_nearest_vec( @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) def _upsample_nearest_exact_vec( input: Tensor, - output_size: Optional[List[int]], - scale_factors: Optional[List[float]], + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], ) -> Tensor: osize = upsample_compute_output_size(input.size(), output_size, scale_factors) scales = ( @@ -2909,7 +2910,7 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): @out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest1d( input: Tensor, - output_size: List[int], + output_size: list[int], scales: Optional[float] = None, ) -> Tensor: return _upsample_nearest(input, output_size, [scales]) @@ -2923,7 +2924,7 @@ def upsample_nearest1d( @out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest_exact1d( input: Tensor, - output_size: List[int], + output_size: list[int], scales: Optional[float] = None, ) -> Tensor: return _upsample_nearest(input, output_size, [scales], exact=True) @@ -2935,7 +2936,7 @@ def upsample_nearest_exact1d( @out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest2d( input: Tensor, - output_size: List[int], + output_size: list[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> Tensor: @@ -2950,7 +2951,7 @@ def upsample_nearest2d( @out_wrapper(preserve_memory_format=True, exact_dtype=True) def _upsample_nearest_exact2d( input: Tensor, - output_size: List[int], + output_size: list[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> Tensor: @@ -2963,7 +2964,7 @@ def _upsample_nearest_exact2d( @out_wrapper(preserve_memory_format=True, exact_dtype=True) def upsample_nearest3d( input: Tensor, - output_size: List[int], + output_size: list[int], scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, @@ -2979,7 +2980,7 @@ def upsample_nearest3d( @out_wrapper(preserve_memory_format=True, exact_dtype=True) def _upsample_nearest_exact3d( input: Tensor, - output_size: List[int], + output_size: list[int], scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, @@ -2992,8 +2993,8 @@ def _upsample_nearest_exact3d( @pw_cast_for_opmath def _upsample_nearest( input: Tensor, - output_size: List[int], - scales: List[Optional[float]], + output_size: list[int], + scales: list[Optional[float]], exact: bool = False, ) -> Tensor: spatial_indices = _compute_upsample_nearest_indices( @@ -3072,7 +3073,7 @@ def one_layer_rnn_data( hh_bias = params[3] if has_biases else None step_output = [] - hiddens: List[torch.Tensor] = [] + hiddens: list[torch.Tensor] = [] last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] cur_hidden = hidden.narrow(0, 0, last_batch_size) @@ -3159,7 +3160,7 @@ def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): hx = hidden[0].unsqueeze(0) cx = hidden[1].unsqueeze(0) - batch_sizes: List[int] = [] + batch_sizes: list[int] = [] mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 hidden_size = hx.size(2) num_layers = 1 @@ -3710,7 +3711,7 @@ def _upsample_linear_vec(input, output_size, align_corners, scale_factors): @out_wrapper() def upsample_linear1d( input: Tensor, - output_size: List[int], + output_size: list[int], align_corners: bool, scales_w: Optional[float] = None, ) -> Tensor: @@ -3724,7 +3725,7 @@ def upsample_linear1d( @out_wrapper() def upsample_bilinear2d( input: Tensor, - output_size: List[int], + output_size: list[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None, @@ -3738,7 +3739,7 @@ def upsample_bilinear2d( @out_wrapper() def upsample_trilinear3d( input: Tensor, - output_size: List[int], + output_size: list[int], align_corners: bool, scales_d: Optional[float] = None, scales_h: Optional[float] = None, @@ -3785,9 +3786,9 @@ def _compute_weight_precision(weights: TensorSequenceType) -> Tensor: @pw_cast_for_opmath def _upsample_linear( input: Tensor, - output_size: List[int], + output_size: list[int], align_corners: bool, - scales: List[Optional[float]], + scales: list[Optional[float]], ) -> Tensor: # get dimensions of original image n_channels = input.shape[1] @@ -3937,7 +3938,7 @@ def _nll_loss_forward( weight: Optional[Tensor], reduction: int, ignore_index: int, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: # self can be [N, C] or [C] # target can be [N] or [] @@ -3992,7 +3993,7 @@ def nll_loss_forward( weight: Optional[Tensor], reduction: int, ignore_index: int, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" assert ( target.dim() <= 1 @@ -4020,7 +4021,7 @@ def nll_loss2d_forward( weight: Optional[Tensor], reduction: int, ignore_index: int, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: return _nll_loss_forward(self, target, weight, reduction, ignore_index) @@ -4108,7 +4109,7 @@ def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: boo return grid_x + grid_y + grid_z + grid_one -def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool): +def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool): n, _, h, w = size base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) @@ -4118,7 +4119,7 @@ def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: boo return grid.view(n, h, w, 2) -def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool): +def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool): n, _, d, h, w = size base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) @@ -4131,7 +4132,7 @@ def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: boo @register_decomposition(aten.affine_grid_generator) @out_wrapper() @pw_cast_for_opmath -def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool): +def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool): torch._check( len(size) in (4, 5), lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", @@ -4454,7 +4455,7 @@ def matmul(tensor1, tensor2, *, is_out=False): m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) p = tensor2.size(-1) if dim_tensor2 > 1 else 1 - batch_tensor2: List[int] = [] + batch_tensor2: list[int] = [] # TODO: handling of slice for i in range(dim_tensor2 - 2): batch_tensor2.append(tensor2.size(i)) @@ -4520,7 +4521,7 @@ def matmul(tensor1, tensor2, *, is_out=False): @pw_cast_for_opmath def upsample_bicubic2d_default( input: Tensor, - output_size: Tuple[int, int], + output_size: tuple[int, int], align_corners: bool, scale_h: Optional[float] = None, scale_w: Optional[float] = None, @@ -4608,9 +4609,9 @@ def upsample_bicubic2d_default( @pw_cast_for_opmath def upsample_bicubic2d_vec( a: Tensor, - output_size: Optional[Tuple[int, int]], + output_size: Optional[tuple[int, int]], align_corners: bool, - scale_factors: Optional[Tuple[float, float]] = None, + scale_factors: Optional[tuple[float, float]] = None, ) -> Tensor: torch._check( bool(output_size) + bool(scale_factors) == 1, @@ -4619,7 +4620,7 @@ def upsample_bicubic2d_vec( if output_size is None: assert scale_factors is not None output_size = cast( - Tuple[int, int], + tuple[int, int], tuple( sym_int(sym_float(w) * scale) for w, scale in zip(a.shape[2:], scale_factors) @@ -4634,7 +4635,7 @@ def upsample_bicubic2d_vec( @register_decomposition(aten.reflection_pad3d) @pw_cast_for_opmath @out_wrapper() -def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: +def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: def idx(left, middle, right): dim_idx = torch.arange(-left, middle + right, device=a.device) return middle - 1 - (middle - 1 - dim_idx.abs()).abs() @@ -4651,7 +4652,7 @@ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: @register_decomposition(aten.replication_pad3d) @pw_cast_for_opmath @out_wrapper() -def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: +def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: def idx(left, middle, right): dim_idx = torch.arange(-left, middle + right, device=a.device) return torch.clamp(dim_idx, 0, middle - 1) @@ -4665,7 +4666,7 @@ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: def _reflection_or_replication_pad( a: Tensor, - padding: Tuple[int, ...], + padding: tuple[int, ...], idx_fn: Callable[[int, int, int], Tensor], ) -> Tensor: dim = len(padding) // 2 @@ -4681,7 +4682,7 @@ def _reflection_or_replication_pad( result = a for i in range(dim): - idx: List[Any] = [None] * result.dim() + idx: list[Any] = [None] * result.dim() idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) result = aten._unsafe_index(result, idx) @@ -4887,7 +4888,7 @@ def multilabel_margin_loss_forward( input: Tensor, target: Tensor, reduction: int, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: orig_input_shape = input.shape orig_target_shape = target.shape input = torch.atleast_2d(input) @@ -4953,7 +4954,7 @@ def scaled_dot_product_flash_attention_for_cpu( *, attn_mask: Optional[Tensor] = None, scale: Optional[float] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: torch._check( torch.is_floating_point(query), lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index d47f91d4c888..60a19f320059 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import inspect -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Optional import torch import torch._decomp @@ -10,7 +10,7 @@ from torch._prims_common.wrappers import _maybe_remove_out_wrapper decomposition_table = torch._decomp.decomposition_table -decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {} +decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {} register_decomposition = torch._decomp.register_decomposition aten = torch.ops.aten @@ -104,7 +104,7 @@ def trace(self: Tensor) -> Tensor: @maybe_register_decomposition(aten.log_sigmoid_forward.default) -def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) z = torch.exp(-torch.abs(self)) if self.is_cuda or self.is_xpu: @@ -115,7 +115,7 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: def recompute_mean_var( - input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool + input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool ): # for most norm decompositions, it will be the same as the core version except for here. # We recompute the mean and variance so that they track gradients through input @@ -132,13 +132,13 @@ def recompute_mean_var( def native_layer_norm_backward( grad_out: Tensor, input: Tensor, - normalized_shape: List[int], + normalized_shape: list[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], - output_mask: List[bool], -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_ndim = input.dim() @@ -205,7 +205,7 @@ def native_layer_norm_backward( return (d_input, d_weight, d_bias) -def prod(x: List[int]): +def prod(x: list[int]): r = 1 for i in x: r *= i @@ -223,8 +223,8 @@ def native_batch_norm_backward( save_invstd: Optional[Tensor], train: bool, eps: float, - output_mask: List[bool], -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_rank = input.dim() assert input_rank >= 2, "rank of the input must be at least 2" @@ -251,7 +251,7 @@ def native_batch_norm_backward( broadcast_mask = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [] + reduction_axes: list[int] = [] for i in range(input_rank): if i != axis: reduction_axes.append(i) @@ -305,9 +305,9 @@ def batch_norm_backward( save_var: Optional[Tensor], update: bool, eps: float, - output_mask: List[bool], + output_mask: list[bool], reserve: Tensor, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: return native_batch_norm_backward( grad_out, input, diff --git a/torch/_decomp/decompositions_for_rng.py b/torch/_decomp/decompositions_for_rng.py index a62a28f783b7..256045498cbf 100644 --- a/torch/_decomp/decompositions_for_rng.py +++ b/torch/_decomp/decompositions_for_rng.py @@ -2,7 +2,7 @@ # mypy: allow-untyped-defs import functools from collections import defaultdict -from typing import Callable, Dict +from typing import Callable import torch import torch._decomp as decomp @@ -12,7 +12,7 @@ from torch._ops import OpOverload aten = torch.ops.aten -rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict) +rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict) def register_rng_decomposition(aten_op): diff --git a/torch/_lazy/device_context.py b/torch/_lazy/device_context.py index eef60961e65b..49f33cf7f7c6 100644 --- a/torch/_lazy/device_context.py +++ b/torch/_lazy/device_context.py @@ -1,11 +1,11 @@ import threading -from typing import Any, Dict, Optional +from typing import Any, Optional import torch._C._lazy class DeviceContext: - _CONTEXTS: Dict[str, Any] = {} + _CONTEXTS: dict[str, Any] = {} _CONTEXTS_LOCK = threading.Lock() def __init__(self, device: str) -> None: diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index f46eea4eee9b..d014c272490b 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -3,7 +3,7 @@ import copy import dataclasses import itertools import os -from typing import Any, Callable, Dict, List +from typing import Any, Callable import torch import torch._lazy as lazy @@ -28,14 +28,14 @@ class GraphInputMatcher: TS/XLA graph inputs. """ - tensor_id_to_arg_idx: Dict[int, int] - graph_input_tensor_ids: List[int] + tensor_id_to_arg_idx: dict[int, int] + graph_input_tensor_ids: list[int] # there are 2 categories of graph_input_tensors. # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are # most likely const tensors and we can get its content from graph_input_tensors # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get # the tensor from method arguments - graph_input_ivalues: List[Any] + graph_input_ivalues: list[Any] # get the real graph input tensors def __call__(self, args): @@ -71,10 +71,10 @@ class ReturnValueHandler: """ def __init__(self, lazy_out_list): - self.index: List[List[int]] = [] + self.index: list[list[int]] = [] self.total_count = len(lazy_out_list) - tensor_id_to_idx: Dict[int, int] = {} + tensor_id_to_idx: dict[int, int] = {} for dup_idx, lazy_tensor in enumerate(lazy_out_list): uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) if uniq_idx is not None: diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 75997ec63eb1..5c8c713b6e42 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import dataclasses from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Protocol +from typing import Any, Callable, Optional, Protocol from torch import _C, _ops, autograd, Tensor from torch.utils import _pytree @@ -28,7 +28,7 @@ def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: @dataclass class Metadata: keyset: _C.DispatchKeySet - keyword_only_args: Dict[str, Any] + keyword_only_args: dict[str, Any] def forward_no_grad(*args): metadata = args[-1] diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 2108025a0b0c..3bcd3df98e83 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -2,20 +2,9 @@ import inspect import logging import weakref +from collections.abc import Iterable, Sequence from contextlib import contextmanager -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - overload, - Sequence, - Set, - Union, -) +from typing import Any, Callable, Literal, Optional, overload, Union import torch from torch import _C, _ops, Tensor @@ -200,16 +189,16 @@ class CustomOpDef: self._init_fn = fn - self._backend_fns: Dict[Union[str, None], Callable] = {} + self._backend_fns: dict[Union[str, None], Callable] = {} self._abstract_fn: Optional[Callable] = None self._setup_context_fn: Optional[Callable] = None self._backward_fn: Optional[Callable] = None - self._torch_dispatch_fns: Dict[type, Callable] = {} + self._torch_dispatch_fns: dict[type, Callable] = {} self._vmap_fn: Optional[Callable] = None self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher() - self._disabled_kernel: Set = set() + self._disabled_kernel: set = set() OPDEFS[self._qualname] = self @property @@ -332,7 +321,7 @@ class CustomOpDef: def inner(fn): if device_types is None or isinstance(device_types, str): - dtypes: List[Union[str, None]] = [device_types] + dtypes: list[Union[str, None]] = [device_types] else: dtypes = list(device_types) for device_type in dtypes: @@ -807,7 +796,7 @@ def increment_version(val: Any) -> None: # decorator. -OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {} +OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {} OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index d1614449b42d..3009e6d4ea42 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import copy import logging -from typing import Any, Dict, Optional, Protocol, Tuple, Union +from typing import Any, Optional, Protocol, Union import torch from torch._library.utils import parse_namespace @@ -56,7 +56,7 @@ class HasStaticMethodFromReal(Protocol): class FakeClassRegistry: def __init__(self) -> None: - self._registered_class: Dict[str, Any] = {} + self._registered_class: dict[str, Any] = {} def has_impl(self, full_qualname: str) -> bool: return full_qualname in self._registered_class @@ -290,7 +290,7 @@ def _full_qual_class_name(qualname: str) -> str: # Return the namespace and class name from fully qualified name. -def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]: +def _ns_and_class_name(full_qualname: str) -> tuple[str, str]: splits = full_qualname.split(".") assert len(splits) == 5, f"Could not split {full_qualname=}" _torch, _torch_ns, _classes, ns, class_name = splits diff --git a/torch/_library/triton.py b/torch/_library/triton.py index bc68c285311d..dc0db8a07bc0 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -1,6 +1,7 @@ import contextlib import threading -from typing import Any, Callable, Generator, Iterable, Optional, Union +from collections.abc import Generator, Iterable +from typing import Any, Callable, Optional, Union from torch.utils._exposed_in import exposed_in diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 45c8208d9e30..8348883cee30 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -3,7 +3,8 @@ import dataclasses import inspect import sys import warnings -from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Union import torch import torch.utils._pytree as pytree @@ -55,7 +56,7 @@ def get_source(stacklevel: int) -> str: return source -def parse_namespace(qualname: str) -> Tuple[str, str]: +def parse_namespace(qualname: str) -> tuple[str, str]: splits = qualname.split("::") if len(splits) != 2: raise ValueError( @@ -189,8 +190,8 @@ def fill_defaults(schema, args, kwargs): def zip_schema( - schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] -) -> Iterable[Tuple[_C.Argument, Any]]: + schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Iterable[tuple[_C.Argument, Any]]: """zips schema.arguments and (args, kwargs) together. Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: @@ -332,7 +333,7 @@ def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: def iter_tensors( - args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1 + args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1 ) -> Iterator[torch.Tensor]: def check(arg): if isinstance(arg, torch.Tensor): @@ -465,7 +466,7 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool: return False -def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]: +def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]: idxs = [] keys = [] for i, info in enumerate(schema.arguments): diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index 94d593684a19..3579cfe83b42 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -12,7 +12,7 @@ from __future__ import annotations import builtins import itertools import operator -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING import torch @@ -20,6 +20,8 @@ from . import _dtypes_impl, _util if TYPE_CHECKING: + from collections.abc import Sequence + from ._normalizations import ( ArrayLike, ArrayLikeOrScalar, diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index 73d60b24b5c3..20ebd9db8182 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -5,7 +5,7 @@ from __future__ import annotations import builtins import math import operator -from typing import Sequence +from collections.abc import Sequence import torch diff --git a/torch/_numpy/linalg.py b/torch/_numpy/linalg.py index 093851142dbc..4ea3b46f23e6 100644 --- a/torch/_numpy/linalg.py +++ b/torch/_numpy/linalg.py @@ -4,7 +4,7 @@ from __future__ import annotations import functools import math -from typing import Sequence +from typing import TYPE_CHECKING import torch @@ -12,6 +12,10 @@ from . import _dtypes_impl, _util from ._normalizations import ArrayLike, KeepDims, normalizer +if TYPE_CHECKING: + from collections.abc import Sequence + + class LinAlgError(Exception): pass diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 9e07a977658e..172e5728d034 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import operator +from collections.abc import Sequence from enum import Enum from functools import partial, reduce -from typing import Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch import torch._prims_common as utils @@ -231,8 +232,8 @@ def TensorMeta( if isinstance(tensorlike, Number): assert not shape and (shape is None or isinstance(shape, Sequence)) assert not strides and (strides is None or isinstance(strides, Sequence)) - inferred_shape: Tuple[int, ...] = () - inferred_strides: Tuple[int, ...] = () + inferred_shape: tuple[int, ...] = () + inferred_strides: tuple[int, ...] = () inferred_dtype = type_to_dtype(type(tensorlike)) inferred_device = torch.device("cpu") # TODO: This looks wrong, a number that is wrapped into a tensor @@ -266,7 +267,7 @@ def TensorMeta( def _make_prim( *, schema: str, - return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]], + return_type: Union[RETURN_TYPE, tuple[RETURN_TYPE, ...]], meta: Callable, impl_aten: Callable, doc: str, @@ -383,7 +384,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum): def _prim_elementwise_meta( *args, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, - args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None, + args_with_fixed_dtypes: Optional[tuple[TensorLikeType, ...]] = None, ) -> FakeTensor: """ Meta function for elementwise operations that produce outputs in the same dtype @@ -1358,7 +1359,7 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None: ) -def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]: +def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]: """ Returns the shape of a with dims in [start, end) merged into a single dimension. """ @@ -1374,7 +1375,7 @@ def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]: def _collapse_view_helper( a: TensorLikeType, start: int, end: int -) -> Tuple[Optional[ShapeType], Optional[StrideType]]: +) -> tuple[Optional[ShapeType], Optional[StrideType]]: assert isinstance(a, TensorLike) from torch.fx.experimental.symbolic_shapes import guard_size_oblivious @@ -1534,8 +1535,8 @@ def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLik ) raise ValueError(msg) - new_shape: List[int] = [] - new_strides: List[int] = [] + new_shape: list[int] = [] + new_strides: list[int] = [] for idx in range(a.ndim): if idx == dim: new_shape.extend((outer_length, inner_length)) @@ -1797,7 +1798,7 @@ def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: ) -def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor: +def _cat_aten(tensors: Union[tuple[Tensor, ...], list[Tensor]], dim: int) -> Tensor: return torch.cat(tensors, dim) @@ -2609,7 +2610,7 @@ scalar_tensor = _make_prim( def _svd_meta( A: TensorLikeType, *, full_matrices: bool -) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]: +) -> tuple[TensorLikeType, TensorLikeType, TensorLikeType]: utils.check_is_matrix(A, "linalg.svd") utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False) @@ -2646,7 +2647,7 @@ def _svd_meta( def _svd_aten( A: TensorLikeType, *, full_matrices: bool -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: return torch.linalg.svd(A, full_matrices=full_matrices) @@ -2899,7 +2900,7 @@ fft_c2r = _make_prim( ) -def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: +def _frexp_meta(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]: torch._check( self.dtype.is_floating_point, lambda: "torch.frexp() only supports floating-point dtypes", diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 475692469136..36cb40e79165 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import functools +from collections.abc import Sequence from contextlib import nullcontext -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Optional import torch import torch._decomp @@ -28,7 +29,7 @@ def torch_to_refs_map(): (torch.fft, torch._refs.fft), (torch.linalg, torch._refs.linalg), ] - r: Dict[Any, Any] = { + r: dict[Any, Any] = { torch.Tensor.__invert__: torch._refs.bitwise_not, torch.Tensor.__xor__: torch._refs.bitwise_xor, torch.Tensor.__and__: torch._refs.bitwise_and, @@ -107,7 +108,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode): orig_func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, + kwargs: Optional[dict] = None, ): if kwargs is None: kwargs = {} diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index bbbdb8958f9a..d4d9203ef6ab 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Optional, Tuple +from typing import Optional import torch import torch.utils._pytree as pytree @@ -85,7 +85,7 @@ def register_philox_rand(): shape: torch.Size, seed: torch.Tensor, offset: torch.Tensor, - stride: Optional[Tuple[int, ...]], + stride: Optional[tuple[int, ...]], device: _device, dtype: _dtype, ): @@ -102,7 +102,7 @@ def register_philox_rand(): shape: torch.Size, seed: torch.Tensor, offset: torch.Tensor, - stride: Optional[Tuple[int, ...]], + stride: Optional[tuple[int, ...]], device: _device, dtype: _dtype, ): diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index c618545e02be..3ee9e5dcd50f 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -2,6 +2,7 @@ import inspect import warnings from functools import wraps +from types import GenericAlias from typing import ( Callable, List, @@ -260,6 +261,15 @@ def out_wrapper( Adds the out parameter to a Python reference. """ out_type = ( + TensorLikeType + if is_tensor + else GenericAlias( + tuple, tuple(TensorLikeType for _ in range(len(out_names))) + ) + ) + # For backward compatibility - should be able to remove once PEP585 + # conversion is complete. + bc_out_type = ( TensorLikeType if is_tensor else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] @@ -301,12 +311,12 @@ def out_wrapper( assert ( (isinstance(result, TensorLike) and is_tensor) or ( - isinstance(result, Tuple) # type: ignore[arg-type] + isinstance(result, tuple) # type: ignore[arg-type] and len(result) == len(out_names) # type: ignore[arg-type] ) or ( fn.__name__ == "unbind" - and isinstance(result, (List, Tuple)) # type: ignore[arg-type] + and isinstance(result, (List, tuple)) # type: ignore[arg-type] ) ) # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829 @@ -336,9 +346,9 @@ def out_wrapper( _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: if fn.__name__ != "unbind": - assert isinstance(out, Tuple) # type: ignore[arg-type] + assert isinstance(out, tuple) # type: ignore[arg-type] else: - assert isinstance(out, (List, Tuple)) # type: ignore[arg-type] + assert isinstance(out, (list, tuple)) # type: ignore[arg-type] torch._check_type( len(out) == len(result), # type: ignore[arg-type] lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] @@ -362,6 +372,7 @@ def out_wrapper( assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( sig.empty, out_type, + bc_out_type, ) params = *sig.parameters.values(), out_param diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 14e28d4d4194..14221f0bde0f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -7,21 +7,10 @@ import itertools import math import operator import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from enum import Enum from functools import partial, reduce, singledispatch, wraps -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Optional, - overload, - Sequence, - Tuple, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union import torch import torch._prims as prims @@ -411,7 +400,7 @@ def _broadcast_shapes(*_shapes): assert isinstance(shape, Sequence) # Computes common shape - common_shape: List[Union[int, torch.SymInt]] = [ + common_shape: list[Union[int, torch.SymInt]] = [ 1, ] * reduce(max, (len(shape) for shape in shapes)) for arg_idx, shape in enumerate(shapes): @@ -1421,7 +1410,7 @@ def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: @register_decomposition(aten.frexp) @out_wrapper("mantissa", "exponent") -def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: +def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]: return torch.return_types.frexp(prims.frexp(self)) @@ -2052,7 +2041,7 @@ def _to_device( non_blocking: bool = False, copy: bool = False, memory_format: Optional[torch.memory_format] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: kwargs = { "device": device, "dtype": dtype, @@ -2070,7 +2059,7 @@ def _to_device_str( non_blocking: bool = False, copy: bool = False, memory_format: Optional[torch.memory_format] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: kwargs = { "device": torch.device(device), "dtype": dtype, @@ -2087,7 +2076,7 @@ def _to_dtype( non_blocking: bool = False, copy: bool = False, memory_format: Optional[torch.memory_format] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: kwargs = { "dtype": dtype, "non_blocking": non_blocking, @@ -2103,7 +2092,7 @@ def _to_other( non_blocking: bool = False, copy: bool = False, memory_format: Optional[torch.memory_format] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: device = other.device dtype = other.dtype layout = other.layout @@ -2311,7 +2300,7 @@ def any( @register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) def sum( a: TensorLikeType, - dim: Union[Optional[int], Optional[List[int]]] = None, + dim: Union[Optional[int], Optional[list[int]]] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, @@ -2363,7 +2352,7 @@ def sum_to_size( @register_decomposition(aten.prod) def prod( a: TensorLikeType, - dim: Union[Optional[int], Optional[List[int]]] = None, + dim: Union[Optional[int], Optional[list[int]]] = None, keepdim: bool = False, *, dtype=None, @@ -2481,7 +2470,7 @@ def var( @out_wrapper() def std( a: TensorLikeType, - dim: Union[Optional[int], Optional[List[int]]] = None, + dim: Union[Optional[int], Optional[list[int]]] = None, unbiased: Optional[bool] = None, keepdim: bool = False, *, @@ -2660,7 +2649,7 @@ def addr( # CompositeImplicitAutograd - don't register decomp def atleast_1d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType -) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_1d`.""" if not args and isinstance(arg, collections.abc.Sequence): args_ = arg @@ -2684,7 +2673,7 @@ def _unsqueeze_atleast( # CompositeImplicitAutograd - don't register decomp def atleast_2d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType -) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_2d`.""" if not args and isinstance(arg, collections.abc.Sequence): args_ = arg @@ -2699,7 +2688,7 @@ def atleast_2d( # CompositeImplicitAutograd - don't register decomp def atleast_3d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType -) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_3d`.""" if not args and isinstance(arg, collections.abc.Sequence): args_ = arg @@ -2742,7 +2731,7 @@ def broadcast_shapes(*shapes) -> ShapeType: @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) -def broadcast_tensors(*tensors) -> List[TensorLikeType]: +def broadcast_tensors(*tensors) -> list[TensorLikeType]: if len(tensors) == 1 and not isinstance(tensors[0], Tensor): tensors = tensors[0] return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) @@ -2900,7 +2889,7 @@ def conj(input: TensorLikeType) -> TensorLikeType: @register_decomposition(aten.constant_pad_nd) @out_wrapper() def constant_pad_nd( - input: TensorLikeType, pad: List[int], value: NumberType = 0 + input: TensorLikeType, pad: list[int], value: NumberType = 0 ) -> TensorLikeType: torch._check( len(pad) % 2 == 0, @@ -3045,7 +3034,7 @@ def expand_as(a: Tensor, b: Tensor) -> Tensor: return a.expand(b.shape) -def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: +def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]: if chunks <= 0: msg = f"Expected at least one chunk, but got {chunks}!" raise ValueError(msg) @@ -3148,7 +3137,7 @@ def narrow( def _normalize( a: Tensor, norm_dims: DimsType, eps: float -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Computes mean and 1/std of a tensor along norm_dims. Used as a helper function for normalization layers. @@ -3176,7 +3165,7 @@ def _normalize( # add all specified dimensions -def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: +def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType: for dim in sorted(dimensions): x = torch.unsqueeze(x, dim) return x @@ -3192,7 +3181,7 @@ def native_group_norm( flattened_inner_size: int, num_groups: int, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: torch._check( input.ndim >= 2, lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", @@ -3243,7 +3232,7 @@ def native_layer_norm( weight: Optional[Tensor], bias: Optional[Tensor], eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: normalized_ndim = len(normalized_shape) torch._check( normalized_ndim >= 1, @@ -4163,8 +4152,8 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType @register_decomposition(aten.split_with_sizes) def split_with_sizes( - self: Tensor, split_sizes: List[int], dim: int = 0 -) -> List[Tensor]: + self: Tensor, split_sizes: list[int], dim: int = 0 +) -> list[Tensor]: # NB: Perform the check_is_size tests first so that the # sum test does not try to do a replacement for i in range(len(split_sizes)): @@ -4197,7 +4186,7 @@ def tensor_split( a: TensorLikeType, indices_or_sections: Union[Tensor, DimsType], dim: int = 0, -) -> Tuple[TensorLikeType, ...]: +) -> tuple[TensorLikeType, ...]: _dim = utils.canonicalize_dim(a.ndim, dim) if a.ndim == 0: msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" @@ -4263,7 +4252,7 @@ def tensor_split( # CompositeImplicitAutograd - don't register decomp def hsplit( a: TensorLikeType, indices_or_sections: DimsType -) -> Tuple[TensorLikeType, ...]: +) -> tuple[TensorLikeType, ...]: torch._check( a.ndim >= 1, lambda: ( @@ -4305,7 +4294,7 @@ def hsplit( # CompositeImplicitAutograd - don't register decomp def vsplit( a: TensorLikeType, indices_or_sections: DimsType -) -> Tuple[TensorLikeType, ...]: +) -> tuple[TensorLikeType, ...]: torch._check( a.ndim >= 2, lambda: ( @@ -4480,7 +4469,7 @@ def diag_embed( @register_decomposition(aten.block_diag) @out_wrapper() -def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType: +def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType: """ Reference implementation of torch.block_diag """ @@ -4516,7 +4505,7 @@ def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType: return torch.cat(result, dim=0) -def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType: +def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType: """ This is used as an input to PythonRefInfo. `torch.block_diag` expects arguments splatted, but `aten.block_diag` expects only @@ -5305,9 +5294,9 @@ def meshgrid(*tensors: TensorLikeType, indexing: str): @register_decomposition(aten.meshgrid) # type: ignore[misc] def meshgrid( - *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]], + *tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]], indexing: str, -) -> List[TensorLikeType]: +) -> list[TensorLikeType]: # This ref simultaneously handles two overloads (see stubs above) # The `indexing` argument is currently optional for torch.meshgrid, but we # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 @@ -5346,7 +5335,7 @@ def meshgrid( ), ) - result_shape: List[int] = [] + result_shape: list[int] = [] for t in tensors: assert isinstance(t, TensorLike) # mypy torch._check( @@ -5355,7 +5344,7 @@ def meshgrid( ) result_shape.append(t.numel()) - grids: List[TensorLikeType] = [] + grids: list[TensorLikeType] = [] for i, t in enumerate(tensors): assert isinstance(t, TensorLike) # mypy if t.ndim == 0: @@ -5436,7 +5425,7 @@ def movedim( @register_decomposition(aten.empty_strided) @out_wrapper() def empty_strided( - shape: Union[ShapeType, Tuple[ShapeType]], + shape: Union[ShapeType, tuple[ShapeType]], strides: StrideType, *, dtype: Optional[torch.dtype] = None, @@ -5854,7 +5843,7 @@ def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: # form a pentagon that can be broken down into a top trapezoid and a bottom # rectangle. For the implementation of tril_indices, we need the sizes of # both of these, as well as the length of the top side of the trapezoid. -def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: +def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: if row == 0 or col == 0: return 0, 0, 0 @@ -5932,7 +5921,7 @@ def tril_indices( # a bottom rectangle instead. Note that you can't reduce this to # _get_tril_sizes(col, row, -offset) because that would correspond to # decomposing into a left trapezoid and right rectangle. -def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: +def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: if row == 0 or col == 0: return 0, 0, 0 diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index 2558dcf6da0f..c95a5bab02f2 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -1,5 +1,6 @@ import math -from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union +from collections.abc import Iterable, Sequence +from typing import Literal, NamedTuple, Optional, Union import torch import torch._prims as prims @@ -88,7 +89,7 @@ def _maybe_promote_tensor_fft( def _resize_fft_input( - x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] + x: TensorLikeType, dims: tuple[int, ...], sizes: tuple[int, ...] ) -> TensorLikeType: """ Fixes the shape of x such that x.size(dims[i]) == sizes[i], @@ -268,8 +269,8 @@ def ihfft( class _ShapeAndDims(NamedTuple): - shape: Tuple[int, ...] - dims: Tuple[int, ...] + shape: tuple[int, ...] + dims: tuple[int, ...] def _canonicalize_fft_shape_and_dim_args( @@ -339,8 +340,8 @@ def _prod(xs: Iterable[int]) -> int: def _fftn_c2c( function_name: str, input: TensorLikeType, - shape: Tuple[int, ...], - dim: Tuple[int, ...], + shape: tuple[int, ...], + dim: tuple[int, ...], norm: NormType, forward: bool, ) -> TensorLikeType: @@ -429,8 +430,8 @@ def ihfftn( class _CanonicalizeC2rReturn(NamedTuple): - shape: Tuple[int, ...] - dim: Tuple[int, ...] + shape: tuple[int, ...] + dim: tuple[int, ...] last_dim_size: int @@ -566,7 +567,7 @@ def ihfft2( return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) -def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> list[int]: """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" if dim is None: return list(range(x.ndim)) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 6585f57e3d64..04187913aacf 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -288,7 +288,7 @@ def norm( # CompositeImplicitAutograd @out_wrapper("U", "S", "Vh", exact_dtype=True) -def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]: +def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]: return prims.svd(A, full_matrices=full_matrices) diff --git a/torch/_refs/nn/__init__.py b/torch/_refs/nn/__init__.py index b7414d43515a..840ecd9ca20b 100644 --- a/torch/_refs/nn/__init__.py +++ b/torch/_refs/nn/__init__.py @@ -1,4 +1,4 @@ from typing import List -__all__: List[str] = [] +__all__: list[str] = [] diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py index e45fe6177baf..4fe133cafc03 100644 --- a/torch/_strobelight/cli_function_profiler.py +++ b/torch/_strobelight/cli_function_profiler.py @@ -6,9 +6,10 @@ import os import re import subprocess import time +from collections.abc import Sequence from threading import Lock from timeit import default_timer as timer -from typing import Any, Callable, List, Optional, Sequence, TypeVar +from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -77,8 +78,8 @@ class StrobelightCLIFunctionProfiler: run_user_name: str = "pytorch-strobelight-ondemand", timeout_wait_for_running_sec: int = 60, timeout_wait_for_finished_sec: int = 60, - recorded_env_variables: Optional[List[str]] = None, - sample_tags: Optional[List[str]] = None, + recorded_env_variables: Optional[list[str]] = None, + sample_tags: Optional[list[str]] = None, stack_max_len: int = 127, async_stack_max_len: int = 127, ): @@ -91,7 +92,7 @@ class StrobelightCLIFunctionProfiler: # Results of the most recent run. # Tracks the strobelight run id of the most recent run self.current_run_id: Optional[int] = None - self.profile_result: Optional[List[str]] = None + self.profile_result: Optional[list[str]] = None self.sample_tags = sample_tags def _run_async(self) -> None: