PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 21:55:48 -08:00
committed by PyTorch MergeBot
parent bd97ce0b45
commit 54a00af2c6
38 changed files with 499 additions and 550 deletions

View File

@ -3,7 +3,7 @@
import importlib
import math
import warnings
from typing import Callable, List, Optional, TYPE_CHECKING, Union
from typing import Callable, Optional, TYPE_CHECKING, Union
import torch
from torch import _VF, sym_int as _sym_int, Tensor
@ -708,7 +708,7 @@ def max_pool1d_with_indices(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch.max_pool1d_with_indices(
input, kernel_size, stride, padding, dilation, ceil_mode
)
@ -736,7 +736,7 @@ def _max_pool1d(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
@ -798,7 +798,7 @@ def max_pool2d_with_indices(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch._C._nn.max_pool2d_with_indices(
input, kernel_size, stride, padding, dilation, ceil_mode
)
@ -826,7 +826,7 @@ def _max_pool2d(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
@ -888,7 +888,7 @@ def max_pool3d_with_indices(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch._C._nn.max_pool3d_with_indices(
input, kernel_size, stride, padding, dilation, ceil_mode
)
@ -916,7 +916,7 @@ def _max_pool3d(
return_indices=return_indices,
)
if stride is None:
stride = torch.jit.annotate(List[int], [])
stride = torch.jit.annotate(list[int], [])
return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
@ -939,7 +939,7 @@ def _unpool_output_size(
output_size: Optional[list[int]],
) -> list[int]:
input_size = input.size()
default_size = torch.jit.annotate(List[int], [])
default_size = torch.jit.annotate(list[int], [])
for d in range(len(kernel_size)):
default_size.append(
(input_size[-len(kernel_size) + d] - 1) * stride[d]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import deprecated
import torch
@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd):
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
)
min_sizes = torch.jit.annotate(List[int], [])
max_sizes = torch.jit.annotate(List[int], [])
min_sizes = torch.jit.annotate(list[int], [])
max_sizes = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims):
dim_size = (
(input.size(d + num_non_spatial_dims) - 1) * stride[d]
@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd):
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
)
res = torch.jit.annotate(List[int], [])
res = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims):
res.append(output_size[d] - min_sizes[d])

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Optional, TYPE_CHECKING, Union
import torch
from torch import Tensor
@ -25,7 +25,7 @@ class Adafactor(Optimizer):
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
beta2_decay: float = -0.8,
eps: Tuple[Optional[float], float] = (None, 1e-3),
eps: tuple[Optional[float], float] = (None, 1e-3),
d: float = 1.0,
weight_decay: float = 0.0,
*,
@ -133,12 +133,12 @@ class Adafactor(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
row_vars: List[Optional[Tensor]] = []
col_vars: List[Optional[Tensor]] = []
variances: List[Optional[Tensor]] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
row_vars: list[Optional[Tensor]] = []
col_vars: list[Optional[Tensor]] = []
variances: list[Optional[Tensor]] = []
state_steps: list[Tensor] = []
eps1, eps2 = group["eps"]
has_complex = self._init_group(
@ -324,16 +324,16 @@ Adafactor.__doc__ = (
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
# so row_var and col_var will be None while variance will be filled.
# Contrarily, for a grad with multiple dimensions, we will factor along the last
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
row_vars: list[Optional[Tensor]],
col_vars: list[Optional[Tensor]],
variances: list[Optional[Tensor]],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -411,17 +411,17 @@ def _single_tensor_adafactor(
def _group_tensors_by_device_dtype_and_is_multidim(
tensorlists: TensorListList,
) -> Dict[
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
List[List[Optional[Tensor]]],
) -> dict[
tuple[Optional[torch.device], Optional[torch.dtype], bool],
list[list[Optional[Tensor]]],
]:
"""Groups tensors by device, dtype, AND multidimensionality -- whether the tensor
has multiple dims or just one dim (is a vector). This allows the foreach impl of
Adafactor to assume that every group of params will either be factored or not."""
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists)
ultra_grouped_tensors: Dict[
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
List[List[Optional[Tensor]]],
ultra_grouped_tensors: dict[
tuple[Optional[torch.device], Optional[torch.dtype], bool],
list[list[Optional[Tensor]]],
] = {}
for (device, dtype), (tensorlists, _) in grouped_tensors.items():
matrix_key = (device, dtype, True)
@ -444,16 +444,16 @@ def _group_tensors_by_device_dtype_and_is_multidim(
def _multi_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
# so row_var and col_var will be None while variance will be filled.
# Contrarily, for a grad with multiple dimensions, we will factor along the last
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
row_vars: list[Optional[Tensor]],
col_vars: list[Optional[Tensor]],
variances: list[Optional[Tensor]],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -486,9 +486,9 @@ def _multi_tensor_adafactor(
device_state_steps_,
)
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_state_steps = cast(list[Tensor], device_state_steps_)
if eps1 is None:
assert (
dtype is not None
@ -530,8 +530,8 @@ def _multi_tensor_adafactor(
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
if is_multidim:
device_row_vars = cast(List[Tensor], device_row_vars_)
device_col_vars = cast(List[Tensor], device_col_vars_)
device_row_vars = cast(list[Tensor], device_row_vars_)
device_col_vars = cast(list[Tensor], device_col_vars_)
assert (
device_row_vars[0] is not None and device_col_vars[0] is not None
), "row_var and col_var should be defined when grad is multidimensional"
@ -564,7 +564,7 @@ def _multi_tensor_adafactor(
torch._foreach_div_(var_estimates, row_var_means)
del row_var_means
else:
device_variances = cast(List[Tensor], device_variances_)
device_variances = cast(list[Tensor], device_variances_)
assert (
device_variances[0] is not None
), "variance should be defined when grad is a vector"
@ -592,12 +592,12 @@ def _multi_tensor_adafactor(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
def adafactor(
params: List[Tensor],
grads: List[Tensor],
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
row_vars: list[Optional[Tensor]],
col_vars: list[Optional[Tensor]],
variances: list[Optional[Tensor]],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
r"""Functional interface."""
import math
from typing import List
from torch import Tensor
@ -22,11 +21,11 @@ from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
def sparse_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[int],
*,
eps: float,
beta1: float,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, cast, Optional, Union
import torch
from torch import Tensor
@ -82,12 +82,12 @@ class Adadelta(Optimizer):
def _init_group(
self,
group: Dict[str, Any],
params_with_grad: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
group: dict[str, Any],
params_with_grad: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
acc_deltas: list[Tensor],
state_steps: list[Tensor],
):
has_complex = False
p: Tensor
@ -139,11 +139,11 @@ class Adadelta(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
square_avgs: List[Tensor] = []
acc_deltas: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
square_avgs: list[Tensor] = []
acc_deltas: list[Tensor] = []
state_steps: list[Tensor] = []
(
lr,
rho,
@ -242,11 +242,11 @@ Adadelta.__doc__ = (
def _single_tensor_adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
acc_deltas: list[Tensor],
state_steps: list[Tensor],
*,
lr: float,
rho: float,
@ -296,11 +296,11 @@ def _single_tensor_adadelta(
def _multi_tensor_adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
acc_deltas: list[Tensor],
state_steps: list[Tensor],
*,
lr: float,
rho: float,
@ -337,11 +337,11 @@ def _multi_tensor_adadelta(
device_acc_deltas_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_square_avgs = cast(List[Tensor], device_square_avgs_)
device_acc_deltas = cast(List[Tensor], device_acc_deltas_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_square_avgs = cast(list[Tensor], device_square_avgs_)
device_acc_deltas = cast(list[Tensor], device_acc_deltas_)
device_state_steps = cast(list[Tensor], device_state_steps_)
if has_complex:
_view_as_real(
device_params, device_grads, device_square_avgs, device_acc_deltas
@ -397,11 +397,11 @@ def _multi_tensor_adadelta(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
def adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
acc_deltas: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
capturable: bool = False,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -157,10 +157,10 @@ class Adagrad(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
state_sums: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
state_sums: list[Tensor] = []
state_steps: list[Tensor] = []
has_sparse_grad, has_complex = self._init_group(
group, params_with_grad, grads, state_sums, state_steps
@ -240,10 +240,10 @@ Adagrad.__doc__ = (
def adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
state_sums: list[Tensor],
state_steps: list[Tensor],
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
@ -319,10 +319,10 @@ def _make_sparse(grad, grad_indices, values):
def _single_tensor_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
state_sums: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -380,10 +380,10 @@ def _single_tensor_adagrad(
def _multi_tensor_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
state_sums: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -412,10 +412,10 @@ def _multi_tensor_adagrad(
device_state_sums_,
device_state_steps_,
), _ in grouped_tensorlists.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_sums = cast(List[Tensor], device_state_sums_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_state_sums = cast(list[Tensor], device_state_sums_)
device_state_steps = cast(list[Tensor], device_state_steps_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
@ -487,10 +487,10 @@ def _multi_tensor_adagrad(
def _fused_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
state_sums: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -530,10 +530,10 @@ def _fused_adagrad(
),
_,
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_sums = cast(List[Tensor], device_state_sums_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_state_sums = cast(list[Tensor], device_state_sums_)
device_state_steps = cast(list[Tensor], device_state_steps_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None and grad_scale_dict is not None:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -35,7 +35,7 @@ class Adam(Optimizer):
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
@ -225,12 +225,12 @@ class Adam(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_avg_sqs: list[Tensor] = []
max_exp_avg_sqs: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = group["betas"]
has_complex = self._init_group(
@ -342,12 +342,12 @@ Adam.__doc__ = (
def _single_tensor_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -532,12 +532,12 @@ def _single_tensor_adam(
def _multi_tensor_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -612,11 +612,11 @@ def _multi_tensor_adam(
device_max_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(list[Tensor], device_state_steps_)
device = device_params[0].device
if beta1_dict is not None and device not in beta1_dict:
@ -627,7 +627,7 @@ def _multi_tensor_adam(
# Handle complex parameters
if has_complex:
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
_view_as_real(
device_params,
device_grads,
@ -693,9 +693,9 @@ def _multi_tensor_adam(
del device_grads
del scaled_device_grads
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
@ -719,7 +719,7 @@ def _multi_tensor_adam(
bias_correction2_sqrt = bias_correction2
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
@ -747,7 +747,7 @@ def _multi_tensor_adam(
bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
@ -764,12 +764,12 @@ def _multi_tensor_adam(
def _fused_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -816,11 +816,11 @@ def _fused_adam(
),
_,
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(list[Tensor], device_state_steps_)
if device.type == "mps": # type: ignore[union-attr]
assert found_inf is None and grad_scale is None
@ -864,12 +864,12 @@ def _fused_adam(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
def adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
from torch import Tensor
@ -23,7 +23,7 @@ class AdamW(Adam):
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
@ -128,12 +128,12 @@ AdamW.__doc__ = (
# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -135,12 +135,12 @@ class ASGD(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
mus: List[Tensor] = []
axs: List[Tensor] = []
etas: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
mus: list[Tensor] = []
axs: list[Tensor] = []
etas: list[Tensor] = []
state_steps: list[Tensor] = []
has_complex = self._init_group(
group, params_with_grad, grads, mus, axs, etas, state_steps
@ -192,12 +192,12 @@ ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
def _single_tensor_asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
*,
lambd: float,
lr: float,
@ -268,12 +268,12 @@ def _single_tensor_asgd(
def _multi_tensor_asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
*,
lambd: float,
lr: float,
@ -315,12 +315,12 @@ def _multi_tensor_asgd(
),
_,
) in grouped_tensors.items():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_axs = cast(List[Tensor], grouped_axs_)
grouped_mus = cast(List[Tensor], grouped_mus_)
grouped_etas = cast(List[Tensor], grouped_etas_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_axs = cast(list[Tensor], grouped_axs_)
grouped_mus = cast(list[Tensor], grouped_mus_)
grouped_etas = cast(list[Tensor], grouped_etas_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_axs)
@ -340,7 +340,7 @@ def _multi_tensor_asgd(
torch._foreach_add_(grouped_state_steps, 1)
# intermediate = grad + param * lambd
intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
intermediate: Union[tuple[Tensor, ...], list[Tensor]]
if weight_decay != 0:
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
@ -375,8 +375,8 @@ def _multi_tensor_asgd(
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
del intermediate
new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
new_etas: Union[tuple[Tensor, ...], list[Tensor]]
new_mus: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
# update grouped_mus
new_mus = torch._foreach_sub(grouped_state_steps, t0)
@ -408,12 +408,12 @@ def _multi_tensor_asgd(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
def asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -5,17 +5,14 @@ import types
import warnings
from bisect import bisect_right
from collections import Counter
from collections.abc import Iterable, Sequence
from functools import partial, wraps
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
SupportsFloat,
TypedDict,
Union,
@ -116,7 +113,7 @@ class LRScheduler:
"param 'initial_lr' is not specified "
f"in param_groups[{i}] when resuming an optimizer"
)
self.base_lrs: List[float] = [
self.base_lrs: list[float] = [
group["initial_lr"] for group in optimizer.param_groups
]
self.last_epoch = last_epoch
@ -163,7 +160,7 @@ class LRScheduler:
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict: Dict[str, Any]):
def load_state_dict(self, state_dict: dict[str, Any]):
"""Load the scheduler's state.
Args:
@ -172,18 +169,18 @@ class LRScheduler:
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
def get_last_lr(self) -> list[float]:
"""Return last computed learning rate by current scheduler."""
return self._last_lr
def get_lr(self) -> List[float]:
def get_lr(self) -> list[float]:
"""Compute learning rate using chainable form of the scheduler."""
raise NotImplementedError
def print_lr(
self,
is_verbose: bool,
group: Dict[str, Any],
group: dict[str, Any],
lr: float,
epoch: Optional[int] = None,
):
@ -243,7 +240,7 @@ class LRScheduler:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(List[float], self._get_closed_form_lr())
values = cast(list[float], self._get_closed_form_lr())
else:
values = self.get_lr()
@ -253,7 +250,7 @@ class LRScheduler:
else:
param_group["lr"] = lr
self._last_lr: List[float] = [
self._last_lr: list[float] = [
group["lr"] for group in self.optimizer.param_groups
]
@ -320,13 +317,13 @@ class LambdaLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
last_epoch: int = -1,
verbose="deprecated",
): # noqa: D107
self.optimizer = optimizer
self.lr_lambdas: List[Callable[[int], float]]
self.lr_lambdas: list[Callable[[int], float]]
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
@ -420,13 +417,13 @@ class MultiplicativeLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
last_epoch: int = -1,
verbose="deprecated",
): # noqa: D107
self.optimizer = optimizer
self.lr_lambdas: List[Callable[[int], float]]
self.lr_lambdas: list[Callable[[int], float]]
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
@ -863,8 +860,8 @@ class SequentialLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
schedulers: List[LRScheduler],
milestones: List[int],
schedulers: list[LRScheduler],
milestones: list[int],
last_epoch: int = -1,
verbose="deprecated",
): # noqa: D107
@ -1311,7 +1308,7 @@ class ReduceLROnPlateau(LRScheduler):
threshold: float = 1e-4,
threshold_mode: Literal["rel", "abs"] = "rel",
cooldown: int = 0,
min_lr: Union[List[float], float] = 0,
min_lr: Union[list[float], float] = 0,
eps: float = 1e-8,
verbose="deprecated",
): # noqa: D107
@ -1557,8 +1554,8 @@ class CyclicLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
base_lr: Union[float, List[float]],
max_lr: Union[float, List[float]],
base_lr: Union[float, list[float]],
max_lr: Union[float, list[float]],
step_size_up: int = 2000,
step_size_down: Optional[int] = None,
mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
@ -1988,15 +1985,15 @@ class OneCycleLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
max_lr: Union[float, List[float]],
max_lr: Union[float, list[float]],
total_steps: Optional[int] = None,
epochs: Optional[int] = None,
steps_per_epoch: Optional[int] = None,
pct_start: float = 0.3,
anneal_strategy: Literal["cos", "linear"] = "cos",
cycle_momentum: bool = True,
base_momentum: Union[float, List[float]] = 0.85,
max_momentum: Union[float, List[float]] = 0.95,
base_momentum: Union[float, list[float]] = 0.85,
max_momentum: Union[float, list[float]] = 0.95,
div_factor: float = 25.0,
final_div_factor: float = 1e4,
three_phase: bool = False,
@ -2028,7 +2025,7 @@ class OneCycleLR(LRScheduler):
"You must define either total_steps OR (epochs AND steps_per_epoch)"
)
self._schedule_phases: List[_SchedulePhase]
self._schedule_phases: list[_SchedulePhase]
if three_phase:
self._schedule_phases = [
{

View File

@ -3,25 +3,10 @@
import functools
import warnings
from collections import defaultdict, OrderedDict
from collections.abc import Hashable, Iterable, Sequence
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
cast,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
overload,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Optional, overload, TypeVar, Union
from typing_extensions import ParamSpec, Self, TypeAlias
import torch
@ -39,15 +24,15 @@ from torch.utils.hooks import RemovableHandle
_T = TypeVar("_T")
_P = ParamSpec("_P")
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
DeviceDict = Dict[Optional[torch.device], torch.Tensor]
DeviceDtypeDict = Dict[Optional[Tuple[torch.device, torch.dtype]], torch.Tensor]
Args: TypeAlias = tuple[Any, ...]
Kwargs: TypeAlias = dict[str, Any]
StateDict: TypeAlias = dict[str, Any]
DeviceDict = dict[Optional[torch.device], torch.Tensor]
DeviceDtypeDict = dict[Optional[tuple[torch.device, torch.dtype]], torch.Tensor]
GlobalOptimizerPreHook: TypeAlias = Callable[
["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]
["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]]
]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
@ -56,8 +41,8 @@ __all__ = [
"register_optimizer_step_pre_hook",
"register_optimizer_step_post_hook",
]
_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
_global_optimizer_pre_hooks: dict[int, GlobalOptimizerPreHook] = OrderedDict()
_global_optimizer_post_hooks: dict[int, GlobalOptimizerPostHook] = OrderedDict()
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
@ -173,8 +158,8 @@ def _disable_dynamo_if_unsupported(
# torch.jit.script nor differentiable, so we fall back to the single tensor
# implementation in those cases.
def _default_to_fused_or_foreach(
params: List[torch.Tensor], differentiable: bool, use_fused: bool = False
) -> Tuple[bool, bool]:
params: list[torch.Tensor], differentiable: bool, use_fused: bool = False
) -> tuple[bool, bool]:
if torch.jit.is_scripting() or differentiable:
return False, False
@ -229,7 +214,7 @@ def _get_scalar_dtype(is_fused=None):
)
def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]:
r"""Return the device type list that supports capturable optimizer."""
capturable_supported_devices = ["cuda", "xpu", "hpu"]
if not torch.jit.is_scripting():
@ -321,7 +306,7 @@ def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> Removabl
ParamsT: TypeAlias = Union[
Iterable[torch.Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, torch.Tensor]]
Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]]
]
R = TypeVar("R")
@ -343,17 +328,17 @@ class Optimizer:
options (used when a parameter group doesn't specify them).
"""
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc]
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._optimizer_step_pre_hooks = OrderedDict()
@ -371,8 +356,8 @@ class Optimizer:
"an iterable of Tensors or dicts, but got " + torch.typename(params)
)
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: list[dict[str, Any]] = []
param_groups = list(params)
if len(param_groups) == 0:
@ -388,14 +373,14 @@ class Optimizer:
# https://github.com/pytorch/pytorch/issues/72948
self._warned_capturable_if_run_uncaptured = True
def __getstate__(self) -> Dict[str, Any]: # noqa: D105
def __getstate__(self) -> dict[str, Any]: # noqa: D105
return {
"defaults": self.defaults,
"state": self.state,
"param_groups": self.param_groups,
}
def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105
def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105
self.__dict__.update(state)
if "_optimizer_step_pre_hooks" not in self.__dict__:
self._optimizer_step_pre_hooks = OrderedDict()
@ -516,8 +501,8 @@ class Optimizer:
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> Union[
Dict[Tuple[None, None], Tuple[TensorListList, Indices]],
Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]],
dict[tuple[None, None], tuple[TensorListList, Indices]],
dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]],
]:
"""Group a list of lists of tensors by device and dtype.
@ -705,10 +690,10 @@ class Optimizer:
pre_hook(self)
# Save order indices instead of Tensors
param_mappings: Dict[int, int] = {}
param_mappings: dict[int, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
def pack_group(group: dict[str, Any]) -> dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
param_mappings.update(
@ -745,7 +730,7 @@ class Optimizer:
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
param_groups: list[dict[Any, Any]],
key: Hashable = None,
) -> torch.Tensor:
# Floating-point types are a bit special here. They are the only ones
@ -918,7 +903,7 @@ class Optimizer:
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
@ -930,8 +915,8 @@ class Optimizer:
# Update parameter groups, setting their 'params' value
def update_group(
group: Dict[str, Any], new_group: Dict[str, Any]
) -> Dict[str, Any]:
group: dict[str, Any], new_group: dict[str, Any]
) -> dict[str, Any]:
new_group["params"] = group["params"]
if "param_names" in group and "param_names" not in new_group:
new_group["param_names"] = group["param_names"]
@ -967,7 +952,7 @@ class Optimizer:
self._patch_step_function()
per_device_and_dtype_grads: Optional[
DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]
defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]]
]
if foreach:
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
@ -1016,7 +1001,7 @@ class Optimizer:
raise NotImplementedError
@torch._disable_dynamo
def add_param_group(self, param_group: Dict[str, Any]) -> None:
def add_param_group(self, param_group: dict[str, Any]) -> None:
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
@ -1087,7 +1072,7 @@ class Optimizer:
stacklevel=3,
)
param_set: Set[torch.Tensor] = set()
param_set: set[torch.Tensor] = set()
for group in self.param_groups:
param_set.update(set(group["params"]))
if ("param_names" in param_group) != ("param_names" in group):

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import cast, List, Optional, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -155,12 +155,12 @@ class RMSprop(Optimizer): # noqa: D101
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
square_avgs: List[Tensor] = []
grad_avgs: List[Tensor] = []
momentum_buffer_list: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
square_avgs: list[Tensor] = []
grad_avgs: list[Tensor] = []
momentum_buffer_list: list[Tensor] = []
state_steps: list[Tensor] = []
has_complex = self._init_group(
group,
@ -261,12 +261,12 @@ RMSprop.__doc__ = (
def _single_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
grad_avgs: list[Tensor],
momentum_buffer_list: list[Tensor],
state_steps: list[Tensor],
*,
lr: float,
alpha: float,
@ -332,12 +332,12 @@ def _single_tensor_rmsprop(
def _multi_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
grad_avgs: list[Tensor],
momentum_buffer_list: list[Tensor],
state_steps: list[Tensor],
*,
lr: float,
alpha: float,
@ -377,20 +377,20 @@ def _multi_tensor_rmsprop(
grouped_state_steps_,
)
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
if has_complex:
state_and_grads = [grouped_grads, grouped_square_avgs]
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
list[Tensor], grouped_momentum_buffer_list_
)
state_and_grads.append(grouped_momentum_buffer_list)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
state_and_grads.append(grouped_grad_avgs)
_view_as_real(grouped_params, *state_and_grads)
@ -423,7 +423,7 @@ def _multi_tensor_rmsprop(
)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
avg = torch._foreach_addcmul(
grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
@ -436,7 +436,7 @@ def _multi_tensor_rmsprop(
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
list[Tensor], grouped_momentum_buffer_list_
)
torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
@ -461,12 +461,12 @@ def _multi_tensor_rmsprop(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
def rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
square_avgs: list[Tensor],
grad_avgs: list[Tensor],
momentum_buffer_list: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the Resilient backpropagation."""
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -30,8 +30,8 @@ class Rprop(Optimizer): # noqa: D101
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
etas: Tuple[float, float] = (0.5, 1.2),
step_sizes: Tuple[float, float] = (1e-6, 50),
etas: tuple[float, float] = (0.5, 1.2),
step_sizes: tuple[float, float] = (1e-6, 50),
*,
capturable: bool = False,
foreach: Optional[bool] = None,
@ -129,11 +129,11 @@ class Rprop(Optimizer): # noqa: D101
loss = closure()
for group in self.param_groups:
params: List[Tensor] = []
grads: List[Tensor] = []
prevs: List[Tensor] = []
step_sizes: List[Tensor] = []
state_steps: List[Tensor] = []
params: list[Tensor] = []
grads: list[Tensor] = []
prevs: list[Tensor] = []
step_sizes: list[Tensor] = []
state_steps: list[Tensor] = []
etaminus, etaplus = group["etas"]
step_size_min, step_size_max = group["step_sizes"]
@ -219,11 +219,11 @@ Rprop.__doc__ = (
def _single_tensor_rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
prevs: list[Tensor],
step_sizes: list[Tensor],
state_steps: list[Tensor],
*,
step_size_min: float,
step_size_max: float,
@ -287,11 +287,11 @@ def _single_tensor_rprop(
def _multi_tensor_rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
prevs: list[Tensor],
step_sizes: list[Tensor],
state_steps: list[Tensor],
*,
step_size_min: float,
step_size_max: float,
@ -326,11 +326,11 @@ def _multi_tensor_rprop(
grouped_step_sizes_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_prevs = cast(List[Tensor], grouped_prevs_)
grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_prevs = cast(list[Tensor], grouped_prevs_)
grouped_step_sizes = cast(list[Tensor], grouped_step_sizes_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
@ -402,11 +402,11 @@ def _multi_tensor_rprop(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
def rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
prevs: list[Tensor],
step_sizes: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, List, Optional, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -114,9 +114,9 @@ class SGD(Optimizer): # noqa: D101
loss = closure()
for group in self.param_groups:
params: List[Tensor] = []
grads: List[Tensor] = []
momentum_buffer_list: List[Optional[Tensor]] = []
params: list[Tensor] = []
grads: list[Tensor] = []
momentum_buffer_list: list[Optional[Tensor]] = []
has_sparse_grad = self._init_group(
group, params, grads, momentum_buffer_list
@ -244,9 +244,9 @@ SGD.__doc__ = (
def sgd(
params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
params: list[Tensor],
d_p_list: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = False,
@ -314,9 +314,9 @@ def sgd(
def _single_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -369,9 +369,9 @@ def _single_tensor_sgd(
def _multi_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -397,8 +397,8 @@ def _multi_tensor_sgd(
device_grads_,
device_momentum_buffer_list,
), indices in grouped_tensors.values():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_params: list[Tensor] = cast(list[Tensor], device_params_)
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
@ -417,7 +417,7 @@ def _multi_tensor_sgd(
)
if momentum != 0:
bufs: List[Tensor] = []
bufs: list[Tensor] = []
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
@ -462,9 +462,9 @@ def _multi_tensor_sgd(
def _fused_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
@ -501,8 +501,8 @@ def _fused_sgd(
(device_params_, device_grads_, device_momentum_buffer_list),
_,
) in grouped_tensors.items():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_params: list[Tensor] = cast(list[Tensor], device_params_)
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
@ -515,7 +515,7 @@ def _fused_sgd(
device_grads,
[]
if no_momentum_buffer
else cast(List[Tensor], device_momentum_buffer_list),
else cast(list[Tensor], device_momentum_buffer_list),
weight_decay=weight_decay,
momentum=momentum,
lr=lr,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Tuple, Union
from typing import Union
import torch
from torch import Tensor
@ -16,7 +16,7 @@ class SparseAdam(Optimizer):
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
maximize: bool = False,
):
@ -69,11 +69,11 @@ class SparseAdam(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[int] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_avg_sqs: list[Tensor] = []
state_steps: list[int] = []
beta1, beta2 = group["betas"]
maximize = group.get("maximize", False)

View File

@ -3,8 +3,9 @@ r"""Implementation for Stochastic Weight Averaging implementation."""
import itertools
import math
import warnings
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Literal, Optional, Union
import torch
from torch import Tensor
@ -28,7 +29,7 @@ __all__ = [
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
def get_ema_multi_avg_fn(decay=0.999):
@ -253,8 +254,8 @@ class AveragedModel(Module):
if self.use_buffers
else model.parameters()
)
self_param_detached: List[Optional[Tensor]] = []
model_param_detached: List[Optional[Tensor]] = []
self_param_detached: list[Optional[Tensor]] = []
model_param_detached: list[Optional[Tensor]] = []
for p_averaged, p_model in zip(self_param, model_param):
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from collections import deque
from typing import List, Set
class DiGraph:
@ -90,7 +89,7 @@ class DiGraph:
except TypeError:
return False
def forward_transitive_closure(self, src: str) -> Set[str]:
def forward_transitive_closure(self, src: str) -> set[str]:
"""Returns a set of nodes that are reachable from src"""
result = set(src)
@ -103,7 +102,7 @@ class DiGraph:
working_set.append(n)
return result
def backward_transitive_closure(self, src: str) -> Set[str]:
def backward_transitive_closure(self, src: str) -> set[str]:
"""Returns a set of nodes that are reachable from src in reverse direction"""
result = set(src)
@ -140,7 +139,7 @@ class DiGraph:
return result_graph.to_dot()
def first_path(self, dst: str) -> List[str]:
def first_path(self, dst: str) -> list[str]:
"""Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
path = []

View File

@ -1,12 +1,10 @@
from typing import Dict, List
from torch.package.package_exporter import PackagingError
__all__ = ["find_first_use_of_broken_modules"]
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:
def find_first_use_of_broken_modules(exc: PackagingError) -> dict[str, list[str]]:
"""
Find all broken modules in a PackagingError, and for each one, return the
dependency path in which the module was first encountered.

View File

@ -1,14 +1,15 @@
# mypy: allow-untyped-defs
import sys
from typing import Any, Callable, Iterable, List, Tuple
from collections.abc import Iterable
from typing import Any, Callable
__all__ = ["trace_dependencies"]
def trace_dependencies(
callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]]
) -> List[str]:
callable: Callable[[Any], Any], inputs: Iterable[tuple[Any, ...]]
) -> list[str]:
"""Trace the execution of a callable in order to determine which modules it uses.
Args:

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Dict, List
from .glob_group import GlobGroup, GlobPattern
@ -15,9 +14,9 @@ class Directory:
def __init__(self, name: str, is_dir: bool):
self.name = name
self.is_dir = is_dir
self.children: Dict[str, Directory] = {}
self.children: dict[str, Directory] = {}
def _get_dir(self, dirs: List[str]) -> "Directory":
def _get_dir(self, dirs: list[str]) -> "Directory":
"""Builds path of Directories if not yet built and returns last directory
in list.
@ -64,13 +63,13 @@ class Directory:
return False
def __str__(self):
str_list: List[str] = []
str_list: list[str] = []
self._stringify_tree(str_list)
return "".join(str_list)
def _stringify_tree(
self,
str_list: List[str],
str_list: list[str],
preamble: str = "",
dir_ptr: str = "\u2500\u2500\u2500 ",
):
@ -89,8 +88,8 @@ class Directory:
else:
preamble = preamble + space
file_keys: List[str] = []
dir_keys: List[str] = []
file_keys: list[str] = []
dir_keys: list[str] = []
for key, val in self.children.items():
if val.is_dir:
dir_keys.append(key)
@ -109,7 +108,7 @@ class Directory:
def _create_directory_from_file_list(
filename: str,
file_list: List[str],
file_list: list[str],
include: "GlobPattern" = "**",
exclude: "GlobPattern" = (),
) -> Directory:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import ast
from typing import List, Optional, Tuple
from typing import Optional
from ._importlib import _resolve_name
@ -11,7 +11,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
"""
@classmethod
def run(cls, src: str, package: str) -> List[Tuple[str, Optional[str]]]:
def run(cls, src: str, package: str) -> list[tuple[str, Optional[str]]]:
visitor = cls(package)
tree = ast.parse(src)
visitor.visit(tree)
@ -53,7 +53,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
if hasattr(node.func, "id") and node.func.id == "__import__":
try:
name = self._grab_node_str(node.args[0])
fromlist: List[str] = []
fromlist: list[str] = []
level = 0
if len(node.args) > 3:
fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts)

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import re
from typing import Iterable, Union
from collections.abc import Iterable
from typing import Union
GlobPattern = Union[str, Iterable[str]]

View File

@ -7,7 +7,7 @@ from pickle import ( # type: ignore[attr-defined]
whichmodule as _pickle_whichmodule,
)
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from ._mangling import demangle, get_mangle_prefix, is_mangled
@ -44,7 +44,7 @@ class Importer(ABC):
assert obj1 is obj2
"""
modules: Dict[str, ModuleType]
modules: dict[str, ModuleType]
@abstractmethod
def import_module(self, module_name: str) -> ModuleType:
@ -53,7 +53,7 @@ class Importer(ABC):
The contract is the same as for importlib.import_module.
"""
def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]:
"""Given an object, return a name that can be used to retrieve the
object from this environment.
@ -184,7 +184,7 @@ class OrderedImporter(Importer):
"""
def __init__(self, *args):
self._importers: List[Importer] = list(args)
self._importers: list[Importer] = list(args)
def _is_torchpackage_dummy(self, module):
"""Returns true iff this module is an empty PackageNode in a torch.package.

View File

@ -7,23 +7,12 @@ import pickletools
import platform
import types
from collections import defaultdict, OrderedDict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import (
Any,
BinaryIO,
Callable,
cast,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Set,
Union,
)
from typing import Any, BinaryIO, Callable, cast, Optional, Union
import torch
from torch.serialization import location_tag, normalize_storage_type
@ -133,7 +122,7 @@ class PackagingError(Exception):
def __init__(self, dependency_graph: DiGraph, debug=False):
# Group errors by reason.
broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
broken: dict[PackagingErrorReason, list[str]] = defaultdict(list)
for module_name, attrs in dependency_graph.nodes.items():
error = attrs.get("error")
if error is None:
@ -236,9 +225,9 @@ class PackageExporter:
self.zip_file = torch._C.PyTorchFileWriter(f)
self.zip_file.set_min_version(6)
self._written_files: Set[str] = set()
self._written_files: set[str] = set()
self.serialized_reduces: Dict[int, Any] = {}
self.serialized_reduces: dict[int, Any] = {}
# A graph tracking all the modules and pickle objects added to this
# package and the dependencies between them.
@ -266,7 +255,7 @@ class PackageExporter:
)
self.importer = OrderedImporter(*importer)
self.patterns: Dict[GlobGroup, _PatternInfo] = {}
self.patterns: dict[GlobGroup, _PatternInfo] = {}
self._unique_id = 0
def save_source_file(
@ -331,7 +320,7 @@ class PackageExporter:
def _get_dependencies(
self, src: str, module_name: str, is_package: bool
) -> List[str]:
) -> list[str]:
"""Return all modules that this source code depends on.
Dependencies are found by scanning the source code for import-like statements.
@ -659,7 +648,7 @@ class PackageExporter:
all_dependencies = []
module = None
field = None
memo: DefaultDict[int, str] = defaultdict(None)
memo: defaultdict[int, str] = defaultdict(None)
memo_count = 0
# pickletools.dis(data_value)
for opcode, arg, _pos in pickletools.genops(data_value):
@ -1115,7 +1104,7 @@ class PackageExporter:
def _nodes_with_action_type(
self, action: Optional[_ModuleProviderAction]
) -> List[str]:
) -> list[str]:
result = []
for name, node_dict in self.dependency_graph.nodes.items():
node_action = node_dict.get("action", None)
@ -1124,7 +1113,7 @@ class PackageExporter:
result.sort()
return result
def externed_modules(self) -> List[str]:
def externed_modules(self) -> list[str]:
"""Return all modules that are currently externed.
Returns:
@ -1133,7 +1122,7 @@ class PackageExporter:
"""
return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
def interned_modules(self) -> List[str]:
def interned_modules(self) -> list[str]:
"""Return all modules that are currently interned.
Returns:
@ -1142,7 +1131,7 @@ class PackageExporter:
"""
return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
def mocked_modules(self) -> List[str]:
def mocked_modules(self) -> list[str]:
"""Return all modules that are currently mocked.
Returns:
@ -1151,7 +1140,7 @@ class PackageExporter:
"""
return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
def denied_modules(self) -> List[str]:
def denied_modules(self) -> list[str]:
"""Return all modules that are currently denied.
Returns:
@ -1160,7 +1149,7 @@ class PackageExporter:
"""
return self._nodes_with_action_type(_ModuleProviderAction.DENY)
def get_rdeps(self, module_name: str) -> List[str]:
def get_rdeps(self, module_name: str) -> list[str]:
"""Return a list of all modules which depend on the module ``module_name``.
Returns:

View File

@ -8,19 +8,9 @@ import linecache
import os
import sys
import types
from collections.abc import Iterable
from contextlib import contextmanager
from typing import (
Any,
BinaryIO,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
Union,
)
from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union
from weakref import WeakValueDictionary
import torch
@ -64,7 +54,7 @@ IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
# The primary motivation is to enable Numpy upgrade that many modules
# depend on. The latest release of Numpy removed `numpy.str` and
# `numpy.bool` breaking unpickling for many modules.
EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = {
EXTERN_IMPORT_COMPAT_NAME_MAPPING: dict[str, dict[str, Any]] = {
"numpy": {
"str": str,
"bool": bool,
@ -90,7 +80,7 @@ class PackageImporter(Importer):
local to this importer.
"""
modules: Dict[str, types.ModuleType]
modules: dict[str, types.ModuleType]
def __init__(
self,
@ -646,7 +636,7 @@ class PackageImporter(Importer):
return f"{name.replace('.', '/')}"
def _get_or_create_package(
self, atoms: List[str]
self, atoms: list[str]
) -> "Union[_PackageNode, _ExternNode]":
cur = self.root
for i, atom in enumerate(atoms):
@ -705,7 +695,7 @@ class _PathNode:
class _PackageNode(_PathNode):
def __init__(self, source_file: Optional[str]):
self.source_file = source_file
self.children: Dict[str, _PathNode] = {}
self.children: dict[str, _PathNode] = {}
class _ModuleNode(_PathNode):

View File

@ -4,7 +4,8 @@ import dataclasses
import enum
import itertools as it
import logging
from typing import Any, cast, DefaultDict, Dict, Iterator, List, Optional, Set, Union
from collections.abc import Iterator
from typing import Any, cast, Optional, Union
from typing_extensions import Literal
import torch
@ -226,7 +227,7 @@ class SchemaMatcher:
overload. If we cannot find any valid schema then we must be
conservative and assume all inputs are mutable.
"""
mutable: Optional[List[bool]] = None
mutable: Optional[list[bool]] = None
for schema in cls.match_schemas(t):
mutable = mutable or [False for _ in schema.arguments]
for i, arg in enumerate(schema.arguments):
@ -327,7 +328,7 @@ class OpTree:
class SizeMap:
def __init__(self, op_tree: OpTree) -> None:
self._values: Dict[TensorKey, int] = {}
self._values: dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.TorchOp:
@ -349,7 +350,7 @@ class SizeMap:
for _, t in state:
self._update_values(t)
allocations: Dict[TensorKey, int] = {}
allocations: dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.Allocation:
alloc_fields = node.typed[1]
@ -410,7 +411,7 @@ class DataFlowNode:
def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
self._event = event
self._graph = graph
self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
self._edges: dict[TensorKey, DataFlowEdge] = self._determine_edges()
for key, edge in self._edges.items():
if edge.mutated and not edge.is_allocation:
@ -420,11 +421,11 @@ class DataFlowNode:
versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
def _determine_edges(self) -> dict[TensorKey, DataFlowEdge]:
subtree = tuple(_utils.traverse_dfs([self._event]))
# Start by populating edges from op inputs and outputs.
mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {}
for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
for op_input, mutable in zip(
op.inputs, SchemaMatcher.inputs_are_mutable(op)
@ -440,7 +441,7 @@ class DataFlowNode:
key = TensorKey.from_tensor(op_input_i)
mutable_by_key.setdefault(key, set()).add(mutable)
edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
edges: collections.defaultdict[Optional[TensorKey], DataFlowEdge]
edges = collections.defaultdict(DataFlowEdge)
for key, mutable_set in mutable_by_key.items():
if key is not None:
@ -472,7 +473,7 @@ class DataFlowNode:
return dict(sorted((k, v) for k, v in edges.items() if k is not None))
@property
def inputs(self) -> Dict[TensorKey, tuple[bool, int]]:
def inputs(self) -> dict[TensorKey, tuple[bool, int]]:
return {
# MyPy can't see through `is_allocation` to know that
# `v.input_version` is not None.
@ -482,7 +483,7 @@ class DataFlowNode:
}
@property
def outputs(self) -> Dict[TensorKey, int]:
def outputs(self) -> dict[TensorKey, int]:
return {
k: 0 if v.input_version is None else v.input_version + 1
for k, v in self._edges.items()
@ -504,7 +505,7 @@ class DataFlowGraph:
def __init__(self, op_tree: OpTree) -> None:
self._op_tree = op_tree
self._leaf_events = self._extract_leaf_events(op_tree)
self._active_version: Dict[TensorKey, Optional[int]] = {}
self._active_version: dict[TensorKey, Optional[int]] = {}
self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
self._flow_nodes.sort(key=lambda x: x.start_time)
self.validate()
@ -515,7 +516,7 @@ class DataFlowGraph:
def validate(self):
# Check that each (Tensor, version) pair has a unique creation node
outputs: Set[tuple[TensorKey, int]] = set()
outputs: set[tuple[TensorKey, int]] = set()
for node in self.flow_nodes:
node_outputs = set(node.outputs.items())
duplicates = outputs & node_outputs
@ -523,7 +524,7 @@ class DataFlowGraph:
outputs |= node_outputs
# And check that `self._nodes` forms a valid topologically sorted DAG.
tensor_versions: Dict[TensorKey, int] = {}
tensor_versions: dict[TensorKey, int] = {}
for node in self.flow_nodes:
for key, (_, version) in node.inputs.items():
expected = tensor_versions.get(key, 0)
@ -571,7 +572,7 @@ class DataFlowGraph:
form the logical nodes in our data flow graph.
"""
leaf_events: List[_ProfilerEvent] = []
leaf_events: list[_ProfilerEvent] = []
def leaf_op(e: _ProfilerEvent) -> bool:
return e.typed[0] == _EventType.TorchOp and (
@ -609,17 +610,17 @@ class DataFlowGraph:
@dataclasses.dataclass
class CategoryElement:
by_id: Optional[Category] = None
by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
by_key: dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
by_version: dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
# Used by unit tests to check internals. (And consequently by
# MemoryProfile.lookup) This should not be used in any other capacity.
_by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
_by_id_keyset: set[TensorKey] = dataclasses.field(default_factory=set)
@dataclasses.dataclass
class CategoryDict:
_values: DefaultDict[int, CategoryElement] = dataclasses.field(
_values: collections.defaultdict[int, CategoryElement] = dataclasses.field(
default_factory=lambda: collections.defaultdict(CategoryElement)
)
@ -666,9 +667,9 @@ class MemoryProfile:
@property
def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]:
output: List[tuple[int, Action, KeyAndID, int]] = []
allocation_times: Dict[tuple[TensorKey, bool], int] = {}
live_unknown: Dict[tuple[int, torch.device], Literal[True]] = {}
output: list[tuple[int, Action, KeyAndID, int]] = []
allocation_times: dict[tuple[TensorKey, bool], int] = {}
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
for event in self._op_tree.dfs():
if event.typed[0] == _EventType.Allocation:
alloc_fields = event.typed[1]
@ -701,7 +702,7 @@ class MemoryProfile:
snapshot = self._category_snapshot()
last_version = dict(sorted(snapshot.keys()))
events: List[tuple[int, Action, TensorAndID]] = [
events: list[tuple[int, Action, TensorAndID]] = [
(-1, Action.PREEXISTING, (key, version))
for key, version in snapshot.keys()
if (key, True) not in allocation_times and version == 0
@ -734,8 +735,8 @@ class MemoryProfile:
def _is_gradient(self, *args, **kwargs) -> bool:
return self._categories.get(*args, **kwargs) == Category.GRADIENT
def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
all_tensor_versions: Set[TensorAndID] = set()
def _category_snapshot(self) -> dict[TensorAndID, Optional[Category]]:
all_tensor_versions: set[TensorAndID] = set()
for node in self._data_flow_graph.flow_nodes:
all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
@ -750,7 +751,7 @@ class MemoryProfile:
for key, version in sorted(all_tensor_versions)
}
def _any_version_depends_on_gradient(self) -> Set[int]:
def _any_version_depends_on_gradient(self) -> set[int]:
"""Extract IDs of Tensors which depend or will depend on a gradient.
Note that this weakened definition of "depends" requires us to loop
@ -761,7 +762,7 @@ class MemoryProfile:
acyclic data flow graph into a cyclic graph and we are attempting to
partition cycles involving a gradient from the rest of the graph.
"""
depends_on_gradient: Set[int] = set()
depends_on_gradient: set[int] = set()
while True:
start_size = len(depends_on_gradient)
for node in self._data_flow_graph.flow_nodes:
@ -837,7 +838,7 @@ class MemoryProfile:
# We only want to annotate Tensors which actually contribute to the
# model calculation.
produces_gradient: Set[TensorAndID] = set()
produces_gradient: set[TensorAndID] = set()
for node in reversed(self._data_flow_graph.flow_nodes):
tensors = {(key, version) for key, (_, version) in node.inputs.items()}
tensors |= node.outputs.items()
@ -894,8 +895,8 @@ class MemoryProfile:
# data flow. Note this these are only candidates; we filter nodes that
# we know are part of the backward pass but that doesn't guarantee that
# they are part of the forward pass.
candidate_parameters: Set[TensorAndID] = set()
candidate_fwd_tensors: Set[TensorAndID] = {
candidate_parameters: set[TensorAndID] = set()
candidate_fwd_tensors: set[TensorAndID] = {
i for i, category in snapshot.items() if category == Category.INPUT
}
@ -914,7 +915,7 @@ class MemoryProfile:
candidate_parameters |= inputs.difference(candidate_fwd_tensors)
# Require that each parameter eventually contributes to the value of a gradient
used_for_gradient: Set[TensorAndID] = set()
used_for_gradient: set[TensorAndID] = set()
for node in reversed(self._data_flow_graph.flow_nodes):
if any(
self._is_gradient(*i) or i in used_for_gradient
@ -993,8 +994,8 @@ class MemoryProfileTimeline:
Output: [timestamps, sizes by category]
"""
device = torch.device(device_str)
times: List[int] = []
sizes: List[List[int]] = []
times: list[int] = []
sizes: list[list[int]] = []
def update(key, version, delta):
category = (
@ -1061,7 +1062,7 @@ class MemoryProfileTimeline:
as a JSON formatted file to the given path for the given
device."""
device = torch.device(device_str)
raw_events: List[tuple[int, int, int, int]] = []
raw_events: list[tuple[int, int, int, int]] = []
def get_category_index(key, version):
category = (

View File

@ -3,7 +3,7 @@ import json
import math
import os
import re
from typing import Dict, List, Optional, Set
from typing import Optional
import torch
import torch.utils.benchmark as benchmark
@ -34,7 +34,7 @@ class Pattern:
self.url = ""
assert prof.profiler is not None and prof.profiler.kineto_results is not None
self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
self.tid_root: dict[int, list[_ProfilerEvent]] = {}
for event in self.event_tree:
self.tid_root.setdefault(event.start_tid, []).append(event)
@ -55,7 +55,7 @@ class Pattern:
"""
yield from traverse_dfs(self.event_tree)
def summary(self, events: List[_ProfilerEvent]):
def summary(self, events: list[_ProfilerEvent]):
default_summary = f"{self.name}: {len(events)} events matched."
if self.should_benchmark:
# If benchmark summary is not empty, use it.
@ -66,7 +66,7 @@ class Pattern:
)
return default_summary
def benchmark_summary(self, events: List[_ProfilerEvent]):
def benchmark_summary(self, events: list[_ProfilerEvent]):
def format_time(time_ns: int):
unit_lst = ["ns", "us", "ms"]
for unit in unit_lst:
@ -215,7 +215,7 @@ class ExtraCUDACopyPattern(Pattern):
return event.name in self.init_ops
# TODO: Check if tensor is reused
def benchmark(self, events: List[_ProfilerEvent]):
def benchmark(self, events: list[_ProfilerEvent]):
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
for shape in shapes_factor_map:
size = shape[0]
@ -252,7 +252,7 @@ class ForLoopIndexingPattern(Pattern):
super().__init__(prof, should_benchmark)
self.name = "For Loop Indexing Pattern"
self.description = "For loop indexing detected. Vectorization recommended."
self.visited: Set[int] = set()
self.visited: set[int] = set()
def eventTreeTraversal(self):
"""
@ -326,7 +326,7 @@ class FP32MatMulPattern(Pattern):
def report(self, event: _ProfilerEvent):
return self.description
def benchmark(self, events: List[_ProfilerEvent]):
def benchmark(self, events: list[_ProfilerEvent]):
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
for shape in shapes_factor_map:
matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
@ -553,7 +553,7 @@ class MatMulDimInFP16Pattern(Pattern):
return True
return False
def benchmark(self, events: List[_ProfilerEvent]):
def benchmark(self, events: list[_ProfilerEvent]):
def closest_multiple(shapes, multiple):
return [multiple * math.ceil(shape / multiple) for shape in shapes]
@ -609,7 +609,7 @@ def report_all_anti_patterns(
print_enable: bool = True,
json_report_dir: Optional[str] = None,
):
report_dict: Dict = {}
report_dict: dict = {}
anti_patterns = [
ExtraCUDACopyPattern(prof, should_benchmark),
# ForLoopIndexingPattern(prof, should_benchmark),

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -64,7 +64,7 @@ class EventKey:
def __repr__(self):
return f"{self.event.name}"
def intervals_overlap(self, intervals: List[Interval]):
def intervals_overlap(self, intervals: list[Interval]):
overlap_time = 0
intervals = sorted(intervals, key=lambda x: x.start)
@ -100,13 +100,13 @@ class EventKey:
class BasicEvaluation:
def __init__(self, prof: profile):
self.profile = prof
self.metrics: Dict[EventKey, EventMetrics] = {}
self.metrics: dict[EventKey, EventMetrics] = {}
self.compute_self_time()
self.event_keys = sorted(
(e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns
)
self.events = [e.event for e in self.event_keys]
self.cuda_events: List[_KinetoEvent] = []
self.cuda_events: list[_KinetoEvent] = []
self.queue_depth_list = self.compute_queue_depth()
self.compute_idle_time()
@ -162,7 +162,7 @@ class BasicEvaluation:
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
)
kernel_mapping: Dict[_KinetoEvent, int] = {}
kernel_mapping: dict[_KinetoEvent, int] = {}
last_mapped_kernel = 0
for cuda_launch_event in cuda_launch_events:
index = index_of_first_match(
@ -188,7 +188,7 @@ class BasicEvaluation:
return event.start_time_ns
raise Exception("Unknown Event Type") # noqa: TRY002
queue_depth_list: List[Interval] = []
queue_depth_list: list[Interval] = []
all_events.sort(key=new_old_event_comparator)
for event in all_events:
# Find latest cuda kernel event
@ -233,7 +233,7 @@ class BasicEvaluation:
# Based on queue_depth_list, we can calculate idle time for all the events
idle = False
idle_start = 0
idle_intervals: List[Interval] = []
idle_intervals: list[Interval] = []
if self.queue_depth_list and self.events:
idle_intervals += [
Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),

View File

@ -5,9 +5,10 @@ import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Iterable
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Callable, Optional
from typing_extensions import Self
from warnings import warn
@ -167,7 +168,7 @@ class _KinetoProfile:
self.use_device = _get_privateuse1_backend_name()
# user-defined metadata to be amended to the trace
self.preset_metadata: Dict[str, str] = {}
self.preset_metadata: dict[str, str] = {}
def start(self):
self.prepare_trace()
@ -723,8 +724,8 @@ class profile(_KinetoProfile):
self.current_action = self.schedule(self.step_num)
self.step_rec_fn: Optional[prof.record_function] = None
self.action_map: Dict[
tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]
self.action_map: dict[
tuple[ProfilerAction, Optional[ProfilerAction]], list[Any]
] = {
# key is (prev_action, current_action), value is action list corresponding to the state pair.
(ProfilerAction.NONE, ProfilerAction.NONE): [],

View File

@ -1,12 +1,11 @@
import os
import site
import sys
import typing
import torch
def _prefix_regex() -> typing.List[str]:
def _prefix_regex() -> list[str]:
raw_paths = (
site.getsitepackages()
+ sys.path

View File

@ -15,19 +15,7 @@ import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from typing import (
Any,
BinaryIO,
Callable,
cast,
Dict,
IO,
List,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union
from typing_extensions import TypeAlias, TypeIs
import torch
@ -79,7 +67,7 @@ STORAGE_KEY_SEPARATOR = ","
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MAP_LOCATION: TypeAlias = Optional[
Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]]
]
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
@ -132,8 +120,8 @@ def mkdtemp():
shutil.rmtree(path)
_package_registry: List[
Tuple[
_package_registry: list[
tuple[
int,
Callable[[STORAGE], Optional[str]],
Callable[[STORAGE, str], Optional[STORAGE]],
@ -270,14 +258,14 @@ def clear_safe_globals() -> None:
_weights_only_unpickler._clear_safe_globals()
def get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
"""
Returns the list of user-added globals that are safe for ``weights_only`` load.
"""
return _weights_only_unpickler._get_safe_globals()
def add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]) -> None:
def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None:
"""
Marks the given globals as safe for ``weights_only`` load. For example, functions
added to this list can be called during unpickling, classes could be instantiated
@ -338,7 +326,7 @@ class safe_globals(_weights_only_unpickler._safe_globals):
"""
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]:
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]:
"""Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
For a given function or class ``f``, the corresponding string will be of the form
@ -804,7 +792,7 @@ class _open_zipfile_writer_buffer(_opener):
def _open_zipfile_writer(name_or_buffer):
container: Type[_opener]
container: type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
@ -965,15 +953,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {}
serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
storage_dtypes: dict[int, torch.dtype] = {}
def persistent_id(obj: Any) -> Optional[Tuple]:
def persistent_id(obj: Any) -> Optional[tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
@ -1032,7 +1020,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
view_metadata: Optional[Tuple[str, int, int]]
view_metadata: Optional[tuple[str, int, int]]
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
@ -1127,13 +1115,13 @@ def _save(
_disable_byteorder_record,
):
serialized_storages = {}
id_map: Dict[int, str] = {}
id_map: dict[int, str] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
storage_dtypes: dict[int, torch.dtype] = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
@ -1534,7 +1522,7 @@ copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects: Dict[int, Any] = {}
deserialized_objects: dict[int, Any] = {}
restore_location = _get_restore_location(map_location)
@ -1599,7 +1587,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
deserialized_objects: dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
@ -1950,7 +1938,7 @@ def _load(
return typed_storage
load_module_mapping: Dict[str, str] = {
load_module_mapping: dict[str, str] = {
# See https://github.com/pytorch/pytorch/pull/51633
"torch.tensor": "torch._tensor"
}

View File

@ -19,11 +19,11 @@ from .semi_structured import (
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
DimOrDims = Optional[Tuple[int]]
DimOrDims = Optional[tuple[int]]
__all__ = [
@ -591,7 +591,7 @@ def as_sparse_gradcheck(gradcheck):
"""Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
if not isinstance(args, (list, tuple)):
args = (args,)
new_args: List[Any] = []
new_args: list[Any] = []
for obj in args:
if (
isinstance(obj, torch.Tensor)

View File

@ -4,7 +4,7 @@ import math
import os
import weakref
from functools import lru_cache
from typing import Optional, Tuple
from typing import Optional
import torch
from torch._dynamo.utils import warn_once
@ -1124,7 +1124,7 @@ def _int_bsr_dense_addmm(
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
):
if out is None and dense.dtype is torch.int8:
@ -1165,7 +1165,7 @@ def bsr_dense_addmm(
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
):
"""Compute
@ -1647,7 +1647,7 @@ if has_triton():
alpha=1.0,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
):
f_name = "sampled_addmm"
@ -1731,7 +1731,7 @@ if has_triton():
*,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
):
f_name = "bsr_dense_mm"

View File

@ -103,7 +103,7 @@ import inspect
import itertools
import re
import warnings
from typing import Any, Dict
from typing import Any
import torch
from torch.hub import tqdm
@ -937,7 +937,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True):
dump()
_operation_device_version_data: Dict[Any, Dict] = {
_operation_device_version_data: dict[Any, dict] = {
# Warning: the data in between the BEGIN/END DATA comment lines
# below is generated. It can be updated either manually or via
# calling dump function defined above.

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import warnings
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
from torch.sparse._semi_structured_conversions import (
@ -54,13 +54,13 @@ class SparseSemiStructuredTensor(torch.Tensor):
"""
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = False
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
BACKEND: str
SPARSE_DISPATCH: Dict[Callable, Callable]
SPARSE_DISPATCH: dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
@ -161,7 +161,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
@ -177,7 +177,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool, int, bool],
tensor_meta: tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:

View File

@ -23,13 +23,13 @@ from .streams import Event, Stream
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
_queued_calls: list[
tuple[Callable[[], None], list[str]]
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment]
def _is_compiled() -> bool:
@ -216,7 +216,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
@lru_cache(None)
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
r"""Get the xpu capability of a device.
Args:
@ -418,7 +418,7 @@ def synchronize(device: _device_t = None) -> None:
return torch._C._xpu_synchronize(device)
def get_arch_list() -> List[str]:
def get_arch_list() -> list[str]:
r"""Return list XPU architectures this library was compiled for."""
if not is_available():
return []

View File

@ -1,5 +1,5 @@
import collections
from typing import Any, Dict, Tuple, Union
from typing import Any, Union
import torch
from torch.types import Device
@ -53,7 +53,7 @@ def reset_accumulated_memory_stats(device: _device_t = None) -> None:
return torch._C._xpu_resetAccumulatedMemoryStats(device)
def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]:
def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]:
r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary."""
if not is_initialized():
return {}
@ -61,7 +61,7 @@ def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]:
return torch._C._xpu_memoryStats(device)
def memory_stats(device: _device_t = None) -> Dict[str, Any]:
def memory_stats(device: _device_t = None) -> dict[str, Any]:
r"""Return a dictionary of XPU memory allocator statistics for a given device.
The return value of this function is a dictionary of statistics, each of
@ -178,7 +178,7 @@ def max_memory_reserved(device: _device_t = None) -> int:
return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
def mem_get_info(device: _device_t = None) -> Tuple[int, int]:
def mem_get_info(device: _device_t = None) -> tuple[int, int]:
r"""Return the global free and total GPU memory for a given device.
Args:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Iterable, List, Union
from collections.abc import Iterable
from typing import Union
import torch
from torch import Tensor
@ -29,7 +30,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
def get_rng_state_all() -> list[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices."""
results = [get_rng_state(i) for i in range(device_count())]
return results