diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 99e5877ba62b..5c8a07e21e17 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3,7 +3,7 @@ import importlib import math import warnings -from typing import Callable, Optional, TYPE_CHECKING, Union +from typing import Callable, List, Optional, TYPE_CHECKING, Union import torch from torch import _VF, sym_int as _sym_int, Tensor @@ -708,7 +708,7 @@ def max_pool1d_with_indices( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch.max_pool1d_with_indices( input, kernel_size, stride, padding, dilation, ceil_mode ) @@ -736,7 +736,7 @@ def _max_pool1d( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -798,7 +798,7 @@ def max_pool2d_with_indices( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices( input, kernel_size, stride, padding, dilation, ceil_mode ) @@ -826,7 +826,7 @@ def _max_pool2d( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -888,7 +888,7 @@ def max_pool3d_with_indices( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool3d_with_indices( input, kernel_size, stride, padding, dilation, ceil_mode ) @@ -916,7 +916,7 @@ def _max_pool3d( return_indices=return_indices, ) if stride is None: - stride = torch.jit.annotate(list[int], []) + stride = torch.jit.annotate(List[int], []) return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -939,7 +939,7 @@ def _unpool_output_size( output_size: Optional[list[int]], ) -> list[int]: input_size = input.size() - default_size = torch.jit.annotate(list[int], []) + default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): default_size.append( (input_size[-len(kernel_size) + d] - 1) * stride[d] diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index af9f5a8386cc..e845d71f8b88 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import Optional, Union +from typing import List, Optional, Union from typing_extensions import deprecated import torch @@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd): f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" ) - min_sizes = torch.jit.annotate(list[int], []) - max_sizes = torch.jit.annotate(list[int], []) + min_sizes = torch.jit.annotate(List[int], []) + max_sizes = torch.jit.annotate(List[int], []) for d in range(num_spatial_dims): dim_size = ( (input.size(d + num_non_spatial_dims) - 1) * stride[d] @@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd): f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" ) - res = torch.jit.annotate(list[int], []) + res = torch.jit.annotate(List[int], []) for d in range(num_spatial_dims): res.append(output_size[d] - min_sizes[d]) diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index f499045dbbbc..b47beae7643a 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from typing import cast, Optional, TYPE_CHECKING, Union +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch import Tensor @@ -25,7 +25,7 @@ class Adafactor(Optimizer): params: ParamsT, lr: Union[float, Tensor] = 1e-2, beta2_decay: float = -0.8, - eps: tuple[Optional[float], float] = (None, 1e-3), + eps: Tuple[Optional[float], float] = (None, 1e-3), d: float = 1.0, weight_decay: float = 0.0, *, @@ -133,12 +133,12 @@ class Adafactor(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - row_vars: list[Optional[Tensor]] = [] - col_vars: list[Optional[Tensor]] = [] - variances: list[Optional[Tensor]] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + row_vars: List[Optional[Tensor]] = [] + col_vars: List[Optional[Tensor]] = [] + variances: List[Optional[Tensor]] = [] + state_steps: List[Tensor] = [] eps1, eps2 = group["eps"] has_complex = self._init_group( @@ -324,16 +324,16 @@ Adafactor.__doc__ = ( def _single_tensor_adafactor( - params: list[Tensor], - grads: list[Tensor], + params: List[Tensor], + grads: List[Tensor], # If grad is 1-dimensional (aka a vector), there is no factorization necessary # so row_var and col_var will be None while variance will be filled. # Contrarily, for a grad with multiple dimensions, we will factor along the last # 2 dimensions, and so row_var and col_var will be filled and variance will be None. - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], - state_steps: list[Tensor], + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -411,17 +411,17 @@ def _single_tensor_adafactor( def _group_tensors_by_device_dtype_and_is_multidim( tensorlists: TensorListList, -) -> dict[ - tuple[Optional[torch.device], Optional[torch.dtype], bool], - list[list[Optional[Tensor]]], +) -> Dict[ + Tuple[Optional[torch.device], Optional[torch.dtype], bool], + List[List[Optional[Tensor]]], ]: """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor has multiple dims or just one dim (is a vector). This allows the foreach impl of Adafactor to assume that every group of params will either be factored or not.""" grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists) - ultra_grouped_tensors: dict[ - tuple[Optional[torch.device], Optional[torch.dtype], bool], - list[list[Optional[Tensor]]], + ultra_grouped_tensors: Dict[ + Tuple[Optional[torch.device], Optional[torch.dtype], bool], + List[List[Optional[Tensor]]], ] = {} for (device, dtype), (tensorlists, _) in grouped_tensors.items(): matrix_key = (device, dtype, True) @@ -444,16 +444,16 @@ def _group_tensors_by_device_dtype_and_is_multidim( def _multi_tensor_adafactor( - params: list[Tensor], - grads: list[Tensor], + params: List[Tensor], + grads: List[Tensor], # If grad is 1-dimensional (aka a vector), there is no factorization necessary # so row_var and col_var will be None while variance will be filled. # Contrarily, for a grad with multiple dimensions, we will factor along the last # 2 dimensions, and so row_var and col_var will be filled and variance will be None. - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], - state_steps: list[Tensor], + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -486,9 +486,9 @@ def _multi_tensor_adafactor( device_state_steps_, ) ) in grouped_tensors.items(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_steps = cast(List[Tensor], device_state_steps_) if eps1 is None: assert ( dtype is not None @@ -530,8 +530,8 @@ def _multi_tensor_adafactor( torch._foreach_mul_(device_params, 1 - lr * weight_decay) if is_multidim: - device_row_vars = cast(list[Tensor], device_row_vars_) - device_col_vars = cast(list[Tensor], device_col_vars_) + device_row_vars = cast(List[Tensor], device_row_vars_) + device_col_vars = cast(List[Tensor], device_col_vars_) assert ( device_row_vars[0] is not None and device_col_vars[0] is not None ), "row_var and col_var should be defined when grad is multidimensional" @@ -564,7 +564,7 @@ def _multi_tensor_adafactor( torch._foreach_div_(var_estimates, row_var_means) del row_var_means else: - device_variances = cast(list[Tensor], device_variances_) + device_variances = cast(List[Tensor], device_variances_) assert ( device_variances[0] is not None ), "variance should be defined when grad is a vector" @@ -592,12 +592,12 @@ def _multi_tensor_adafactor( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor) def adafactor( - params: list[Tensor], - grads: list[Tensor], - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index f48311fb11d8..a307cc76846d 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs r"""Functional interface.""" import math +from typing import List from torch import Tensor @@ -21,11 +22,11 @@ from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 def sparse_adam( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - state_steps: list[int], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], *, eps: float, beta1: float, diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index e1d2f3d203bf..4eb3c0f6b58a 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, cast, Optional, Union +from typing import Any, cast, Dict, List, Optional, Union import torch from torch import Tensor @@ -82,12 +82,12 @@ class Adadelta(Optimizer): def _init_group( self, - group: dict[str, Any], - params_with_grad: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - acc_deltas: list[Tensor], - state_steps: list[Tensor], + group: Dict[str, Any], + params_with_grad: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], ): has_complex = False p: Tensor @@ -139,11 +139,11 @@ class Adadelta(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - square_avgs: list[Tensor] = [] - acc_deltas: list[Tensor] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + acc_deltas: List[Tensor] = [] + state_steps: List[Tensor] = [] ( lr, rho, @@ -242,11 +242,11 @@ Adadelta.__doc__ = ( def _single_tensor_adadelta( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - acc_deltas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], *, lr: float, rho: float, @@ -296,11 +296,11 @@ def _single_tensor_adadelta( def _multi_tensor_adadelta( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - acc_deltas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], *, lr: float, rho: float, @@ -337,11 +337,11 @@ def _multi_tensor_adadelta( device_acc_deltas_, device_state_steps_, ), _ in grouped_tensors.values(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_square_avgs = cast(list[Tensor], device_square_avgs_) - device_acc_deltas = cast(list[Tensor], device_acc_deltas_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_square_avgs = cast(List[Tensor], device_square_avgs_) + device_acc_deltas = cast(List[Tensor], device_acc_deltas_) + device_state_steps = cast(List[Tensor], device_state_steps_) if has_complex: _view_as_real( device_params, device_grads, device_square_avgs, device_acc_deltas @@ -397,11 +397,11 @@ def _multi_tensor_adadelta( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta) def adadelta( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - acc_deltas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim capturable: bool = False, diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 451135c1ad83..8e08b62d1a19 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast, List, Optional, Union import torch from torch import Tensor @@ -157,10 +157,10 @@ class Adagrad(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - state_sums: list[Tensor] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + state_sums: List[Tensor] = [] + state_steps: List[Tensor] = [] has_sparse_grad, has_complex = self._init_group( group, params_with_grad, grads, state_sums, state_steps @@ -240,10 +240,10 @@ Adagrad.__doc__ = ( def adagrad( - params: list[Tensor], - grads: list[Tensor], - state_sums: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], fused: Optional[bool] = None, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None, @@ -319,10 +319,10 @@ def _make_sparse(grad, grad_indices, values): def _single_tensor_adagrad( - params: list[Tensor], - grads: list[Tensor], - state_sums: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -380,10 +380,10 @@ def _single_tensor_adagrad( def _multi_tensor_adagrad( - params: list[Tensor], - grads: list[Tensor], - state_sums: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -412,10 +412,10 @@ def _multi_tensor_adagrad( device_state_sums_, device_state_steps_, ), _ in grouped_tensorlists.values(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_state_sums = cast(list[Tensor], device_state_sums_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_sums = cast(List[Tensor], device_state_sums_) + device_state_steps = cast(List[Tensor], device_state_steps_) device_has_sparse_grad = has_sparse_grad and any( grad.is_sparse for grad in device_grads @@ -487,10 +487,10 @@ def _multi_tensor_adagrad( def _fused_adagrad( - params: list[Tensor], - grads: list[Tensor], - state_sums: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -530,10 +530,10 @@ def _fused_adagrad( ), _, ) in grouped_tensors.items(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_state_sums = cast(list[Tensor], device_state_sums_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_sums = cast(List[Tensor], device_state_sums_) + device_state_steps = cast(List[Tensor], device_state_steps_) device_grad_scale, device_found_inf = None, None if grad_scale is not None and grad_scale_dict is not None: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 9623236f47d0..6b34a7a1b09c 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor @@ -35,7 +35,7 @@ class Adam(Optimizer): self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, - betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, amsgrad: bool = False, @@ -225,12 +225,12 @@ class Adam(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - exp_avgs: list[Tensor] = [] - exp_avg_sqs: list[Tensor] = [] - max_exp_avg_sqs: list[Tensor] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] beta1, beta2 = group["betas"] has_complex = self._init_group( @@ -342,12 +342,12 @@ Adam.__doc__ = ( def _single_tensor_adam( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - max_exp_avg_sqs: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -532,12 +532,12 @@ def _single_tensor_adam( def _multi_tensor_adam( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - max_exp_avg_sqs: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -612,11 +612,11 @@ def _multi_tensor_adam( device_max_exp_avg_sqs_, device_state_steps_, ), _ in grouped_tensors.values(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_exp_avgs = cast(list[Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) device = device_params[0].device if beta1_dict is not None and device not in beta1_dict: @@ -627,7 +627,7 @@ def _multi_tensor_adam( # Handle complex parameters if has_complex: if amsgrad: - device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) _view_as_real( device_params, device_grads, @@ -693,9 +693,9 @@ def _multi_tensor_adam( del device_grads del scaled_device_grads - bias_correction1: Union[tuple[Tensor, ...], list[Tensor]] - bias_correction2: Union[tuple[Tensor, ...], list[Tensor]] - bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] if capturable: bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type] @@ -719,7 +719,7 @@ def _multi_tensor_adam( bias_correction2_sqrt = bias_correction2 if amsgrad: - device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) # Maintains the maximum of all 2nd moment running avg. till now torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] @@ -747,7 +747,7 @@ def _multi_tensor_adam( bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type] if amsgrad: - device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) # Maintains the maximum of all 2nd moment running avg. till now torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) @@ -764,12 +764,12 @@ def _multi_tensor_adam( def _fused_adam( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - max_exp_avg_sqs: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -816,11 +816,11 @@ def _fused_adam( ), _, ) in grouped_tensors.items(): - device_params = cast(list[Tensor], device_params_) - device_grads = cast(list[Tensor], device_grads_) - device_exp_avgs = cast(list[Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[Tensor], device_state_steps_) + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) if device.type == "mps": # type: ignore[union-attr] assert found_inf is None and grad_scale is None @@ -864,12 +864,12 @@ def _fused_adam( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) def adam( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - max_exp_avg_sqs: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index b61a3f61b668..63984b1e9326 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Optional, Union +from typing import List, Optional, Tuple, Union from torch import Tensor @@ -23,7 +23,7 @@ class AdamW(Adam): self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, - betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, amsgrad: bool = False, @@ -128,12 +128,12 @@ AdamW.__doc__ = ( # @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam def adamw( - params: list[Tensor], - grads: list[Tensor], - exp_avgs: list[Tensor], - exp_avg_sqs: list[Tensor], - max_exp_avg_sqs: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index ca798a24f38a..2fcabe3786ad 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor @@ -135,12 +135,12 @@ class ASGD(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - mus: list[Tensor] = [] - axs: list[Tensor] = [] - etas: list[Tensor] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + mus: List[Tensor] = [] + axs: List[Tensor] = [] + etas: List[Tensor] = [] + state_steps: List[Tensor] = [] has_complex = self._init_group( group, params_with_grad, grads, mus, axs, etas, state_steps @@ -192,12 +192,12 @@ ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. def _single_tensor_asgd( - params: list[Tensor], - grads: list[Tensor], - axs: list[Tensor], - mus: list[Tensor], - etas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], *, lambd: float, lr: float, @@ -268,12 +268,12 @@ def _single_tensor_asgd( def _multi_tensor_asgd( - params: list[Tensor], - grads: list[Tensor], - axs: list[Tensor], - mus: list[Tensor], - etas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], *, lambd: float, lr: float, @@ -315,12 +315,12 @@ def _multi_tensor_asgd( ), _, ) in grouped_tensors.items(): - grouped_params = cast(list[Tensor], grouped_params_) - grouped_grads = cast(list[Tensor], grouped_grads_) - grouped_axs = cast(list[Tensor], grouped_axs_) - grouped_mus = cast(list[Tensor], grouped_mus_) - grouped_etas = cast(list[Tensor], grouped_etas_) - grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_axs = cast(List[Tensor], grouped_axs_) + grouped_mus = cast(List[Tensor], grouped_mus_) + grouped_etas = cast(List[Tensor], grouped_etas_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) if has_complex: _view_as_real(grouped_params, grouped_grads, grouped_axs) @@ -340,7 +340,7 @@ def _multi_tensor_asgd( torch._foreach_add_(grouped_state_steps, 1) # intermediate = grad + param * lambd - intermediate: Union[tuple[Tensor, ...], list[Tensor]] + intermediate: Union[Tuple[Tensor, ...], List[Tensor]] if weight_decay != 0: if maximize: torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) @@ -375,8 +375,8 @@ def _multi_tensor_asgd( torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) del intermediate - new_etas: Union[tuple[Tensor, ...], list[Tensor]] - new_mus: Union[tuple[Tensor, ...], list[Tensor]] + new_etas: Union[Tuple[Tensor, ...], List[Tensor]] + new_mus: Union[Tuple[Tensor, ...], List[Tensor]] if capturable: # update grouped_mus new_mus = torch._foreach_sub(grouped_state_steps, t0) @@ -408,12 +408,12 @@ def _multi_tensor_asgd( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) def asgd( - params: list[Tensor], - grads: list[Tensor], - axs: list[Tensor], - mus: list[Tensor], - etas: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 3e65e7bc156c..abbeb51edfb0 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -5,14 +5,17 @@ import types import warnings from bisect import bisect_right from collections import Counter -from collections.abc import Iterable, Sequence from functools import partial, wraps from typing import ( Any, Callable, cast, + Dict, + Iterable, + List, Literal, Optional, + Sequence, SupportsFloat, TypedDict, Union, @@ -113,7 +116,7 @@ class LRScheduler: "param 'initial_lr' is not specified " f"in param_groups[{i}] when resuming an optimizer" ) - self.base_lrs: list[float] = [ + self.base_lrs: List[float] = [ group["initial_lr"] for group in optimizer.param_groups ] self.last_epoch = last_epoch @@ -160,7 +163,7 @@ class LRScheduler: key: value for key, value in self.__dict__.items() if key != "optimizer" } - def load_state_dict(self, state_dict: dict[str, Any]): + def load_state_dict(self, state_dict: Dict[str, Any]): """Load the scheduler's state. Args: @@ -169,18 +172,18 @@ class LRScheduler: """ self.__dict__.update(state_dict) - def get_last_lr(self) -> list[float]: + def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler.""" return self._last_lr - def get_lr(self) -> list[float]: + def get_lr(self) -> List[float]: """Compute learning rate using chainable form of the scheduler.""" raise NotImplementedError def print_lr( self, is_verbose: bool, - group: dict[str, Any], + group: Dict[str, Any], lr: float, epoch: Optional[int] = None, ): @@ -240,7 +243,7 @@ class LRScheduler: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): - values = cast(list[float], self._get_closed_form_lr()) + values = cast(List[float], self._get_closed_form_lr()) else: values = self.get_lr() @@ -250,7 +253,7 @@ class LRScheduler: else: param_group["lr"] = lr - self._last_lr: list[float] = [ + self._last_lr: List[float] = [ group["lr"] for group in self.optimizer.param_groups ] @@ -317,13 +320,13 @@ class LambdaLR(LRScheduler): def __init__( self, optimizer: Optimizer, - lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer - self.lr_lambdas: list[Callable[[int], float]] + self.lr_lambdas: List[Callable[[int], float]] if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: @@ -417,13 +420,13 @@ class MultiplicativeLR(LRScheduler): def __init__( self, optimizer: Optimizer, - lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer - self.lr_lambdas: list[Callable[[int], float]] + self.lr_lambdas: List[Callable[[int], float]] if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: @@ -860,8 +863,8 @@ class SequentialLR(LRScheduler): def __init__( self, optimizer: Optimizer, - schedulers: list[LRScheduler], - milestones: list[int], + schedulers: List[LRScheduler], + milestones: List[int], last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 @@ -1308,7 +1311,7 @@ class ReduceLROnPlateau(LRScheduler): threshold: float = 1e-4, threshold_mode: Literal["rel", "abs"] = "rel", cooldown: int = 0, - min_lr: Union[list[float], float] = 0, + min_lr: Union[List[float], float] = 0, eps: float = 1e-8, verbose="deprecated", ): # noqa: D107 @@ -1554,8 +1557,8 @@ class CyclicLR(LRScheduler): def __init__( self, optimizer: Optimizer, - base_lr: Union[float, list[float]], - max_lr: Union[float, list[float]], + base_lr: Union[float, List[float]], + max_lr: Union[float, List[float]], step_size_up: int = 2000, step_size_down: Optional[int] = None, mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", @@ -1985,15 +1988,15 @@ class OneCycleLR(LRScheduler): def __init__( self, optimizer: Optimizer, - max_lr: Union[float, list[float]], + max_lr: Union[float, List[float]], total_steps: Optional[int] = None, epochs: Optional[int] = None, steps_per_epoch: Optional[int] = None, pct_start: float = 0.3, anneal_strategy: Literal["cos", "linear"] = "cos", cycle_momentum: bool = True, - base_momentum: Union[float, list[float]] = 0.85, - max_momentum: Union[float, list[float]] = 0.95, + base_momentum: Union[float, List[float]] = 0.85, + max_momentum: Union[float, List[float]] = 0.95, div_factor: float = 25.0, final_div_factor: float = 1e4, three_phase: bool = False, @@ -2025,7 +2028,7 @@ class OneCycleLR(LRScheduler): "You must define either total_steps OR (epochs AND steps_per_epoch)" ) - self._schedule_phases: list[_SchedulePhase] + self._schedule_phases: List[_SchedulePhase] if three_phase: self._schedule_phases = [ { diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index cee7b25c8f2e..dedb248cdec0 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -3,10 +3,25 @@ import functools import warnings from collections import defaultdict, OrderedDict -from collections.abc import Hashable, Iterable, Sequence from copy import deepcopy from itertools import chain -from typing import Any, Callable, cast, Optional, overload, TypeVar, Union +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + overload, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) from typing_extensions import ParamSpec, Self, TypeAlias import torch @@ -24,15 +39,15 @@ from torch.utils.hooks import RemovableHandle _T = TypeVar("_T") _P = ParamSpec("_P") -Args: TypeAlias = tuple[Any, ...] -Kwargs: TypeAlias = dict[str, Any] -StateDict: TypeAlias = dict[str, Any] -DeviceDict = dict[Optional[torch.device], torch.Tensor] -DeviceDtypeDict = dict[Optional[tuple[torch.device, torch.dtype]], torch.Tensor] +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] +DeviceDict = Dict[Optional[torch.device], torch.Tensor] +DeviceDtypeDict = Dict[Optional[Tuple[torch.device, torch.dtype]], torch.Tensor] GlobalOptimizerPreHook: TypeAlias = Callable[ - ["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]] + ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] ] GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] @@ -41,8 +56,8 @@ __all__ = [ "register_optimizer_step_pre_hook", "register_optimizer_step_post_hook", ] -_global_optimizer_pre_hooks: dict[int, GlobalOptimizerPreHook] = OrderedDict() -_global_optimizer_post_hooks: dict[int, GlobalOptimizerPostHook] = OrderedDict() +_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() +_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] @@ -158,8 +173,8 @@ def _disable_dynamo_if_unsupported( # torch.jit.script nor differentiable, so we fall back to the single tensor # implementation in those cases. def _default_to_fused_or_foreach( - params: list[torch.Tensor], differentiable: bool, use_fused: bool = False -) -> tuple[bool, bool]: + params: List[torch.Tensor], differentiable: bool, use_fused: bool = False +) -> Tuple[bool, bool]: if torch.jit.is_scripting() or differentiable: return False, False @@ -214,7 +229,7 @@ def _get_scalar_dtype(is_fused=None): ) -def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]: +def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: r"""Return the device type list that supports capturable optimizer.""" capturable_supported_devices = ["cuda", "xpu", "hpu"] if not torch.jit.is_scripting(): @@ -306,7 +321,7 @@ def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> Removabl ParamsT: TypeAlias = Union[ - Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]] + Iterable[torch.Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, torch.Tensor]] ] R = TypeVar("R") @@ -328,17 +343,17 @@ class Optimizer: options (used when a parameter group doesn't specify them). """ - OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc] + OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] - _optimizer_step_pre_hooks: dict[int, OptimizerPreHook] - _optimizer_step_post_hooks: dict[int, OptimizerPostHook] + _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] + _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' - def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107 + def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 torch._C._log_api_usage_once("python.optimizer") self.defaults = defaults self._optimizer_step_pre_hooks = OrderedDict() @@ -356,8 +371,8 @@ class Optimizer: "an iterable of Tensors or dicts, but got " + torch.typename(params) ) - self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict) - self.param_groups: list[dict[str, Any]] = [] + self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) + self.param_groups: List[Dict[str, Any]] = [] param_groups = list(params) if len(param_groups) == 0: @@ -373,14 +388,14 @@ class Optimizer: # https://github.com/pytorch/pytorch/issues/72948 self._warned_capturable_if_run_uncaptured = True - def __getstate__(self) -> dict[str, Any]: # noqa: D105 + def __getstate__(self) -> Dict[str, Any]: # noqa: D105 return { "defaults": self.defaults, "state": self.state, "param_groups": self.param_groups, } - def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105 + def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 self.__dict__.update(state) if "_optimizer_step_pre_hooks" not in self.__dict__: self._optimizer_step_pre_hooks = OrderedDict() @@ -501,8 +516,8 @@ class Optimizer: tensorlistlist: TensorListList, with_indices: bool = False, ) -> Union[ - dict[tuple[None, None], tuple[TensorListList, Indices]], - dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]], + Dict[Tuple[None, None], Tuple[TensorListList, Indices]], + Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], ]: """Group a list of lists of tensors by device and dtype. @@ -690,10 +705,10 @@ class Optimizer: pre_hook(self) # Save order indices instead of Tensors - param_mappings: dict[int, int] = {} + param_mappings: Dict[int, int] = {} start_index = 0 - def pack_group(group: dict[str, Any]) -> dict[str, Any]: + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: nonlocal start_index packed = {k: v for k, v in group.items() if k != "params"} param_mappings.update( @@ -730,7 +745,7 @@ class Optimizer: param: torch.Tensor, value: torch.Tensor, param_id: int, - param_groups: list[dict[Any, Any]], + param_groups: List[Dict[Any, Any]], key: Hashable = None, ) -> torch.Tensor: # Floating-point types are a bit special here. They are the only ones @@ -903,7 +918,7 @@ class Optimizer: # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). - state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict) + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] @@ -915,8 +930,8 @@ class Optimizer: # Update parameter groups, setting their 'params' value def update_group( - group: dict[str, Any], new_group: dict[str, Any] - ) -> dict[str, Any]: + group: Dict[str, Any], new_group: Dict[str, Any] + ) -> Dict[str, Any]: new_group["params"] = group["params"] if "param_names" in group and "param_names" not in new_group: new_group["param_names"] = group["param_names"] @@ -952,7 +967,7 @@ class Optimizer: self._patch_step_function() per_device_and_dtype_grads: Optional[ - defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]] + DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] ] if foreach: per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) @@ -1001,7 +1016,7 @@ class Optimizer: raise NotImplementedError @torch._disable_dynamo - def add_param_group(self, param_group: dict[str, Any]) -> None: + def add_param_group(self, param_group: Dict[str, Any]) -> None: r"""Add a param group to the :class:`Optimizer` s `param_groups`. This can be useful when fine tuning a pre-trained network as frozen layers can be made @@ -1072,7 +1087,7 @@ class Optimizer: stacklevel=3, ) - param_set: set[torch.Tensor] = set() + param_set: Set[torch.Tensor] = set() for group in self.param_groups: param_set.update(set(group["params"])) if ("param_names" in param_group) != ("param_names" in group): diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 21c06721165f..399cddfdb761 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the RMSprop algorithm.""" -from typing import cast, Optional, Union +from typing import cast, List, Optional, Union import torch from torch import Tensor @@ -155,12 +155,12 @@ class RMSprop(Optimizer): # noqa: D101 loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - square_avgs: list[Tensor] = [] - grad_avgs: list[Tensor] = [] - momentum_buffer_list: list[Tensor] = [] - state_steps: list[Tensor] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + grad_avgs: List[Tensor] = [] + momentum_buffer_list: List[Tensor] = [] + state_steps: List[Tensor] = [] has_complex = self._init_group( group, @@ -261,12 +261,12 @@ RMSprop.__doc__ = ( def _single_tensor_rmsprop( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - grad_avgs: list[Tensor], - momentum_buffer_list: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], *, lr: float, alpha: float, @@ -332,12 +332,12 @@ def _single_tensor_rmsprop( def _multi_tensor_rmsprop( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - grad_avgs: list[Tensor], - momentum_buffer_list: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], *, lr: float, alpha: float, @@ -377,20 +377,20 @@ def _multi_tensor_rmsprop( grouped_state_steps_, ) ), _ in grouped_tensors.values(): - grouped_params = cast(list[Tensor], grouped_params_) - grouped_grads = cast(list[Tensor], grouped_grads_) - grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_) - grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) if has_complex: state_and_grads = [grouped_grads, grouped_square_avgs] if momentum > 0: grouped_momentum_buffer_list = cast( - list[Tensor], grouped_momentum_buffer_list_ + List[Tensor], grouped_momentum_buffer_list_ ) state_and_grads.append(grouped_momentum_buffer_list) if centered: - grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_) + grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) state_and_grads.append(grouped_grad_avgs) _view_as_real(grouped_params, *state_and_grads) @@ -423,7 +423,7 @@ def _multi_tensor_rmsprop( ) if centered: - grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_) + grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha) avg = torch._foreach_addcmul( grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1 @@ -436,7 +436,7 @@ def _multi_tensor_rmsprop( if momentum > 0: grouped_momentum_buffer_list = cast( - list[Tensor], grouped_momentum_buffer_list_ + List[Tensor], grouped_momentum_buffer_list_ ) torch._foreach_mul_(grouped_momentum_buffer_list, momentum) torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg) @@ -461,12 +461,12 @@ def _multi_tensor_rmsprop( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop) def rmsprop( - params: list[Tensor], - grads: list[Tensor], - square_avgs: list[Tensor], - grad_avgs: list[Tensor], - momentum_buffer_list: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 69f489fc9458..d01b879738a7 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the Resilient backpropagation.""" -from typing import cast, Optional, Union +from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor @@ -30,8 +30,8 @@ class Rprop(Optimizer): # noqa: D101 self, params: ParamsT, lr: Union[float, Tensor] = 1e-2, - etas: tuple[float, float] = (0.5, 1.2), - step_sizes: tuple[float, float] = (1e-6, 50), + etas: Tuple[float, float] = (0.5, 1.2), + step_sizes: Tuple[float, float] = (1e-6, 50), *, capturable: bool = False, foreach: Optional[bool] = None, @@ -129,11 +129,11 @@ class Rprop(Optimizer): # noqa: D101 loss = closure() for group in self.param_groups: - params: list[Tensor] = [] - grads: list[Tensor] = [] - prevs: list[Tensor] = [] - step_sizes: list[Tensor] = [] - state_steps: list[Tensor] = [] + params: List[Tensor] = [] + grads: List[Tensor] = [] + prevs: List[Tensor] = [] + step_sizes: List[Tensor] = [] + state_steps: List[Tensor] = [] etaminus, etaplus = group["etas"] step_size_min, step_size_max = group["step_sizes"] @@ -219,11 +219,11 @@ Rprop.__doc__ = ( def _single_tensor_rprop( - params: list[Tensor], - grads: list[Tensor], - prevs: list[Tensor], - step_sizes: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], *, step_size_min: float, step_size_max: float, @@ -287,11 +287,11 @@ def _single_tensor_rprop( def _multi_tensor_rprop( - params: list[Tensor], - grads: list[Tensor], - prevs: list[Tensor], - step_sizes: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], *, step_size_min: float, step_size_max: float, @@ -326,11 +326,11 @@ def _multi_tensor_rprop( grouped_step_sizes_, grouped_state_steps_, ), _ in grouped_tensors.values(): - grouped_params = cast(list[Tensor], grouped_params_) - grouped_grads = cast(list[Tensor], grouped_grads_) - grouped_prevs = cast(list[Tensor], grouped_prevs_) - grouped_step_sizes = cast(list[Tensor], grouped_step_sizes_) - grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_prevs = cast(List[Tensor], grouped_prevs_) + grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over @@ -402,11 +402,11 @@ def _multi_tensor_rprop( @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop) def rprop( - params: list[Tensor], - grads: list[Tensor], - prevs: list[Tensor], - step_sizes: list[Tensor], - state_steps: list[Tensor], + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 7e9a964c2f21..440b09f0f257 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for Stochastic Gradient Descent optimizer.""" -from typing import cast, Optional, Union +from typing import cast, List, Optional, Union import torch from torch import Tensor @@ -114,9 +114,9 @@ class SGD(Optimizer): # noqa: D101 loss = closure() for group in self.param_groups: - params: list[Tensor] = [] - grads: list[Tensor] = [] - momentum_buffer_list: list[Optional[Tensor]] = [] + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] has_sparse_grad = self._init_group( group, params, grads, momentum_buffer_list @@ -244,9 +244,9 @@ SGD.__doc__ = ( def sgd( - params: list[Tensor], - d_p_list: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim has_sparse_grad: bool = False, @@ -314,9 +314,9 @@ def sgd( def _single_tensor_sgd( - params: list[Tensor], - grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -369,9 +369,9 @@ def _single_tensor_sgd( def _multi_tensor_sgd( - params: list[Tensor], - grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -397,8 +397,8 @@ def _multi_tensor_sgd( device_grads_, device_momentum_buffer_list, ), indices in grouped_tensors.values(): - device_params: list[Tensor] = cast(list[Tensor], device_params_) - device_grads: list[Tensor] = cast(list[Tensor], device_grads_) + device_params: List[Tensor] = cast(List[Tensor], device_params_) + device_grads: List[Tensor] = cast(List[Tensor], device_grads_) device_has_sparse_grad = has_sparse_grad and any( grad.is_sparse for grad in device_grads @@ -417,7 +417,7 @@ def _multi_tensor_sgd( ) if momentum != 0: - bufs: list[Tensor] = [] + bufs: List[Tensor] = [] all_states_with_momentum_buffer = True for i in range(len(device_momentum_buffer_list)): @@ -462,9 +462,9 @@ def _multi_tensor_sgd( def _fused_sgd( - params: list[Tensor], - grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, @@ -501,8 +501,8 @@ def _fused_sgd( (device_params_, device_grads_, device_momentum_buffer_list), _, ) in grouped_tensors.items(): - device_params: list[Tensor] = cast(list[Tensor], device_params_) - device_grads: list[Tensor] = cast(list[Tensor], device_grads_) + device_params: List[Tensor] = cast(List[Tensor], device_params_) + device_grads: List[Tensor] = cast(List[Tensor], device_grads_) device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( @@ -515,7 +515,7 @@ def _fused_sgd( device_grads, [] if no_momentum_buffer - else cast(list[Tensor], device_momentum_buffer_list), + else cast(List[Tensor], device_momentum_buffer_list), weight_decay=weight_decay, momentum=momentum, lr=lr, diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 09814a9746c0..23ac70678e2e 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Union +from typing import List, Tuple, Union import torch from torch import Tensor @@ -16,7 +16,7 @@ class SparseAdam(Optimizer): self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, - betas: tuple[float, float] = (0.9, 0.999), + betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, maximize: bool = False, ): @@ -69,11 +69,11 @@ class SparseAdam(Optimizer): loss = closure() for group in self.param_groups: - params_with_grad: list[Tensor] = [] - grads: list[Tensor] = [] - exp_avgs: list[Tensor] = [] - exp_avg_sqs: list[Tensor] = [] - state_steps: list[int] = [] + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[int] = [] beta1, beta2 = group["betas"] maximize = group.get("maximize", False) diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index fffd9462dd22..d568283d2d8f 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -3,9 +3,8 @@ r"""Implementation for Stochastic Weight Averaging implementation.""" import itertools import math import warnings -from collections.abc import Iterable from copy import deepcopy -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union import torch from torch import Tensor @@ -29,7 +28,7 @@ __all__ = [ from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype -PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]] +PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] def get_ema_multi_avg_fn(decay=0.999): @@ -254,8 +253,8 @@ class AveragedModel(Module): if self.use_buffers else model.parameters() ) - self_param_detached: list[Optional[Tensor]] = [] - model_param_detached: list[Optional[Tensor]] = [] + self_param_detached: List[Optional[Tensor]] = [] + model_param_detached: List[Optional[Tensor]] = [] for p_averaged, p_model in zip(self_param, model_param): p_model_ = p_model.detach().to(p_averaged.device) self_param_detached.append(p_averaged.detach()) diff --git a/torch/package/_digraph.py b/torch/package/_digraph.py index b98b49b507a3..8b753f7ebdc4 100644 --- a/torch/package/_digraph.py +++ b/torch/package/_digraph.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs from collections import deque +from typing import List, Set class DiGraph: @@ -89,7 +90,7 @@ class DiGraph: except TypeError: return False - def forward_transitive_closure(self, src: str) -> set[str]: + def forward_transitive_closure(self, src: str) -> Set[str]: """Returns a set of nodes that are reachable from src""" result = set(src) @@ -102,7 +103,7 @@ class DiGraph: working_set.append(n) return result - def backward_transitive_closure(self, src: str) -> set[str]: + def backward_transitive_closure(self, src: str) -> Set[str]: """Returns a set of nodes that are reachable from src in reverse direction""" result = set(src) @@ -139,7 +140,7 @@ class DiGraph: return result_graph.to_dot() - def first_path(self, dst: str) -> list[str]: + def first_path(self, dst: str) -> List[str]: """Returns a list of nodes that show the first path that resulted in dst being added to the graph.""" path = [] diff --git a/torch/package/analyze/find_first_use_of_broken_modules.py b/torch/package/analyze/find_first_use_of_broken_modules.py index 728f3289b5cd..b3016a56c2a4 100644 --- a/torch/package/analyze/find_first_use_of_broken_modules.py +++ b/torch/package/analyze/find_first_use_of_broken_modules.py @@ -1,10 +1,12 @@ +from typing import Dict, List + from torch.package.package_exporter import PackagingError __all__ = ["find_first_use_of_broken_modules"] -def find_first_use_of_broken_modules(exc: PackagingError) -> dict[str, list[str]]: +def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]: """ Find all broken modules in a PackagingError, and for each one, return the dependency path in which the module was first encountered. diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index e029fe130bdd..23f6c998385b 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -1,15 +1,14 @@ # mypy: allow-untyped-defs import sys -from collections.abc import Iterable -from typing import Any, Callable +from typing import Any, Callable, Iterable, List, Tuple __all__ = ["trace_dependencies"] def trace_dependencies( - callable: Callable[[Any], Any], inputs: Iterable[tuple[Any, ...]] -) -> list[str]: + callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]] +) -> List[str]: """Trace the execution of a callable in order to determine which modules it uses. Args: diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 8ef00e0159d8..e1137234ab73 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Dict, List from .glob_group import GlobGroup, GlobPattern @@ -14,9 +15,9 @@ class Directory: def __init__(self, name: str, is_dir: bool): self.name = name self.is_dir = is_dir - self.children: dict[str, Directory] = {} + self.children: Dict[str, Directory] = {} - def _get_dir(self, dirs: list[str]) -> "Directory": + def _get_dir(self, dirs: List[str]) -> "Directory": """Builds path of Directories if not yet built and returns last directory in list. @@ -63,13 +64,13 @@ class Directory: return False def __str__(self): - str_list: list[str] = [] + str_list: List[str] = [] self._stringify_tree(str_list) return "".join(str_list) def _stringify_tree( self, - str_list: list[str], + str_list: List[str], preamble: str = "", dir_ptr: str = "\u2500\u2500\u2500 ", ): @@ -88,8 +89,8 @@ class Directory: else: preamble = preamble + space - file_keys: list[str] = [] - dir_keys: list[str] = [] + file_keys: List[str] = [] + dir_keys: List[str] = [] for key, val in self.children.items(): if val.is_dir: dir_keys.append(key) @@ -108,7 +109,7 @@ class Directory: def _create_directory_from_file_list( filename: str, - file_list: list[str], + file_list: List[str], include: "GlobPattern" = "**", exclude: "GlobPattern" = (), ) -> Directory: diff --git a/torch/package/find_file_dependencies.py b/torch/package/find_file_dependencies.py index 216af0d6aebe..7f2386b3ce50 100644 --- a/torch/package/find_file_dependencies.py +++ b/torch/package/find_file_dependencies.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import ast -from typing import Optional +from typing import List, Optional, Tuple from ._importlib import _resolve_name @@ -11,7 +11,7 @@ class _ExtractModuleReferences(ast.NodeVisitor): """ @classmethod - def run(cls, src: str, package: str) -> list[tuple[str, Optional[str]]]: + def run(cls, src: str, package: str) -> List[Tuple[str, Optional[str]]]: visitor = cls(package) tree = ast.parse(src) visitor.visit(tree) @@ -53,7 +53,7 @@ class _ExtractModuleReferences(ast.NodeVisitor): if hasattr(node.func, "id") and node.func.id == "__import__": try: name = self._grab_node_str(node.args[0]) - fromlist: list[str] = [] + fromlist: List[str] = [] level = 0 if len(node.args) > 3: fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts) diff --git a/torch/package/glob_group.py b/torch/package/glob_group.py index 986938cd256e..1c1d31930fd1 100644 --- a/torch/package/glob_group.py +++ b/torch/package/glob_group.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import re -from collections.abc import Iterable -from typing import Union +from typing import Iterable, Union GlobPattern = Union[str, Iterable[str]] diff --git a/torch/package/importer.py b/torch/package/importer.py index 49b4512f79a6..2fb2891e076c 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -7,7 +7,7 @@ from pickle import ( # type: ignore[attr-defined] whichmodule as _pickle_whichmodule, ) from types import ModuleType -from typing import Any, Optional +from typing import Any, Dict, List, Optional, Tuple from ._mangling import demangle, get_mangle_prefix, is_mangled @@ -44,7 +44,7 @@ class Importer(ABC): assert obj1 is obj2 """ - modules: dict[str, ModuleType] + modules: Dict[str, ModuleType] @abstractmethod def import_module(self, module_name: str) -> ModuleType: @@ -53,7 +53,7 @@ class Importer(ABC): The contract is the same as for importlib.import_module. """ - def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: + def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: """Given an object, return a name that can be used to retrieve the object from this environment. @@ -184,7 +184,7 @@ class OrderedImporter(Importer): """ def __init__(self, *args): - self._importers: list[Importer] = list(args) + self._importers: List[Importer] = list(args) def _is_torchpackage_dummy(self, module): """Returns true iff this module is an empty PackageNode in a torch.package. diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 796936b1f3ed..2ece831fab00 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -7,12 +7,23 @@ import pickletools import platform import types from collections import defaultdict, OrderedDict -from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from importlib.machinery import SourceFileLoader from pathlib import Path -from typing import Any, BinaryIO, Callable, cast, Optional, Union +from typing import ( + Any, + BinaryIO, + Callable, + cast, + DefaultDict, + Dict, + List, + Optional, + Sequence, + Set, + Union, +) import torch from torch.serialization import location_tag, normalize_storage_type @@ -122,7 +133,7 @@ class PackagingError(Exception): def __init__(self, dependency_graph: DiGraph, debug=False): # Group errors by reason. - broken: dict[PackagingErrorReason, list[str]] = defaultdict(list) + broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list) for module_name, attrs in dependency_graph.nodes.items(): error = attrs.get("error") if error is None: @@ -225,9 +236,9 @@ class PackageExporter: self.zip_file = torch._C.PyTorchFileWriter(f) self.zip_file.set_min_version(6) - self._written_files: set[str] = set() + self._written_files: Set[str] = set() - self.serialized_reduces: dict[int, Any] = {} + self.serialized_reduces: Dict[int, Any] = {} # A graph tracking all the modules and pickle objects added to this # package and the dependencies between them. @@ -255,7 +266,7 @@ class PackageExporter: ) self.importer = OrderedImporter(*importer) - self.patterns: dict[GlobGroup, _PatternInfo] = {} + self.patterns: Dict[GlobGroup, _PatternInfo] = {} self._unique_id = 0 def save_source_file( @@ -320,7 +331,7 @@ class PackageExporter: def _get_dependencies( self, src: str, module_name: str, is_package: bool - ) -> list[str]: + ) -> List[str]: """Return all modules that this source code depends on. Dependencies are found by scanning the source code for import-like statements. @@ -648,7 +659,7 @@ class PackageExporter: all_dependencies = [] module = None field = None - memo: defaultdict[int, str] = defaultdict(None) + memo: DefaultDict[int, str] = defaultdict(None) memo_count = 0 # pickletools.dis(data_value) for opcode, arg, _pos in pickletools.genops(data_value): @@ -1104,7 +1115,7 @@ class PackageExporter: def _nodes_with_action_type( self, action: Optional[_ModuleProviderAction] - ) -> list[str]: + ) -> List[str]: result = [] for name, node_dict in self.dependency_graph.nodes.items(): node_action = node_dict.get("action", None) @@ -1113,7 +1124,7 @@ class PackageExporter: result.sort() return result - def externed_modules(self) -> list[str]: + def externed_modules(self) -> List[str]: """Return all modules that are currently externed. Returns: @@ -1122,7 +1133,7 @@ class PackageExporter: """ return self._nodes_with_action_type(_ModuleProviderAction.EXTERN) - def interned_modules(self) -> list[str]: + def interned_modules(self) -> List[str]: """Return all modules that are currently interned. Returns: @@ -1131,7 +1142,7 @@ class PackageExporter: """ return self._nodes_with_action_type(_ModuleProviderAction.INTERN) - def mocked_modules(self) -> list[str]: + def mocked_modules(self) -> List[str]: """Return all modules that are currently mocked. Returns: @@ -1140,7 +1151,7 @@ class PackageExporter: """ return self._nodes_with_action_type(_ModuleProviderAction.MOCK) - def denied_modules(self) -> list[str]: + def denied_modules(self) -> List[str]: """Return all modules that are currently denied. Returns: @@ -1149,7 +1160,7 @@ class PackageExporter: """ return self._nodes_with_action_type(_ModuleProviderAction.DENY) - def get_rdeps(self, module_name: str) -> list[str]: + def get_rdeps(self, module_name: str) -> List[str]: """Return a list of all modules which depend on the module ``module_name``. Returns: diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 971b5398ec63..a9a577c4fcba 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -8,9 +8,19 @@ import linecache import os import sys import types -from collections.abc import Iterable from contextlib import contextmanager -from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + BinaryIO, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + TYPE_CHECKING, + Union, +) from weakref import WeakValueDictionary import torch @@ -54,7 +64,7 @@ IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ # The primary motivation is to enable Numpy upgrade that many modules # depend on. The latest release of Numpy removed `numpy.str` and # `numpy.bool` breaking unpickling for many modules. -EXTERN_IMPORT_COMPAT_NAME_MAPPING: dict[str, dict[str, Any]] = { +EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = { "numpy": { "str": str, "bool": bool, @@ -80,7 +90,7 @@ class PackageImporter(Importer): local to this importer. """ - modules: dict[str, types.ModuleType] + modules: Dict[str, types.ModuleType] def __init__( self, @@ -636,7 +646,7 @@ class PackageImporter(Importer): return f"{name.replace('.', '/')}" def _get_or_create_package( - self, atoms: list[str] + self, atoms: List[str] ) -> "Union[_PackageNode, _ExternNode]": cur = self.root for i, atom in enumerate(atoms): @@ -695,7 +705,7 @@ class _PathNode: class _PackageNode(_PathNode): def __init__(self, source_file: Optional[str]): self.source_file = source_file - self.children: dict[str, _PathNode] = {} + self.children: Dict[str, _PathNode] = {} class _ModuleNode(_PathNode): diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 72c13c87a25a..8c72300ce0e4 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -4,8 +4,7 @@ import dataclasses import enum import itertools as it import logging -from collections.abc import Iterator -from typing import Any, cast, Optional, Union +from typing import Any, cast, DefaultDict, Dict, Iterator, List, Optional, Set, Union from typing_extensions import Literal import torch @@ -227,7 +226,7 @@ class SchemaMatcher: overload. If we cannot find any valid schema then we must be conservative and assume all inputs are mutable. """ - mutable: Optional[list[bool]] = None + mutable: Optional[List[bool]] = None for schema in cls.match_schemas(t): mutable = mutable or [False for _ in schema.arguments] for i, arg in enumerate(schema.arguments): @@ -328,7 +327,7 @@ class OpTree: class SizeMap: def __init__(self, op_tree: OpTree) -> None: - self._values: dict[TensorKey, int] = {} + self._values: Dict[TensorKey, int] = {} for node in op_tree.sorted_nodes: if node.typed[0] == _EventType.TorchOp: @@ -350,7 +349,7 @@ class SizeMap: for _, t in state: self._update_values(t) - allocations: dict[TensorKey, int] = {} + allocations: Dict[TensorKey, int] = {} for node in op_tree.sorted_nodes: if node.typed[0] == _EventType.Allocation: alloc_fields = node.typed[1] @@ -411,7 +410,7 @@ class DataFlowNode: def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None: self._event = event self._graph = graph - self._edges: dict[TensorKey, DataFlowEdge] = self._determine_edges() + self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges() for key, edge in self._edges.items(): if edge.mutated and not edge.is_allocation: @@ -421,11 +420,11 @@ class DataFlowNode: versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()} assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}" - def _determine_edges(self) -> dict[TensorKey, DataFlowEdge]: + def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]: subtree = tuple(_utils.traverse_dfs([self._event])) # Start by populating edges from op inputs and outputs. - mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {} + mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {} for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp): for op_input, mutable in zip( op.inputs, SchemaMatcher.inputs_are_mutable(op) @@ -441,7 +440,7 @@ class DataFlowNode: key = TensorKey.from_tensor(op_input_i) mutable_by_key.setdefault(key, set()).add(mutable) - edges: collections.defaultdict[Optional[TensorKey], DataFlowEdge] + edges: DefaultDict[Optional[TensorKey], DataFlowEdge] edges = collections.defaultdict(DataFlowEdge) for key, mutable_set in mutable_by_key.items(): if key is not None: @@ -473,7 +472,7 @@ class DataFlowNode: return dict(sorted((k, v) for k, v in edges.items() if k is not None)) @property - def inputs(self) -> dict[TensorKey, tuple[bool, int]]: + def inputs(self) -> Dict[TensorKey, tuple[bool, int]]: return { # MyPy can't see through `is_allocation` to know that # `v.input_version` is not None. @@ -483,7 +482,7 @@ class DataFlowNode: } @property - def outputs(self) -> dict[TensorKey, int]: + def outputs(self) -> Dict[TensorKey, int]: return { k: 0 if v.input_version is None else v.input_version + 1 for k, v in self._edges.items() @@ -505,7 +504,7 @@ class DataFlowGraph: def __init__(self, op_tree: OpTree) -> None: self._op_tree = op_tree self._leaf_events = self._extract_leaf_events(op_tree) - self._active_version: dict[TensorKey, Optional[int]] = {} + self._active_version: Dict[TensorKey, Optional[int]] = {} self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events] self._flow_nodes.sort(key=lambda x: x.start_time) self.validate() @@ -516,7 +515,7 @@ class DataFlowGraph: def validate(self): # Check that each (Tensor, version) pair has a unique creation node - outputs: set[tuple[TensorKey, int]] = set() + outputs: Set[tuple[TensorKey, int]] = set() for node in self.flow_nodes: node_outputs = set(node.outputs.items()) duplicates = outputs & node_outputs @@ -524,7 +523,7 @@ class DataFlowGraph: outputs |= node_outputs # And check that `self._nodes` forms a valid topologically sorted DAG. - tensor_versions: dict[TensorKey, int] = {} + tensor_versions: Dict[TensorKey, int] = {} for node in self.flow_nodes: for key, (_, version) in node.inputs.items(): expected = tensor_versions.get(key, 0) @@ -572,7 +571,7 @@ class DataFlowGraph: form the logical nodes in our data flow graph. """ - leaf_events: list[_ProfilerEvent] = [] + leaf_events: List[_ProfilerEvent] = [] def leaf_op(e: _ProfilerEvent) -> bool: return e.typed[0] == _EventType.TorchOp and ( @@ -610,17 +609,17 @@ class DataFlowGraph: @dataclasses.dataclass class CategoryElement: by_id: Optional[Category] = None - by_key: dict[TensorKey, Category] = dataclasses.field(default_factory=dict) - by_version: dict[TensorAndID, Category] = dataclasses.field(default_factory=dict) + by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict) + by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict) # Used by unit tests to check internals. (And consequently by # MemoryProfile.lookup) This should not be used in any other capacity. - _by_id_keyset: set[TensorKey] = dataclasses.field(default_factory=set) + _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set) @dataclasses.dataclass class CategoryDict: - _values: collections.defaultdict[int, CategoryElement] = dataclasses.field( + _values: DefaultDict[int, CategoryElement] = dataclasses.field( default_factory=lambda: collections.defaultdict(CategoryElement) ) @@ -667,9 +666,9 @@ class MemoryProfile: @property def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]: - output: list[tuple[int, Action, KeyAndID, int]] = [] - allocation_times: dict[tuple[TensorKey, bool], int] = {} - live_unknown: dict[tuple[int, torch.device], Literal[True]] = {} + output: List[tuple[int, Action, KeyAndID, int]] = [] + allocation_times: Dict[tuple[TensorKey, bool], int] = {} + live_unknown: Dict[tuple[int, torch.device], Literal[True]] = {} for event in self._op_tree.dfs(): if event.typed[0] == _EventType.Allocation: alloc_fields = event.typed[1] @@ -702,7 +701,7 @@ class MemoryProfile: snapshot = self._category_snapshot() last_version = dict(sorted(snapshot.keys())) - events: list[tuple[int, Action, TensorAndID]] = [ + events: List[tuple[int, Action, TensorAndID]] = [ (-1, Action.PREEXISTING, (key, version)) for key, version in snapshot.keys() if (key, True) not in allocation_times and version == 0 @@ -735,8 +734,8 @@ class MemoryProfile: def _is_gradient(self, *args, **kwargs) -> bool: return self._categories.get(*args, **kwargs) == Category.GRADIENT - def _category_snapshot(self) -> dict[TensorAndID, Optional[Category]]: - all_tensor_versions: set[TensorAndID] = set() + def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]: + all_tensor_versions: Set[TensorAndID] = set() for node in self._data_flow_graph.flow_nodes: all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items())) @@ -751,7 +750,7 @@ class MemoryProfile: for key, version in sorted(all_tensor_versions) } - def _any_version_depends_on_gradient(self) -> set[int]: + def _any_version_depends_on_gradient(self) -> Set[int]: """Extract IDs of Tensors which depend or will depend on a gradient. Note that this weakened definition of "depends" requires us to loop @@ -762,7 +761,7 @@ class MemoryProfile: acyclic data flow graph into a cyclic graph and we are attempting to partition cycles involving a gradient from the rest of the graph. """ - depends_on_gradient: set[int] = set() + depends_on_gradient: Set[int] = set() while True: start_size = len(depends_on_gradient) for node in self._data_flow_graph.flow_nodes: @@ -838,7 +837,7 @@ class MemoryProfile: # We only want to annotate Tensors which actually contribute to the # model calculation. - produces_gradient: set[TensorAndID] = set() + produces_gradient: Set[TensorAndID] = set() for node in reversed(self._data_flow_graph.flow_nodes): tensors = {(key, version) for key, (_, version) in node.inputs.items()} tensors |= node.outputs.items() @@ -895,8 +894,8 @@ class MemoryProfile: # data flow. Note this these are only candidates; we filter nodes that # we know are part of the backward pass but that doesn't guarantee that # they are part of the forward pass. - candidate_parameters: set[TensorAndID] = set() - candidate_fwd_tensors: set[TensorAndID] = { + candidate_parameters: Set[TensorAndID] = set() + candidate_fwd_tensors: Set[TensorAndID] = { i for i, category in snapshot.items() if category == Category.INPUT } @@ -915,7 +914,7 @@ class MemoryProfile: candidate_parameters |= inputs.difference(candidate_fwd_tensors) # Require that each parameter eventually contributes to the value of a gradient - used_for_gradient: set[TensorAndID] = set() + used_for_gradient: Set[TensorAndID] = set() for node in reversed(self._data_flow_graph.flow_nodes): if any( self._is_gradient(*i) or i in used_for_gradient @@ -994,8 +993,8 @@ class MemoryProfileTimeline: Output: [timestamps, sizes by category] """ device = torch.device(device_str) - times: list[int] = [] - sizes: list[list[int]] = [] + times: List[int] = [] + sizes: List[List[int]] = [] def update(key, version, delta): category = ( @@ -1062,7 +1061,7 @@ class MemoryProfileTimeline: as a JSON formatted file to the given path for the given device.""" device = torch.device(device_str) - raw_events: list[tuple[int, int, int, int]] = [] + raw_events: List[tuple[int, int, int, int]] = [] def get_category_index(key, version): category = ( diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 41748ea39545..6bac511dbbfc 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -3,7 +3,7 @@ import json import math import os import re -from typing import Optional +from typing import Dict, List, Optional, Set import torch import torch.utils.benchmark as benchmark @@ -34,7 +34,7 @@ class Pattern: self.url = "" assert prof.profiler is not None and prof.profiler.kineto_results is not None self.event_tree = prof.profiler.kineto_results.experimental_event_tree() - self.tid_root: dict[int, list[_ProfilerEvent]] = {} + self.tid_root: Dict[int, List[_ProfilerEvent]] = {} for event in self.event_tree: self.tid_root.setdefault(event.start_tid, []).append(event) @@ -55,7 +55,7 @@ class Pattern: """ yield from traverse_dfs(self.event_tree) - def summary(self, events: list[_ProfilerEvent]): + def summary(self, events: List[_ProfilerEvent]): default_summary = f"{self.name}: {len(events)} events matched." if self.should_benchmark: # If benchmark summary is not empty, use it. @@ -66,7 +66,7 @@ class Pattern: ) return default_summary - def benchmark_summary(self, events: list[_ProfilerEvent]): + def benchmark_summary(self, events: List[_ProfilerEvent]): def format_time(time_ns: int): unit_lst = ["ns", "us", "ms"] for unit in unit_lst: @@ -215,7 +215,7 @@ class ExtraCUDACopyPattern(Pattern): return event.name in self.init_ops # TODO: Check if tensor is reused - def benchmark(self, events: list[_ProfilerEvent]): + def benchmark(self, events: List[_ProfilerEvent]): shapes_factor_map = {input_shapes(event): 0.0 for event in events} for shape in shapes_factor_map: size = shape[0] @@ -252,7 +252,7 @@ class ForLoopIndexingPattern(Pattern): super().__init__(prof, should_benchmark) self.name = "For Loop Indexing Pattern" self.description = "For loop indexing detected. Vectorization recommended." - self.visited: set[int] = set() + self.visited: Set[int] = set() def eventTreeTraversal(self): """ @@ -326,7 +326,7 @@ class FP32MatMulPattern(Pattern): def report(self, event: _ProfilerEvent): return self.description - def benchmark(self, events: list[_ProfilerEvent]): + def benchmark(self, events: List[_ProfilerEvent]): shapes_factor_map = {input_shapes(event): 0.0 for event in events} for shape in shapes_factor_map: matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) @@ -553,7 +553,7 @@ class MatMulDimInFP16Pattern(Pattern): return True return False - def benchmark(self, events: list[_ProfilerEvent]): + def benchmark(self, events: List[_ProfilerEvent]): def closest_multiple(shapes, multiple): return [multiple * math.ceil(shape / multiple) for shape in shapes] @@ -609,7 +609,7 @@ def report_all_anti_patterns( print_enable: bool = True, json_report_dir: Optional[str] = None, ): - report_dict: dict = {} + report_dict: Dict = {} anti_patterns = [ ExtraCUDACopyPattern(prof, should_benchmark), # ForLoopIndexingPattern(prof, should_benchmark), diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index af693aecdde1..283d31c87024 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import operator import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -64,7 +64,7 @@ class EventKey: def __repr__(self): return f"{self.event.name}" - def intervals_overlap(self, intervals: list[Interval]): + def intervals_overlap(self, intervals: List[Interval]): overlap_time = 0 intervals = sorted(intervals, key=lambda x: x.start) @@ -100,13 +100,13 @@ class EventKey: class BasicEvaluation: def __init__(self, prof: profile): self.profile = prof - self.metrics: dict[EventKey, EventMetrics] = {} + self.metrics: Dict[EventKey, EventMetrics] = {} self.compute_self_time() self.event_keys = sorted( (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns ) self.events = [e.event for e in self.event_keys] - self.cuda_events: list[_KinetoEvent] = [] + self.cuda_events: List[_KinetoEvent] = [] self.queue_depth_list = self.compute_queue_depth() self.compute_idle_time() @@ -162,7 +162,7 @@ class BasicEvaluation: cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns() ) - kernel_mapping: dict[_KinetoEvent, int] = {} + kernel_mapping: Dict[_KinetoEvent, int] = {} last_mapped_kernel = 0 for cuda_launch_event in cuda_launch_events: index = index_of_first_match( @@ -188,7 +188,7 @@ class BasicEvaluation: return event.start_time_ns raise Exception("Unknown Event Type") # noqa: TRY002 - queue_depth_list: list[Interval] = [] + queue_depth_list: List[Interval] = [] all_events.sort(key=new_old_event_comparator) for event in all_events: # Find latest cuda kernel event @@ -233,7 +233,7 @@ class BasicEvaluation: # Based on queue_depth_list, we can calculate idle time for all the events idle = False idle_start = 0 - idle_intervals: list[Interval] = [] + idle_intervals: List[Interval] = [] if self.queue_depth_list and self.events: idle_intervals += [ Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start), diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index b328270d96ab..3a9c03bea018 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -5,10 +5,9 @@ import os import shutil import tempfile from abc import ABC, abstractmethod -from collections.abc import Iterable from enum import Enum from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional from typing_extensions import Self from warnings import warn @@ -168,7 +167,7 @@ class _KinetoProfile: self.use_device = _get_privateuse1_backend_name() # user-defined metadata to be amended to the trace - self.preset_metadata: dict[str, str] = {} + self.preset_metadata: Dict[str, str] = {} def start(self): self.prepare_trace() @@ -724,8 +723,8 @@ class profile(_KinetoProfile): self.current_action = self.schedule(self.step_num) self.step_rec_fn: Optional[prof.record_function] = None - self.action_map: dict[ - tuple[ProfilerAction, Optional[ProfilerAction]], list[Any] + self.action_map: Dict[ + tuple[ProfilerAction, Optional[ProfilerAction]], List[Any] ] = { # key is (prev_action, current_action), value is action list corresponding to the state pair. (ProfilerAction.NONE, ProfilerAction.NONE): [], diff --git a/torch/profiler/python_tracer.py b/torch/profiler/python_tracer.py index aff0fbc32ff3..b3e624911f95 100644 --- a/torch/profiler/python_tracer.py +++ b/torch/profiler/python_tracer.py @@ -1,11 +1,12 @@ import os import site import sys +import typing import torch -def _prefix_regex() -> list[str]: +def _prefix_regex() -> typing.List[str]: raw_paths = ( site.getsitepackages() + sys.path diff --git a/torch/serialization.py b/torch/serialization.py index c7046049f6b2..85333525ed79 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -15,7 +15,19 @@ import threading import warnings from contextlib import closing, contextmanager from enum import Enum -from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union +from typing import ( + Any, + BinaryIO, + Callable, + cast, + Dict, + IO, + List, + Optional, + Tuple, + Type, + Union, +) from typing_extensions import TypeAlias, TypeIs import torch @@ -67,7 +79,7 @@ STORAGE_KEY_SEPARATOR = "," FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] MAP_LOCATION: TypeAlias = Optional[ - Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] + Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]] ] STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] @@ -120,8 +132,8 @@ def mkdtemp(): shutil.rmtree(path) -_package_registry: list[ - tuple[ +_package_registry: List[ + Tuple[ int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]], @@ -258,14 +270,14 @@ def clear_safe_globals() -> None: _weights_only_unpickler._clear_safe_globals() -def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: +def get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]: """ Returns the list of user-added globals that are safe for ``weights_only`` load. """ return _weights_only_unpickler._get_safe_globals() -def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None: +def add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]) -> None: """ Marks the given globals as safe for ``weights_only`` load. For example, functions added to this list can be called during unpickling, classes could be instantiated @@ -326,7 +338,7 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ -def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]: +def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]: """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``. For a given function or class ``f``, the corresponding string will be of the form @@ -792,7 +804,7 @@ class _open_zipfile_writer_buffer(_opener): def _open_zipfile_writer(name_or_buffer): - container: type[_opener] + container: Type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file else: @@ -953,15 +965,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: import torch.nn as nn serialized_container_types = {} - serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {} + serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {} # Since loading storages that view the same data with different dtypes is # not supported, we need to keep track of the dtype associated with each # storage data_ptr and throw an error if the dtype is ever different. # TODO: This feature could be added in the future - storage_dtypes: dict[int, torch.dtype] = {} + storage_dtypes: Dict[int, torch.dtype] = {} - def persistent_id(obj: Any) -> Optional[tuple]: + def persistent_id(obj: Any) -> Optional[Tuple]: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see @@ -1020,7 +1032,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: else: storage_dtypes[storage.data_ptr()] = storage_dtype - view_metadata: Optional[tuple[str, int, int]] + view_metadata: Optional[Tuple[str, int, int]] # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) @@ -1115,13 +1127,13 @@ def _save( _disable_byteorder_record, ): serialized_storages = {} - id_map: dict[int, str] = {} + id_map: Dict[int, str] = {} # Since loading storages that view the same data with different dtypes is # not supported, we need to keep track of the dtype associated with each # storage data_ptr and throw an error if the dtype is ever different. # TODO: This feature could be added in the future - storage_dtypes: dict[int, torch.dtype] = {} + storage_dtypes: Dict[int, torch.dtype] = {} def persistent_id(obj): # FIXME: the docs say that persistent_id should only return a string @@ -1522,7 +1534,7 @@ copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) def _legacy_load(f, map_location, pickle_module, **pickle_load_args): - deserialized_objects: dict[int, Any] = {} + deserialized_objects: Dict[int, Any] = {} restore_location = _get_restore_location(map_location) @@ -1587,7 +1599,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): warnings.warn(msg, SourceChangeWarning) def legacy_load(f): - deserialized_objects: dict[int, Any] = {} + deserialized_objects: Dict[int, Any] = {} def persistent_load(saved_id): if isinstance(saved_id, tuple): @@ -1938,7 +1950,7 @@ def _load( return typed_storage - load_module_mapping: dict[str, str] = { + load_module_mapping: Dict[str, str] = { # See https://github.com/pytorch/pytorch/pull/51633 "torch.tensor": "torch._tensor" } diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 801334cdd8f0..7203bf6a6fa4 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -19,11 +19,11 @@ from .semi_structured import ( if TYPE_CHECKING: from torch.types import _dtype as DType - DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]] + DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]] else: # The JIT doesn't understand Union, nor torch.dtype here DType = int - DimOrDims = Optional[tuple[int]] + DimOrDims = Optional[Tuple[int]] __all__ = [ @@ -591,7 +591,7 @@ def as_sparse_gradcheck(gradcheck): """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors.""" if not isinstance(args, (list, tuple)): args = (args,) - new_args: list[Any] = [] + new_args: List[Any] = [] for obj in args: if ( isinstance(obj, torch.Tensor) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index ce0e8446cba2..ebc59b18d5a7 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -4,7 +4,7 @@ import math import os import weakref from functools import lru_cache -from typing import Optional +from typing import Optional, Tuple import torch from torch._dynamo.utils import warn_once @@ -1124,7 +1124,7 @@ def _int_bsr_dense_addmm( right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, - max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, meta: Optional[dict] = None, ): if out is None and dense.dtype is torch.int8: @@ -1165,7 +1165,7 @@ def bsr_dense_addmm( right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, - max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, meta: Optional[dict] = None, ): """Compute @@ -1647,7 +1647,7 @@ if has_triton(): alpha=1.0, out: Optional[torch.Tensor] = None, skip_checks: bool = False, - max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, ): f_name = "sampled_addmm" @@ -1731,7 +1731,7 @@ if has_triton(): *, out: Optional[torch.Tensor] = None, skip_checks: bool = False, - max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, meta: Optional[dict] = None, ): f_name = "bsr_dense_mm" diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 08471ac05888..e65981f72651 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -103,7 +103,7 @@ import inspect import itertools import re import warnings -from typing import Any +from typing import Any, Dict import torch from torch.hub import tqdm @@ -937,7 +937,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True): dump() -_operation_device_version_data: dict[Any, dict] = { +_operation_device_version_data: Dict[Any, Dict] = { # Warning: the data in between the BEGIN/END DATA comment lines # below is generated. It can be updated either manually or via # calling dump function defined above. diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 82b20ab792d2..0ca2202cc4ba 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import warnings from collections import namedtuple -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.sparse._semi_structured_conversions import ( @@ -54,13 +54,13 @@ class SparseSemiStructuredTensor(torch.Tensor): """ _DEFAULT_ALG_ID: int = 0 - _DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] + _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] _FORCE_CUTLASS: bool = False _FUSE_TRANSPOSE: bool = False _PROTOTYPE_WARNING_SHOWN: bool = False BACKEND: str - SPARSE_DISPATCH: dict[Callable, Callable] + SPARSE_DISPATCH: Dict[Callable, Callable] packed: Optional[torch.Tensor] meta: Optional[torch.Tensor] @@ -161,7 +161,7 @@ class SparseSemiStructuredTensor(torch.Tensor): def __tensor_flatten__( self, - ) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]: + ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]: inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) @@ -177,7 +177,7 @@ class SparseSemiStructuredTensor(torch.Tensor): def __tensor_unflatten__( cls, inner_tensors, - tensor_meta: tuple[torch.Size, bool, int, bool], + tensor_meta: Tuple[torch.Size, bool, int, bool], outer_size, outer_stride, ) -> torch.Tensor: diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 44d25c080275..2d4ade6ebc4a 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -23,13 +23,13 @@ from .streams import Event, Stream _initialized = False _tls = threading.local() _initialization_lock = threading.Lock() -_queued_calls: list[ - tuple[Callable[[], None], list[str]] +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] ] = [] # don't invoke these until initialization occurs _is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False) _device_t = Union[_device, str, int, None] _lazy_seed_tracker = _LazySeedTracker() -default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment] +default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] def _is_compiled() -> bool: @@ -216,7 +216,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: @lru_cache(None) -def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: +def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]: r"""Get the xpu capability of a device. Args: @@ -418,7 +418,7 @@ def synchronize(device: _device_t = None) -> None: return torch._C._xpu_synchronize(device) -def get_arch_list() -> list[str]: +def get_arch_list() -> List[str]: r"""Return list XPU architectures this library was compiled for.""" if not is_available(): return [] diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py index 2d3ea4995419..2d8bd296dff0 100644 --- a/torch/xpu/memory.py +++ b/torch/xpu/memory.py @@ -1,5 +1,5 @@ import collections -from typing import Any, Union +from typing import Any, Dict, Tuple, Union import torch from torch.types import Device @@ -53,7 +53,7 @@ def reset_accumulated_memory_stats(device: _device_t = None) -> None: return torch._C._xpu_resetAccumulatedMemoryStats(device) -def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]: +def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]: r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary.""" if not is_initialized(): return {} @@ -61,7 +61,7 @@ def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]: return torch._C._xpu_memoryStats(device) -def memory_stats(device: _device_t = None) -> dict[str, Any]: +def memory_stats(device: _device_t = None) -> Dict[str, Any]: r"""Return a dictionary of XPU memory allocator statistics for a given device. The return value of this function is a dictionary of statistics, each of @@ -178,7 +178,7 @@ def max_memory_reserved(device: _device_t = None) -> int: return memory_stats(device=device).get("reserved_bytes.all.peak", 0) -def mem_get_info(device: _device_t = None) -> tuple[int, int]: +def mem_get_info(device: _device_t = None) -> Tuple[int, int]: r"""Return the global free and total GPU memory for a given device. Args: diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 8cd74d385def..5bc142418637 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs -from collections.abc import Iterable -from typing import Union +from typing import Iterable, List, Union import torch from torch import Tensor @@ -30,7 +29,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: return default_generator.get_state() -def get_rng_state_all() -> list[Tensor]: +def get_rng_state_all() -> List[Tensor]: r"""Return a list of ByteTensor representing the random number states of all devices.""" results = [get_rng_state(i) for i in range(device_count())] return results