mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
bd97ce0b45
commit
54a00af2c6
@ -3,7 +3,7 @@
|
||||
import importlib
|
||||
import math
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import _VF, sym_int as _sym_int, Tensor
|
||||
@ -708,7 +708,7 @@ def max_pool1d_with_indices(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch.max_pool1d_with_indices(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
@ -736,7 +736,7 @@ def _max_pool1d(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@ -798,7 +798,7 @@ def max_pool2d_with_indices(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch._C._nn.max_pool2d_with_indices(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
@ -826,7 +826,7 @@ def _max_pool2d(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@ -888,7 +888,7 @@ def max_pool3d_with_indices(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch._C._nn.max_pool3d_with_indices(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
@ -916,7 +916,7 @@ def _max_pool3d(
|
||||
return_indices=return_indices,
|
||||
)
|
||||
if stride is None:
|
||||
stride = torch.jit.annotate(List[int], [])
|
||||
stride = torch.jit.annotate(list[int], [])
|
||||
return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
|
||||
@ -939,7 +939,7 @@ def _unpool_output_size(
|
||||
output_size: Optional[list[int]],
|
||||
) -> list[int]:
|
||||
input_size = input.size()
|
||||
default_size = torch.jit.annotate(List[int], [])
|
||||
default_size = torch.jit.annotate(list[int], [])
|
||||
for d in range(len(kernel_size)):
|
||||
default_size.append(
|
||||
(input_size[-len(kernel_size) + d] - 1) * stride[d]
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd):
|
||||
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
|
||||
)
|
||||
|
||||
min_sizes = torch.jit.annotate(List[int], [])
|
||||
max_sizes = torch.jit.annotate(List[int], [])
|
||||
min_sizes = torch.jit.annotate(list[int], [])
|
||||
max_sizes = torch.jit.annotate(list[int], [])
|
||||
for d in range(num_spatial_dims):
|
||||
dim_size = (
|
||||
(input.size(d + num_non_spatial_dims) - 1) * stride[d]
|
||||
@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd):
|
||||
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
|
||||
)
|
||||
|
||||
res = torch.jit.annotate(List[int], [])
|
||||
res = torch.jit.annotate(list[int], [])
|
||||
for d in range(num_spatial_dims):
|
||||
res.append(output_size[d] - min_sizes[d])
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -25,7 +25,7 @@ class Adafactor(Optimizer):
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-2,
|
||||
beta2_decay: float = -0.8,
|
||||
eps: Tuple[Optional[float], float] = (None, 1e-3),
|
||||
eps: tuple[Optional[float], float] = (None, 1e-3),
|
||||
d: float = 1.0,
|
||||
weight_decay: float = 0.0,
|
||||
*,
|
||||
@ -133,12 +133,12 @@ class Adafactor(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
row_vars: List[Optional[Tensor]] = []
|
||||
col_vars: List[Optional[Tensor]] = []
|
||||
variances: List[Optional[Tensor]] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
row_vars: list[Optional[Tensor]] = []
|
||||
col_vars: list[Optional[Tensor]] = []
|
||||
variances: list[Optional[Tensor]] = []
|
||||
state_steps: list[Tensor] = []
|
||||
eps1, eps2 = group["eps"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
@ -324,16 +324,16 @@ Adafactor.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
|
||||
# so row_var and col_var will be None while variance will be filled.
|
||||
# Contrarily, for a grad with multiple dimensions, we will factor along the last
|
||||
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
|
||||
row_vars: List[Optional[Tensor]],
|
||||
col_vars: List[Optional[Tensor]],
|
||||
variances: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
row_vars: list[Optional[Tensor]],
|
||||
col_vars: list[Optional[Tensor]],
|
||||
variances: list[Optional[Tensor]],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -411,17 +411,17 @@ def _single_tensor_adafactor(
|
||||
|
||||
def _group_tensors_by_device_dtype_and_is_multidim(
|
||||
tensorlists: TensorListList,
|
||||
) -> Dict[
|
||||
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
|
||||
List[List[Optional[Tensor]]],
|
||||
) -> dict[
|
||||
tuple[Optional[torch.device], Optional[torch.dtype], bool],
|
||||
list[list[Optional[Tensor]]],
|
||||
]:
|
||||
"""Groups tensors by device, dtype, AND multidimensionality -- whether the tensor
|
||||
has multiple dims or just one dim (is a vector). This allows the foreach impl of
|
||||
Adafactor to assume that every group of params will either be factored or not."""
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists)
|
||||
ultra_grouped_tensors: Dict[
|
||||
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
|
||||
List[List[Optional[Tensor]]],
|
||||
ultra_grouped_tensors: dict[
|
||||
tuple[Optional[torch.device], Optional[torch.dtype], bool],
|
||||
list[list[Optional[Tensor]]],
|
||||
] = {}
|
||||
for (device, dtype), (tensorlists, _) in grouped_tensors.items():
|
||||
matrix_key = (device, dtype, True)
|
||||
@ -444,16 +444,16 @@ def _group_tensors_by_device_dtype_and_is_multidim(
|
||||
|
||||
|
||||
def _multi_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
|
||||
# so row_var and col_var will be None while variance will be filled.
|
||||
# Contrarily, for a grad with multiple dimensions, we will factor along the last
|
||||
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
|
||||
row_vars: List[Optional[Tensor]],
|
||||
col_vars: List[Optional[Tensor]],
|
||||
variances: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
row_vars: list[Optional[Tensor]],
|
||||
col_vars: list[Optional[Tensor]],
|
||||
variances: list[Optional[Tensor]],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -486,9 +486,9 @@ def _multi_tensor_adafactor(
|
||||
device_state_steps_,
|
||||
)
|
||||
) in grouped_tensors.items():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
if eps1 is None:
|
||||
assert (
|
||||
dtype is not None
|
||||
@ -530,8 +530,8 @@ def _multi_tensor_adafactor(
|
||||
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
|
||||
|
||||
if is_multidim:
|
||||
device_row_vars = cast(List[Tensor], device_row_vars_)
|
||||
device_col_vars = cast(List[Tensor], device_col_vars_)
|
||||
device_row_vars = cast(list[Tensor], device_row_vars_)
|
||||
device_col_vars = cast(list[Tensor], device_col_vars_)
|
||||
assert (
|
||||
device_row_vars[0] is not None and device_col_vars[0] is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
@ -564,7 +564,7 @@ def _multi_tensor_adafactor(
|
||||
torch._foreach_div_(var_estimates, row_var_means)
|
||||
del row_var_means
|
||||
else:
|
||||
device_variances = cast(List[Tensor], device_variances_)
|
||||
device_variances = cast(list[Tensor], device_variances_)
|
||||
assert (
|
||||
device_variances[0] is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
@ -592,12 +592,12 @@ def _multi_tensor_adafactor(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
|
||||
def adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
row_vars: List[Optional[Tensor]],
|
||||
col_vars: List[Optional[Tensor]],
|
||||
variances: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
row_vars: list[Optional[Tensor]],
|
||||
col_vars: list[Optional[Tensor]],
|
||||
variances: list[Optional[Tensor]],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Functional interface."""
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@ -22,11 +21,11 @@ from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
|
||||
|
||||
|
||||
def sparse_adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[int],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[int],
|
||||
*,
|
||||
eps: float,
|
||||
beta1: float,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, cast, Dict, List, Optional, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -82,12 +82,12 @@ class Adadelta(Optimizer):
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group: Dict[str, Any],
|
||||
params_with_grad: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
group: dict[str, Any],
|
||||
params_with_grad: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
acc_deltas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
):
|
||||
has_complex = False
|
||||
p: Tensor
|
||||
@ -139,11 +139,11 @@ class Adadelta(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
square_avgs: List[Tensor] = []
|
||||
acc_deltas: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
square_avgs: list[Tensor] = []
|
||||
acc_deltas: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
(
|
||||
lr,
|
||||
rho,
|
||||
@ -242,11 +242,11 @@ Adadelta.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
acc_deltas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lr: float,
|
||||
rho: float,
|
||||
@ -296,11 +296,11 @@ def _single_tensor_adadelta(
|
||||
|
||||
|
||||
def _multi_tensor_adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
acc_deltas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lr: float,
|
||||
rho: float,
|
||||
@ -337,11 +337,11 @@ def _multi_tensor_adadelta(
|
||||
device_acc_deltas_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_square_avgs = cast(List[Tensor], device_square_avgs_)
|
||||
device_acc_deltas = cast(List[Tensor], device_acc_deltas_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_square_avgs = cast(list[Tensor], device_square_avgs_)
|
||||
device_acc_deltas = cast(list[Tensor], device_acc_deltas_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_square_avgs, device_acc_deltas
|
||||
@ -397,11 +397,11 @@ def _multi_tensor_adadelta(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
|
||||
def adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
acc_deltas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
capturable: bool = False,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -157,10 +157,10 @@ class Adagrad(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
state_sums: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
state_sums: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
has_sparse_grad, has_complex = self._init_group(
|
||||
group, params_with_grad, grads, state_sums, state_steps
|
||||
@ -240,10 +240,10 @@ Adagrad.__doc__ = (
|
||||
|
||||
|
||||
def adagrad(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
state_sums: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
state_sums: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
@ -319,10 +319,10 @@ def _make_sparse(grad, grad_indices, values):
|
||||
|
||||
|
||||
def _single_tensor_adagrad(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
state_sums: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
state_sums: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -380,10 +380,10 @@ def _single_tensor_adagrad(
|
||||
|
||||
|
||||
def _multi_tensor_adagrad(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
state_sums: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
state_sums: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -412,10 +412,10 @@ def _multi_tensor_adagrad(
|
||||
device_state_sums_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensorlists.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_state_sums = cast(List[Tensor], device_state_sums_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_state_sums = cast(list[Tensor], device_state_sums_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
|
||||
device_has_sparse_grad = has_sparse_grad and any(
|
||||
grad.is_sparse for grad in device_grads
|
||||
@ -487,10 +487,10 @@ def _multi_tensor_adagrad(
|
||||
|
||||
|
||||
def _fused_adagrad(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
state_sums: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
state_sums: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -530,10 +530,10 @@ def _fused_adagrad(
|
||||
),
|
||||
_,
|
||||
) in grouped_tensors.items():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_state_sums = cast(List[Tensor], device_state_sums_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_state_sums = cast(list[Tensor], device_state_sums_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
|
||||
device_grad_scale, device_found_inf = None, None
|
||||
if grad_scale is not None and grad_scale_dict is not None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -35,7 +35,7 @@ class Adam(Optimizer):
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
|
||||
betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0,
|
||||
amsgrad: bool = False,
|
||||
@ -225,12 +225,12 @@ class Adam(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
max_exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
exp_avgs: list[Tensor] = []
|
||||
exp_avg_sqs: list[Tensor] = []
|
||||
max_exp_avg_sqs: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
@ -342,12 +342,12 @@ Adam.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
max_exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -532,12 +532,12 @@ def _single_tensor_adam(
|
||||
|
||||
|
||||
def _multi_tensor_adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
max_exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -612,11 +612,11 @@ def _multi_tensor_adam(
|
||||
device_max_exp_avg_sqs_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
|
||||
device = device_params[0].device
|
||||
if beta1_dict is not None and device not in beta1_dict:
|
||||
@ -627,7 +627,7 @@ def _multi_tensor_adam(
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
|
||||
_view_as_real(
|
||||
device_params,
|
||||
device_grads,
|
||||
@ -693,9 +693,9 @@ def _multi_tensor_adam(
|
||||
del device_grads
|
||||
del scaled_device_grads
|
||||
|
||||
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
|
||||
if capturable:
|
||||
bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
|
||||
@ -719,7 +719,7 @@ def _multi_tensor_adam(
|
||||
bias_correction2_sqrt = bias_correction2
|
||||
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
|
||||
|
||||
@ -747,7 +747,7 @@ def _multi_tensor_adam(
|
||||
bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
|
||||
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
|
||||
|
||||
@ -764,12 +764,12 @@ def _multi_tensor_adam(
|
||||
|
||||
|
||||
def _fused_adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
max_exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -816,11 +816,11 @@ def _fused_adam(
|
||||
),
|
||||
_,
|
||||
) in grouped_tensors.items():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
device_params = cast(list[Tensor], device_params_)
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
|
||||
if device.type == "mps": # type: ignore[union-attr]
|
||||
assert found_inf is None and grad_scale is None
|
||||
@ -864,12 +864,12 @@ def _fused_adam(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
|
||||
def adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
max_exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@ -23,7 +23,7 @@ class AdamW(Adam):
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
|
||||
betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 1e-2,
|
||||
amsgrad: bool = False,
|
||||
@ -128,12 +128,12 @@ AdamW.__doc__ = (
|
||||
|
||||
# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
|
||||
def adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
max_exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -135,12 +135,12 @@ class ASGD(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
mus: List[Tensor] = []
|
||||
axs: List[Tensor] = []
|
||||
etas: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
mus: list[Tensor] = []
|
||||
axs: list[Tensor] = []
|
||||
etas: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
has_complex = self._init_group(
|
||||
group, params_with_grad, grads, mus, axs, etas, state_steps
|
||||
@ -192,12 +192,12 @@ ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
|
||||
|
||||
|
||||
def _single_tensor_asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
axs: List[Tensor],
|
||||
mus: List[Tensor],
|
||||
etas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
axs: list[Tensor],
|
||||
mus: list[Tensor],
|
||||
etas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lambd: float,
|
||||
lr: float,
|
||||
@ -268,12 +268,12 @@ def _single_tensor_asgd(
|
||||
|
||||
|
||||
def _multi_tensor_asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
axs: List[Tensor],
|
||||
mus: List[Tensor],
|
||||
etas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
axs: list[Tensor],
|
||||
mus: list[Tensor],
|
||||
etas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lambd: float,
|
||||
lr: float,
|
||||
@ -315,12 +315,12 @@ def _multi_tensor_asgd(
|
||||
),
|
||||
_,
|
||||
) in grouped_tensors.items():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_axs = cast(List[Tensor], grouped_axs_)
|
||||
grouped_mus = cast(List[Tensor], grouped_mus_)
|
||||
grouped_etas = cast(List[Tensor], grouped_etas_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_axs = cast(list[Tensor], grouped_axs_)
|
||||
grouped_mus = cast(list[Tensor], grouped_mus_)
|
||||
grouped_etas = cast(list[Tensor], grouped_etas_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
if has_complex:
|
||||
_view_as_real(grouped_params, grouped_grads, grouped_axs)
|
||||
@ -340,7 +340,7 @@ def _multi_tensor_asgd(
|
||||
torch._foreach_add_(grouped_state_steps, 1)
|
||||
|
||||
# intermediate = grad + param * lambd
|
||||
intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
intermediate: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
if weight_decay != 0:
|
||||
if maximize:
|
||||
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
|
||||
@ -375,8 +375,8 @@ def _multi_tensor_asgd(
|
||||
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
|
||||
del intermediate
|
||||
|
||||
new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
new_etas: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
new_mus: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
if capturable:
|
||||
# update grouped_mus
|
||||
new_mus = torch._foreach_sub(grouped_state_steps, t0)
|
||||
@ -408,12 +408,12 @@ def _multi_tensor_asgd(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
|
||||
def asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
axs: List[Tensor],
|
||||
mus: List[Tensor],
|
||||
etas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
axs: list[Tensor],
|
||||
mus: list[Tensor],
|
||||
etas: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -5,17 +5,14 @@ import types
|
||||
import warnings
|
||||
from bisect import bisect_right
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
SupportsFloat,
|
||||
TypedDict,
|
||||
Union,
|
||||
@ -116,7 +113,7 @@ class LRScheduler:
|
||||
"param 'initial_lr' is not specified "
|
||||
f"in param_groups[{i}] when resuming an optimizer"
|
||||
)
|
||||
self.base_lrs: List[float] = [
|
||||
self.base_lrs: list[float] = [
|
||||
group["initial_lr"] for group in optimizer.param_groups
|
||||
]
|
||||
self.last_epoch = last_epoch
|
||||
@ -163,7 +160,7 @@ class LRScheduler:
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||
def load_state_dict(self, state_dict: dict[str, Any]):
|
||||
"""Load the scheduler's state.
|
||||
|
||||
Args:
|
||||
@ -172,18 +169,18 @@ class LRScheduler:
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
def get_last_lr(self) -> list[float]:
|
||||
"""Return last computed learning rate by current scheduler."""
|
||||
return self._last_lr
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
def get_lr(self) -> list[float]:
|
||||
"""Compute learning rate using chainable form of the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def print_lr(
|
||||
self,
|
||||
is_verbose: bool,
|
||||
group: Dict[str, Any],
|
||||
group: dict[str, Any],
|
||||
lr: float,
|
||||
epoch: Optional[int] = None,
|
||||
):
|
||||
@ -243,7 +240,7 @@ class LRScheduler:
|
||||
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
||||
self.last_epoch = epoch
|
||||
if hasattr(self, "_get_closed_form_lr"):
|
||||
values = cast(List[float], self._get_closed_form_lr())
|
||||
values = cast(list[float], self._get_closed_form_lr())
|
||||
else:
|
||||
values = self.get_lr()
|
||||
|
||||
@ -253,7 +250,7 @@ class LRScheduler:
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
|
||||
self._last_lr: List[float] = [
|
||||
self._last_lr: list[float] = [
|
||||
group["lr"] for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
@ -320,13 +317,13 @@ class LambdaLR(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
|
||||
lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
|
||||
last_epoch: int = -1,
|
||||
verbose="deprecated",
|
||||
): # noqa: D107
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.lr_lambdas: List[Callable[[int], float]]
|
||||
self.lr_lambdas: list[Callable[[int], float]]
|
||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
||||
else:
|
||||
@ -420,13 +417,13 @@ class MultiplicativeLR(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]],
|
||||
lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
|
||||
last_epoch: int = -1,
|
||||
verbose="deprecated",
|
||||
): # noqa: D107
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.lr_lambdas: List[Callable[[int], float]]
|
||||
self.lr_lambdas: list[Callable[[int], float]]
|
||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
||||
else:
|
||||
@ -863,8 +860,8 @@ class SequentialLR(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
schedulers: List[LRScheduler],
|
||||
milestones: List[int],
|
||||
schedulers: list[LRScheduler],
|
||||
milestones: list[int],
|
||||
last_epoch: int = -1,
|
||||
verbose="deprecated",
|
||||
): # noqa: D107
|
||||
@ -1311,7 +1308,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
threshold: float = 1e-4,
|
||||
threshold_mode: Literal["rel", "abs"] = "rel",
|
||||
cooldown: int = 0,
|
||||
min_lr: Union[List[float], float] = 0,
|
||||
min_lr: Union[list[float], float] = 0,
|
||||
eps: float = 1e-8,
|
||||
verbose="deprecated",
|
||||
): # noqa: D107
|
||||
@ -1557,8 +1554,8 @@ class CyclicLR(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
base_lr: Union[float, List[float]],
|
||||
max_lr: Union[float, List[float]],
|
||||
base_lr: Union[float, list[float]],
|
||||
max_lr: Union[float, list[float]],
|
||||
step_size_up: int = 2000,
|
||||
step_size_down: Optional[int] = None,
|
||||
mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
|
||||
@ -1988,15 +1985,15 @@ class OneCycleLR(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
max_lr: Union[float, List[float]],
|
||||
max_lr: Union[float, list[float]],
|
||||
total_steps: Optional[int] = None,
|
||||
epochs: Optional[int] = None,
|
||||
steps_per_epoch: Optional[int] = None,
|
||||
pct_start: float = 0.3,
|
||||
anneal_strategy: Literal["cos", "linear"] = "cos",
|
||||
cycle_momentum: bool = True,
|
||||
base_momentum: Union[float, List[float]] = 0.85,
|
||||
max_momentum: Union[float, List[float]] = 0.95,
|
||||
base_momentum: Union[float, list[float]] = 0.85,
|
||||
max_momentum: Union[float, list[float]] = 0.95,
|
||||
div_factor: float = 25.0,
|
||||
final_div_factor: float = 1e4,
|
||||
three_phase: bool = False,
|
||||
@ -2028,7 +2025,7 @@ class OneCycleLR(LRScheduler):
|
||||
"You must define either total_steps OR (epochs AND steps_per_epoch)"
|
||||
)
|
||||
|
||||
self._schedule_phases: List[_SchedulePhase]
|
||||
self._schedule_phases: list[_SchedulePhase]
|
||||
if three_phase:
|
||||
self._schedule_phases = [
|
||||
{
|
||||
|
@ -3,25 +3,10 @@
|
||||
import functools
|
||||
import warnings
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import Hashable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, Self, TypeAlias
|
||||
|
||||
import torch
|
||||
@ -39,15 +24,15 @@ from torch.utils.hooks import RemovableHandle
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
Args: TypeAlias = Tuple[Any, ...]
|
||||
Kwargs: TypeAlias = Dict[str, Any]
|
||||
StateDict: TypeAlias = Dict[str, Any]
|
||||
DeviceDict = Dict[Optional[torch.device], torch.Tensor]
|
||||
DeviceDtypeDict = Dict[Optional[Tuple[torch.device, torch.dtype]], torch.Tensor]
|
||||
Args: TypeAlias = tuple[Any, ...]
|
||||
Kwargs: TypeAlias = dict[str, Any]
|
||||
StateDict: TypeAlias = dict[str, Any]
|
||||
DeviceDict = dict[Optional[torch.device], torch.Tensor]
|
||||
DeviceDtypeDict = dict[Optional[tuple[torch.device, torch.dtype]], torch.Tensor]
|
||||
|
||||
|
||||
GlobalOptimizerPreHook: TypeAlias = Callable[
|
||||
["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]
|
||||
["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]]
|
||||
]
|
||||
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
|
||||
|
||||
@ -56,8 +41,8 @@ __all__ = [
|
||||
"register_optimizer_step_pre_hook",
|
||||
"register_optimizer_step_post_hook",
|
||||
]
|
||||
_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
|
||||
_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
|
||||
_global_optimizer_pre_hooks: dict[int, GlobalOptimizerPreHook] = OrderedDict()
|
||||
_global_optimizer_post_hooks: dict[int, GlobalOptimizerPostHook] = OrderedDict()
|
||||
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
|
||||
|
||||
|
||||
@ -173,8 +158,8 @@ def _disable_dynamo_if_unsupported(
|
||||
# torch.jit.script nor differentiable, so we fall back to the single tensor
|
||||
# implementation in those cases.
|
||||
def _default_to_fused_or_foreach(
|
||||
params: List[torch.Tensor], differentiable: bool, use_fused: bool = False
|
||||
) -> Tuple[bool, bool]:
|
||||
params: list[torch.Tensor], differentiable: bool, use_fused: bool = False
|
||||
) -> tuple[bool, bool]:
|
||||
if torch.jit.is_scripting() or differentiable:
|
||||
return False, False
|
||||
|
||||
@ -229,7 +214,7 @@ def _get_scalar_dtype(is_fused=None):
|
||||
)
|
||||
|
||||
|
||||
def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
|
||||
def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]:
|
||||
r"""Return the device type list that supports capturable optimizer."""
|
||||
capturable_supported_devices = ["cuda", "xpu", "hpu"]
|
||||
if not torch.jit.is_scripting():
|
||||
@ -321,7 +306,7 @@ def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> Removabl
|
||||
|
||||
|
||||
ParamsT: TypeAlias = Union[
|
||||
Iterable[torch.Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, torch.Tensor]]
|
||||
Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]]
|
||||
]
|
||||
|
||||
R = TypeVar("R")
|
||||
@ -343,17 +328,17 @@ class Optimizer:
|
||||
options (used when a parameter group doesn't specify them).
|
||||
"""
|
||||
|
||||
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc]
|
||||
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
|
||||
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
|
||||
|
||||
_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
|
||||
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107
|
||||
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
|
||||
torch._C._log_api_usage_once("python.optimizer")
|
||||
self.defaults = defaults
|
||||
self._optimizer_step_pre_hooks = OrderedDict()
|
||||
@ -371,8 +356,8 @@ class Optimizer:
|
||||
"an iterable of Tensors or dicts, but got " + torch.typename(params)
|
||||
)
|
||||
|
||||
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
|
||||
self.param_groups: List[Dict[str, Any]] = []
|
||||
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
|
||||
self.param_groups: list[dict[str, Any]] = []
|
||||
|
||||
param_groups = list(params)
|
||||
if len(param_groups) == 0:
|
||||
@ -388,14 +373,14 @@ class Optimizer:
|
||||
# https://github.com/pytorch/pytorch/issues/72948
|
||||
self._warned_capturable_if_run_uncaptured = True
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]: # noqa: D105
|
||||
def __getstate__(self) -> dict[str, Any]: # noqa: D105
|
||||
return {
|
||||
"defaults": self.defaults,
|
||||
"state": self.state,
|
||||
"param_groups": self.param_groups,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105
|
||||
def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105
|
||||
self.__dict__.update(state)
|
||||
if "_optimizer_step_pre_hooks" not in self.__dict__:
|
||||
self._optimizer_step_pre_hooks = OrderedDict()
|
||||
@ -516,8 +501,8 @@ class Optimizer:
|
||||
tensorlistlist: TensorListList,
|
||||
with_indices: bool = False,
|
||||
) -> Union[
|
||||
Dict[Tuple[None, None], Tuple[TensorListList, Indices]],
|
||||
Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]],
|
||||
dict[tuple[None, None], tuple[TensorListList, Indices]],
|
||||
dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]],
|
||||
]:
|
||||
"""Group a list of lists of tensors by device and dtype.
|
||||
|
||||
@ -705,10 +690,10 @@ class Optimizer:
|
||||
pre_hook(self)
|
||||
|
||||
# Save order indices instead of Tensors
|
||||
param_mappings: Dict[int, int] = {}
|
||||
param_mappings: dict[int, int] = {}
|
||||
start_index = 0
|
||||
|
||||
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def pack_group(group: dict[str, Any]) -> dict[str, Any]:
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
param_mappings.update(
|
||||
@ -745,7 +730,7 @@ class Optimizer:
|
||||
param: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
param_id: int,
|
||||
param_groups: List[Dict[Any, Any]],
|
||||
param_groups: list[dict[Any, Any]],
|
||||
key: Hashable = None,
|
||||
) -> torch.Tensor:
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
@ -918,7 +903,7 @@ class Optimizer:
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
|
||||
state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict)
|
||||
for k, v in state_dict["state"].items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
@ -930,8 +915,8 @@ class Optimizer:
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(
|
||||
group: Dict[str, Any], new_group: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
group: dict[str, Any], new_group: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
new_group["params"] = group["params"]
|
||||
if "param_names" in group and "param_names" not in new_group:
|
||||
new_group["param_names"] = group["param_names"]
|
||||
@ -967,7 +952,7 @@ class Optimizer:
|
||||
self._patch_step_function()
|
||||
|
||||
per_device_and_dtype_grads: Optional[
|
||||
DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]
|
||||
defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]]
|
||||
]
|
||||
if foreach:
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
|
||||
@ -1016,7 +1001,7 @@ class Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
@torch._disable_dynamo
|
||||
def add_param_group(self, param_group: Dict[str, Any]) -> None:
|
||||
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
||||
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
||||
|
||||
This can be useful when fine tuning a pre-trained network as frozen layers can be made
|
||||
@ -1087,7 +1072,7 @@ class Optimizer:
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
param_set: Set[torch.Tensor] = set()
|
||||
param_set: set[torch.Tensor] = set()
|
||||
for group in self.param_groups:
|
||||
param_set.update(set(group["params"]))
|
||||
if ("param_names" in param_group) != ("param_names" in group):
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RMSprop algorithm."""
|
||||
from typing import cast, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -155,12 +155,12 @@ class RMSprop(Optimizer): # noqa: D101
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
square_avgs: List[Tensor] = []
|
||||
grad_avgs: List[Tensor] = []
|
||||
momentum_buffer_list: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
square_avgs: list[Tensor] = []
|
||||
grad_avgs: list[Tensor] = []
|
||||
momentum_buffer_list: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
@ -261,12 +261,12 @@ RMSprop.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
grad_avgs: List[Tensor],
|
||||
momentum_buffer_list: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
grad_avgs: list[Tensor],
|
||||
momentum_buffer_list: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lr: float,
|
||||
alpha: float,
|
||||
@ -332,12 +332,12 @@ def _single_tensor_rmsprop(
|
||||
|
||||
|
||||
def _multi_tensor_rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
grad_avgs: List[Tensor],
|
||||
momentum_buffer_list: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
grad_avgs: list[Tensor],
|
||||
momentum_buffer_list: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
lr: float,
|
||||
alpha: float,
|
||||
@ -377,20 +377,20 @@ def _multi_tensor_rmsprop(
|
||||
grouped_state_steps_,
|
||||
)
|
||||
), _ in grouped_tensors.values():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
if has_complex:
|
||||
state_and_grads = [grouped_grads, grouped_square_avgs]
|
||||
if momentum > 0:
|
||||
grouped_momentum_buffer_list = cast(
|
||||
List[Tensor], grouped_momentum_buffer_list_
|
||||
list[Tensor], grouped_momentum_buffer_list_
|
||||
)
|
||||
state_and_grads.append(grouped_momentum_buffer_list)
|
||||
if centered:
|
||||
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
|
||||
grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
|
||||
state_and_grads.append(grouped_grad_avgs)
|
||||
_view_as_real(grouped_params, *state_and_grads)
|
||||
|
||||
@ -423,7 +423,7 @@ def _multi_tensor_rmsprop(
|
||||
)
|
||||
|
||||
if centered:
|
||||
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
|
||||
grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
|
||||
torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
|
||||
avg = torch._foreach_addcmul(
|
||||
grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
|
||||
@ -436,7 +436,7 @@ def _multi_tensor_rmsprop(
|
||||
|
||||
if momentum > 0:
|
||||
grouped_momentum_buffer_list = cast(
|
||||
List[Tensor], grouped_momentum_buffer_list_
|
||||
list[Tensor], grouped_momentum_buffer_list_
|
||||
)
|
||||
torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
|
||||
torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
|
||||
@ -461,12 +461,12 @@ def _multi_tensor_rmsprop(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
|
||||
def rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
grad_avgs: List[Tensor],
|
||||
momentum_buffer_list: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
square_avgs: list[Tensor],
|
||||
grad_avgs: list[Tensor],
|
||||
momentum_buffer_list: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the Resilient backpropagation."""
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -30,8 +30,8 @@ class Rprop(Optimizer): # noqa: D101
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-2,
|
||||
etas: Tuple[float, float] = (0.5, 1.2),
|
||||
step_sizes: Tuple[float, float] = (1e-6, 50),
|
||||
etas: tuple[float, float] = (0.5, 1.2),
|
||||
step_sizes: tuple[float, float] = (1e-6, 50),
|
||||
*,
|
||||
capturable: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
@ -129,11 +129,11 @@ class Rprop(Optimizer): # noqa: D101
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
prevs: List[Tensor] = []
|
||||
step_sizes: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
prevs: list[Tensor] = []
|
||||
step_sizes: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
etaminus, etaplus = group["etas"]
|
||||
step_size_min, step_size_max = group["step_sizes"]
|
||||
@ -219,11 +219,11 @@ Rprop.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
prevs: List[Tensor],
|
||||
step_sizes: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
prevs: list[Tensor],
|
||||
step_sizes: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
@ -287,11 +287,11 @@ def _single_tensor_rprop(
|
||||
|
||||
|
||||
def _multi_tensor_rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
prevs: List[Tensor],
|
||||
step_sizes: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
prevs: list[Tensor],
|
||||
step_sizes: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
@ -326,11 +326,11 @@ def _multi_tensor_rprop(
|
||||
grouped_step_sizes_,
|
||||
grouped_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_prevs = cast(List[Tensor], grouped_prevs_)
|
||||
grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_prevs = cast(list[Tensor], grouped_prevs_)
|
||||
grouped_step_sizes = cast(list[Tensor], grouped_step_sizes_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
@ -402,11 +402,11 @@ def _multi_tensor_rprop(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
|
||||
def rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
prevs: List[Tensor],
|
||||
step_sizes: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
prevs: list[Tensor],
|
||||
step_sizes: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Gradient Descent optimizer."""
|
||||
from typing import cast, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -114,9 +114,9 @@ class SGD(Optimizer): # noqa: D101
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
momentum_buffer_list: List[Optional[Tensor]] = []
|
||||
params: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
momentum_buffer_list: list[Optional[Tensor]] = []
|
||||
|
||||
has_sparse_grad = self._init_group(
|
||||
group, params, grads, momentum_buffer_list
|
||||
@ -244,9 +244,9 @@ SGD.__doc__ = (
|
||||
|
||||
|
||||
def sgd(
|
||||
params: List[Tensor],
|
||||
d_p_list: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
params: list[Tensor],
|
||||
d_p_list: list[Tensor],
|
||||
momentum_buffer_list: list[Optional[Tensor]],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
has_sparse_grad: bool = False,
|
||||
@ -314,9 +314,9 @@ def sgd(
|
||||
|
||||
|
||||
def _single_tensor_sgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
momentum_buffer_list: list[Optional[Tensor]],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -369,9 +369,9 @@ def _single_tensor_sgd(
|
||||
|
||||
|
||||
def _multi_tensor_sgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
momentum_buffer_list: list[Optional[Tensor]],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -397,8 +397,8 @@ def _multi_tensor_sgd(
|
||||
device_grads_,
|
||||
device_momentum_buffer_list,
|
||||
), indices in grouped_tensors.values():
|
||||
device_params: List[Tensor] = cast(List[Tensor], device_params_)
|
||||
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
|
||||
device_params: list[Tensor] = cast(list[Tensor], device_params_)
|
||||
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
|
||||
|
||||
device_has_sparse_grad = has_sparse_grad and any(
|
||||
grad.is_sparse for grad in device_grads
|
||||
@ -417,7 +417,7 @@ def _multi_tensor_sgd(
|
||||
)
|
||||
|
||||
if momentum != 0:
|
||||
bufs: List[Tensor] = []
|
||||
bufs: list[Tensor] = []
|
||||
|
||||
all_states_with_momentum_buffer = True
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
@ -462,9 +462,9 @@ def _multi_tensor_sgd(
|
||||
|
||||
|
||||
def _fused_sgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
momentum_buffer_list: list[Optional[Tensor]],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
@ -501,8 +501,8 @@ def _fused_sgd(
|
||||
(device_params_, device_grads_, device_momentum_buffer_list),
|
||||
_,
|
||||
) in grouped_tensors.items():
|
||||
device_params: List[Tensor] = cast(List[Tensor], device_params_)
|
||||
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
|
||||
device_params: list[Tensor] = cast(list[Tensor], device_params_)
|
||||
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
|
||||
device_grad_scale, device_found_inf = None, None
|
||||
if grad_scale is not None:
|
||||
device_grad_scale = grad_scale_dict.setdefault(
|
||||
@ -515,7 +515,7 @@ def _fused_sgd(
|
||||
device_grads,
|
||||
[]
|
||||
if no_momentum_buffer
|
||||
else cast(List[Tensor], device_momentum_buffer_list),
|
||||
else cast(list[Tensor], device_momentum_buffer_list),
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
lr=lr,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -16,7 +16,7 @@ class SparseAdam(Optimizer):
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
maximize: bool = False,
|
||||
):
|
||||
@ -69,11 +69,11 @@ class SparseAdam(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[int] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
exp_avgs: list[Tensor] = []
|
||||
exp_avg_sqs: list[Tensor] = []
|
||||
state_steps: list[int] = []
|
||||
beta1, beta2 = group["betas"]
|
||||
maximize = group.get("maximize", False)
|
||||
|
||||
|
@ -3,8 +3,9 @@ r"""Implementation for Stochastic Weight Averaging implementation."""
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -28,7 +29,7 @@ __all__ = [
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
|
||||
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
|
||||
|
||||
|
||||
def get_ema_multi_avg_fn(decay=0.999):
|
||||
@ -253,8 +254,8 @@ class AveragedModel(Module):
|
||||
if self.use_buffers
|
||||
else model.parameters()
|
||||
)
|
||||
self_param_detached: List[Optional[Tensor]] = []
|
||||
model_param_detached: List[Optional[Tensor]] = []
|
||||
self_param_detached: list[Optional[Tensor]] = []
|
||||
model_param_detached: list[Optional[Tensor]] = []
|
||||
for p_averaged, p_model in zip(self_param, model_param):
|
||||
p_model_ = p_model.detach().to(p_averaged.device)
|
||||
self_param_detached.append(p_averaged.detach())
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import deque
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
class DiGraph:
|
||||
@ -90,7 +89,7 @@ class DiGraph:
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def forward_transitive_closure(self, src: str) -> Set[str]:
|
||||
def forward_transitive_closure(self, src: str) -> set[str]:
|
||||
"""Returns a set of nodes that are reachable from src"""
|
||||
|
||||
result = set(src)
|
||||
@ -103,7 +102,7 @@ class DiGraph:
|
||||
working_set.append(n)
|
||||
return result
|
||||
|
||||
def backward_transitive_closure(self, src: str) -> Set[str]:
|
||||
def backward_transitive_closure(self, src: str) -> set[str]:
|
||||
"""Returns a set of nodes that are reachable from src in reverse direction"""
|
||||
|
||||
result = set(src)
|
||||
@ -140,7 +139,7 @@ class DiGraph:
|
||||
|
||||
return result_graph.to_dot()
|
||||
|
||||
def first_path(self, dst: str) -> List[str]:
|
||||
def first_path(self, dst: str) -> list[str]:
|
||||
"""Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
|
||||
path = []
|
||||
|
||||
|
@ -1,12 +1,10 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.package.package_exporter import PackagingError
|
||||
|
||||
|
||||
__all__ = ["find_first_use_of_broken_modules"]
|
||||
|
||||
|
||||
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:
|
||||
def find_first_use_of_broken_modules(exc: PackagingError) -> dict[str, list[str]]:
|
||||
"""
|
||||
Find all broken modules in a PackagingError, and for each one, return the
|
||||
dependency path in which the module was first encountered.
|
||||
|
@ -1,14 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
from typing import Any, Callable, Iterable, List, Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
__all__ = ["trace_dependencies"]
|
||||
|
||||
|
||||
def trace_dependencies(
|
||||
callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]]
|
||||
) -> List[str]:
|
||||
callable: Callable[[Any], Any], inputs: Iterable[tuple[Any, ...]]
|
||||
) -> list[str]:
|
||||
"""Trace the execution of a callable in order to determine which modules it uses.
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List
|
||||
|
||||
from .glob_group import GlobGroup, GlobPattern
|
||||
|
||||
@ -15,9 +14,9 @@ class Directory:
|
||||
def __init__(self, name: str, is_dir: bool):
|
||||
self.name = name
|
||||
self.is_dir = is_dir
|
||||
self.children: Dict[str, Directory] = {}
|
||||
self.children: dict[str, Directory] = {}
|
||||
|
||||
def _get_dir(self, dirs: List[str]) -> "Directory":
|
||||
def _get_dir(self, dirs: list[str]) -> "Directory":
|
||||
"""Builds path of Directories if not yet built and returns last directory
|
||||
in list.
|
||||
|
||||
@ -64,13 +63,13 @@ class Directory:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
str_list: List[str] = []
|
||||
str_list: list[str] = []
|
||||
self._stringify_tree(str_list)
|
||||
return "".join(str_list)
|
||||
|
||||
def _stringify_tree(
|
||||
self,
|
||||
str_list: List[str],
|
||||
str_list: list[str],
|
||||
preamble: str = "",
|
||||
dir_ptr: str = "\u2500\u2500\u2500 ",
|
||||
):
|
||||
@ -89,8 +88,8 @@ class Directory:
|
||||
else:
|
||||
preamble = preamble + space
|
||||
|
||||
file_keys: List[str] = []
|
||||
dir_keys: List[str] = []
|
||||
file_keys: list[str] = []
|
||||
dir_keys: list[str] = []
|
||||
for key, val in self.children.items():
|
||||
if val.is_dir:
|
||||
dir_keys.append(key)
|
||||
@ -109,7 +108,7 @@ class Directory:
|
||||
|
||||
def _create_directory_from_file_list(
|
||||
filename: str,
|
||||
file_list: List[str],
|
||||
file_list: list[str],
|
||||
include: "GlobPattern" = "**",
|
||||
exclude: "GlobPattern" = (),
|
||||
) -> Directory:
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import ast
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ._importlib import _resolve_name
|
||||
|
||||
@ -11,7 +11,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def run(cls, src: str, package: str) -> List[Tuple[str, Optional[str]]]:
|
||||
def run(cls, src: str, package: str) -> list[tuple[str, Optional[str]]]:
|
||||
visitor = cls(package)
|
||||
tree = ast.parse(src)
|
||||
visitor.visit(tree)
|
||||
@ -53,7 +53,7 @@ class _ExtractModuleReferences(ast.NodeVisitor):
|
||||
if hasattr(node.func, "id") and node.func.id == "__import__":
|
||||
try:
|
||||
name = self._grab_node_str(node.args[0])
|
||||
fromlist: List[str] = []
|
||||
fromlist: list[str] = []
|
||||
level = 0
|
||||
if len(node.args) > 3:
|
||||
fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import re
|
||||
from typing import Iterable, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
|
||||
GlobPattern = Union[str, Iterable[str]]
|
||||
|
@ -7,7 +7,7 @@ from pickle import ( # type: ignore[attr-defined]
|
||||
whichmodule as _pickle_whichmodule,
|
||||
)
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from ._mangling import demangle, get_mangle_prefix, is_mangled
|
||||
|
||||
@ -44,7 +44,7 @@ class Importer(ABC):
|
||||
assert obj1 is obj2
|
||||
"""
|
||||
|
||||
modules: Dict[str, ModuleType]
|
||||
modules: dict[str, ModuleType]
|
||||
|
||||
@abstractmethod
|
||||
def import_module(self, module_name: str) -> ModuleType:
|
||||
@ -53,7 +53,7 @@ class Importer(ABC):
|
||||
The contract is the same as for importlib.import_module.
|
||||
"""
|
||||
|
||||
def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
|
||||
def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]:
|
||||
"""Given an object, return a name that can be used to retrieve the
|
||||
object from this environment.
|
||||
|
||||
@ -184,7 +184,7 @@ class OrderedImporter(Importer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self._importers: List[Importer] = list(args)
|
||||
self._importers: list[Importer] = list(args)
|
||||
|
||||
def _is_torchpackage_dummy(self, module):
|
||||
"""Returns true iff this module is an empty PackageNode in a torch.package.
|
||||
|
@ -7,23 +7,12 @@ import pickletools
|
||||
import platform
|
||||
import types
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
cast,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, BinaryIO, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.serialization import location_tag, normalize_storage_type
|
||||
@ -133,7 +122,7 @@ class PackagingError(Exception):
|
||||
|
||||
def __init__(self, dependency_graph: DiGraph, debug=False):
|
||||
# Group errors by reason.
|
||||
broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
|
||||
broken: dict[PackagingErrorReason, list[str]] = defaultdict(list)
|
||||
for module_name, attrs in dependency_graph.nodes.items():
|
||||
error = attrs.get("error")
|
||||
if error is None:
|
||||
@ -236,9 +225,9 @@ class PackageExporter:
|
||||
|
||||
self.zip_file = torch._C.PyTorchFileWriter(f)
|
||||
self.zip_file.set_min_version(6)
|
||||
self._written_files: Set[str] = set()
|
||||
self._written_files: set[str] = set()
|
||||
|
||||
self.serialized_reduces: Dict[int, Any] = {}
|
||||
self.serialized_reduces: dict[int, Any] = {}
|
||||
|
||||
# A graph tracking all the modules and pickle objects added to this
|
||||
# package and the dependencies between them.
|
||||
@ -266,7 +255,7 @@ class PackageExporter:
|
||||
)
|
||||
self.importer = OrderedImporter(*importer)
|
||||
|
||||
self.patterns: Dict[GlobGroup, _PatternInfo] = {}
|
||||
self.patterns: dict[GlobGroup, _PatternInfo] = {}
|
||||
self._unique_id = 0
|
||||
|
||||
def save_source_file(
|
||||
@ -331,7 +320,7 @@ class PackageExporter:
|
||||
|
||||
def _get_dependencies(
|
||||
self, src: str, module_name: str, is_package: bool
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Return all modules that this source code depends on.
|
||||
|
||||
Dependencies are found by scanning the source code for import-like statements.
|
||||
@ -659,7 +648,7 @@ class PackageExporter:
|
||||
all_dependencies = []
|
||||
module = None
|
||||
field = None
|
||||
memo: DefaultDict[int, str] = defaultdict(None)
|
||||
memo: defaultdict[int, str] = defaultdict(None)
|
||||
memo_count = 0
|
||||
# pickletools.dis(data_value)
|
||||
for opcode, arg, _pos in pickletools.genops(data_value):
|
||||
@ -1115,7 +1104,7 @@ class PackageExporter:
|
||||
|
||||
def _nodes_with_action_type(
|
||||
self, action: Optional[_ModuleProviderAction]
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
result = []
|
||||
for name, node_dict in self.dependency_graph.nodes.items():
|
||||
node_action = node_dict.get("action", None)
|
||||
@ -1124,7 +1113,7 @@ class PackageExporter:
|
||||
result.sort()
|
||||
return result
|
||||
|
||||
def externed_modules(self) -> List[str]:
|
||||
def externed_modules(self) -> list[str]:
|
||||
"""Return all modules that are currently externed.
|
||||
|
||||
Returns:
|
||||
@ -1133,7 +1122,7 @@ class PackageExporter:
|
||||
"""
|
||||
return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
|
||||
|
||||
def interned_modules(self) -> List[str]:
|
||||
def interned_modules(self) -> list[str]:
|
||||
"""Return all modules that are currently interned.
|
||||
|
||||
Returns:
|
||||
@ -1142,7 +1131,7 @@ class PackageExporter:
|
||||
"""
|
||||
return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
|
||||
|
||||
def mocked_modules(self) -> List[str]:
|
||||
def mocked_modules(self) -> list[str]:
|
||||
"""Return all modules that are currently mocked.
|
||||
|
||||
Returns:
|
||||
@ -1151,7 +1140,7 @@ class PackageExporter:
|
||||
"""
|
||||
return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
|
||||
|
||||
def denied_modules(self) -> List[str]:
|
||||
def denied_modules(self) -> list[str]:
|
||||
"""Return all modules that are currently denied.
|
||||
|
||||
Returns:
|
||||
@ -1160,7 +1149,7 @@ class PackageExporter:
|
||||
"""
|
||||
return self._nodes_with_action_type(_ModuleProviderAction.DENY)
|
||||
|
||||
def get_rdeps(self, module_name: str) -> List[str]:
|
||||
def get_rdeps(self, module_name: str) -> list[str]:
|
||||
"""Return a list of all modules which depend on the module ``module_name``.
|
||||
|
||||
Returns:
|
||||
|
@ -8,19 +8,9 @@ import linecache
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@ -64,7 +54,7 @@ IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
|
||||
# The primary motivation is to enable Numpy upgrade that many modules
|
||||
# depend on. The latest release of Numpy removed `numpy.str` and
|
||||
# `numpy.bool` breaking unpickling for many modules.
|
||||
EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = {
|
||||
EXTERN_IMPORT_COMPAT_NAME_MAPPING: dict[str, dict[str, Any]] = {
|
||||
"numpy": {
|
||||
"str": str,
|
||||
"bool": bool,
|
||||
@ -90,7 +80,7 @@ class PackageImporter(Importer):
|
||||
local to this importer.
|
||||
"""
|
||||
|
||||
modules: Dict[str, types.ModuleType]
|
||||
modules: dict[str, types.ModuleType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -646,7 +636,7 @@ class PackageImporter(Importer):
|
||||
return f"{name.replace('.', '/')}"
|
||||
|
||||
def _get_or_create_package(
|
||||
self, atoms: List[str]
|
||||
self, atoms: list[str]
|
||||
) -> "Union[_PackageNode, _ExternNode]":
|
||||
cur = self.root
|
||||
for i, atom in enumerate(atoms):
|
||||
@ -705,7 +695,7 @@ class _PathNode:
|
||||
class _PackageNode(_PathNode):
|
||||
def __init__(self, source_file: Optional[str]):
|
||||
self.source_file = source_file
|
||||
self.children: Dict[str, _PathNode] = {}
|
||||
self.children: dict[str, _PathNode] = {}
|
||||
|
||||
|
||||
class _ModuleNode(_PathNode):
|
||||
|
@ -4,7 +4,8 @@ import dataclasses
|
||||
import enum
|
||||
import itertools as it
|
||||
import logging
|
||||
from typing import Any, cast, DefaultDict, Dict, Iterator, List, Optional, Set, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
@ -226,7 +227,7 @@ class SchemaMatcher:
|
||||
overload. If we cannot find any valid schema then we must be
|
||||
conservative and assume all inputs are mutable.
|
||||
"""
|
||||
mutable: Optional[List[bool]] = None
|
||||
mutable: Optional[list[bool]] = None
|
||||
for schema in cls.match_schemas(t):
|
||||
mutable = mutable or [False for _ in schema.arguments]
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
@ -327,7 +328,7 @@ class OpTree:
|
||||
|
||||
class SizeMap:
|
||||
def __init__(self, op_tree: OpTree) -> None:
|
||||
self._values: Dict[TensorKey, int] = {}
|
||||
self._values: dict[TensorKey, int] = {}
|
||||
|
||||
for node in op_tree.sorted_nodes:
|
||||
if node.typed[0] == _EventType.TorchOp:
|
||||
@ -349,7 +350,7 @@ class SizeMap:
|
||||
for _, t in state:
|
||||
self._update_values(t)
|
||||
|
||||
allocations: Dict[TensorKey, int] = {}
|
||||
allocations: dict[TensorKey, int] = {}
|
||||
for node in op_tree.sorted_nodes:
|
||||
if node.typed[0] == _EventType.Allocation:
|
||||
alloc_fields = node.typed[1]
|
||||
@ -410,7 +411,7 @@ class DataFlowNode:
|
||||
def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
|
||||
self._event = event
|
||||
self._graph = graph
|
||||
self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
|
||||
self._edges: dict[TensorKey, DataFlowEdge] = self._determine_edges()
|
||||
|
||||
for key, edge in self._edges.items():
|
||||
if edge.mutated and not edge.is_allocation:
|
||||
@ -420,11 +421,11 @@ class DataFlowNode:
|
||||
versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
|
||||
assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
|
||||
|
||||
def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
|
||||
def _determine_edges(self) -> dict[TensorKey, DataFlowEdge]:
|
||||
subtree = tuple(_utils.traverse_dfs([self._event]))
|
||||
|
||||
# Start by populating edges from op inputs and outputs.
|
||||
mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
|
||||
mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {}
|
||||
for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
|
||||
for op_input, mutable in zip(
|
||||
op.inputs, SchemaMatcher.inputs_are_mutable(op)
|
||||
@ -440,7 +441,7 @@ class DataFlowNode:
|
||||
key = TensorKey.from_tensor(op_input_i)
|
||||
mutable_by_key.setdefault(key, set()).add(mutable)
|
||||
|
||||
edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
|
||||
edges: collections.defaultdict[Optional[TensorKey], DataFlowEdge]
|
||||
edges = collections.defaultdict(DataFlowEdge)
|
||||
for key, mutable_set in mutable_by_key.items():
|
||||
if key is not None:
|
||||
@ -472,7 +473,7 @@ class DataFlowNode:
|
||||
return dict(sorted((k, v) for k, v in edges.items() if k is not None))
|
||||
|
||||
@property
|
||||
def inputs(self) -> Dict[TensorKey, tuple[bool, int]]:
|
||||
def inputs(self) -> dict[TensorKey, tuple[bool, int]]:
|
||||
return {
|
||||
# MyPy can't see through `is_allocation` to know that
|
||||
# `v.input_version` is not None.
|
||||
@ -482,7 +483,7 @@ class DataFlowNode:
|
||||
}
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[TensorKey, int]:
|
||||
def outputs(self) -> dict[TensorKey, int]:
|
||||
return {
|
||||
k: 0 if v.input_version is None else v.input_version + 1
|
||||
for k, v in self._edges.items()
|
||||
@ -504,7 +505,7 @@ class DataFlowGraph:
|
||||
def __init__(self, op_tree: OpTree) -> None:
|
||||
self._op_tree = op_tree
|
||||
self._leaf_events = self._extract_leaf_events(op_tree)
|
||||
self._active_version: Dict[TensorKey, Optional[int]] = {}
|
||||
self._active_version: dict[TensorKey, Optional[int]] = {}
|
||||
self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
|
||||
self._flow_nodes.sort(key=lambda x: x.start_time)
|
||||
self.validate()
|
||||
@ -515,7 +516,7 @@ class DataFlowGraph:
|
||||
|
||||
def validate(self):
|
||||
# Check that each (Tensor, version) pair has a unique creation node
|
||||
outputs: Set[tuple[TensorKey, int]] = set()
|
||||
outputs: set[tuple[TensorKey, int]] = set()
|
||||
for node in self.flow_nodes:
|
||||
node_outputs = set(node.outputs.items())
|
||||
duplicates = outputs & node_outputs
|
||||
@ -523,7 +524,7 @@ class DataFlowGraph:
|
||||
outputs |= node_outputs
|
||||
|
||||
# And check that `self._nodes` forms a valid topologically sorted DAG.
|
||||
tensor_versions: Dict[TensorKey, int] = {}
|
||||
tensor_versions: dict[TensorKey, int] = {}
|
||||
for node in self.flow_nodes:
|
||||
for key, (_, version) in node.inputs.items():
|
||||
expected = tensor_versions.get(key, 0)
|
||||
@ -571,7 +572,7 @@ class DataFlowGraph:
|
||||
form the logical nodes in our data flow graph.
|
||||
"""
|
||||
|
||||
leaf_events: List[_ProfilerEvent] = []
|
||||
leaf_events: list[_ProfilerEvent] = []
|
||||
|
||||
def leaf_op(e: _ProfilerEvent) -> bool:
|
||||
return e.typed[0] == _EventType.TorchOp and (
|
||||
@ -609,17 +610,17 @@ class DataFlowGraph:
|
||||
@dataclasses.dataclass
|
||||
class CategoryElement:
|
||||
by_id: Optional[Category] = None
|
||||
by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
|
||||
by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
|
||||
by_key: dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
|
||||
by_version: dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# Used by unit tests to check internals. (And consequently by
|
||||
# MemoryProfile.lookup) This should not be used in any other capacity.
|
||||
_by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
|
||||
_by_id_keyset: set[TensorKey] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CategoryDict:
|
||||
_values: DefaultDict[int, CategoryElement] = dataclasses.field(
|
||||
_values: collections.defaultdict[int, CategoryElement] = dataclasses.field(
|
||||
default_factory=lambda: collections.defaultdict(CategoryElement)
|
||||
)
|
||||
|
||||
@ -666,9 +667,9 @@ class MemoryProfile:
|
||||
|
||||
@property
|
||||
def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]:
|
||||
output: List[tuple[int, Action, KeyAndID, int]] = []
|
||||
allocation_times: Dict[tuple[TensorKey, bool], int] = {}
|
||||
live_unknown: Dict[tuple[int, torch.device], Literal[True]] = {}
|
||||
output: list[tuple[int, Action, KeyAndID, int]] = []
|
||||
allocation_times: dict[tuple[TensorKey, bool], int] = {}
|
||||
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
|
||||
for event in self._op_tree.dfs():
|
||||
if event.typed[0] == _EventType.Allocation:
|
||||
alloc_fields = event.typed[1]
|
||||
@ -701,7 +702,7 @@ class MemoryProfile:
|
||||
snapshot = self._category_snapshot()
|
||||
last_version = dict(sorted(snapshot.keys()))
|
||||
|
||||
events: List[tuple[int, Action, TensorAndID]] = [
|
||||
events: list[tuple[int, Action, TensorAndID]] = [
|
||||
(-1, Action.PREEXISTING, (key, version))
|
||||
for key, version in snapshot.keys()
|
||||
if (key, True) not in allocation_times and version == 0
|
||||
@ -734,8 +735,8 @@ class MemoryProfile:
|
||||
def _is_gradient(self, *args, **kwargs) -> bool:
|
||||
return self._categories.get(*args, **kwargs) == Category.GRADIENT
|
||||
|
||||
def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
|
||||
all_tensor_versions: Set[TensorAndID] = set()
|
||||
def _category_snapshot(self) -> dict[TensorAndID, Optional[Category]]:
|
||||
all_tensor_versions: set[TensorAndID] = set()
|
||||
|
||||
for node in self._data_flow_graph.flow_nodes:
|
||||
all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
|
||||
@ -750,7 +751,7 @@ class MemoryProfile:
|
||||
for key, version in sorted(all_tensor_versions)
|
||||
}
|
||||
|
||||
def _any_version_depends_on_gradient(self) -> Set[int]:
|
||||
def _any_version_depends_on_gradient(self) -> set[int]:
|
||||
"""Extract IDs of Tensors which depend or will depend on a gradient.
|
||||
|
||||
Note that this weakened definition of "depends" requires us to loop
|
||||
@ -761,7 +762,7 @@ class MemoryProfile:
|
||||
acyclic data flow graph into a cyclic graph and we are attempting to
|
||||
partition cycles involving a gradient from the rest of the graph.
|
||||
"""
|
||||
depends_on_gradient: Set[int] = set()
|
||||
depends_on_gradient: set[int] = set()
|
||||
while True:
|
||||
start_size = len(depends_on_gradient)
|
||||
for node in self._data_flow_graph.flow_nodes:
|
||||
@ -837,7 +838,7 @@ class MemoryProfile:
|
||||
|
||||
# We only want to annotate Tensors which actually contribute to the
|
||||
# model calculation.
|
||||
produces_gradient: Set[TensorAndID] = set()
|
||||
produces_gradient: set[TensorAndID] = set()
|
||||
for node in reversed(self._data_flow_graph.flow_nodes):
|
||||
tensors = {(key, version) for key, (_, version) in node.inputs.items()}
|
||||
tensors |= node.outputs.items()
|
||||
@ -894,8 +895,8 @@ class MemoryProfile:
|
||||
# data flow. Note this these are only candidates; we filter nodes that
|
||||
# we know are part of the backward pass but that doesn't guarantee that
|
||||
# they are part of the forward pass.
|
||||
candidate_parameters: Set[TensorAndID] = set()
|
||||
candidate_fwd_tensors: Set[TensorAndID] = {
|
||||
candidate_parameters: set[TensorAndID] = set()
|
||||
candidate_fwd_tensors: set[TensorAndID] = {
|
||||
i for i, category in snapshot.items() if category == Category.INPUT
|
||||
}
|
||||
|
||||
@ -914,7 +915,7 @@ class MemoryProfile:
|
||||
candidate_parameters |= inputs.difference(candidate_fwd_tensors)
|
||||
|
||||
# Require that each parameter eventually contributes to the value of a gradient
|
||||
used_for_gradient: Set[TensorAndID] = set()
|
||||
used_for_gradient: set[TensorAndID] = set()
|
||||
for node in reversed(self._data_flow_graph.flow_nodes):
|
||||
if any(
|
||||
self._is_gradient(*i) or i in used_for_gradient
|
||||
@ -993,8 +994,8 @@ class MemoryProfileTimeline:
|
||||
Output: [timestamps, sizes by category]
|
||||
"""
|
||||
device = torch.device(device_str)
|
||||
times: List[int] = []
|
||||
sizes: List[List[int]] = []
|
||||
times: list[int] = []
|
||||
sizes: list[list[int]] = []
|
||||
|
||||
def update(key, version, delta):
|
||||
category = (
|
||||
@ -1061,7 +1062,7 @@ class MemoryProfileTimeline:
|
||||
as a JSON formatted file to the given path for the given
|
||||
device."""
|
||||
device = torch.device(device_str)
|
||||
raw_events: List[tuple[int, int, int, int]] = []
|
||||
raw_events: list[tuple[int, int, int, int]] = []
|
||||
|
||||
def get_category_index(key, version):
|
||||
category = (
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
@ -34,7 +34,7 @@ class Pattern:
|
||||
self.url = ""
|
||||
assert prof.profiler is not None and prof.profiler.kineto_results is not None
|
||||
self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
|
||||
self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
|
||||
self.tid_root: dict[int, list[_ProfilerEvent]] = {}
|
||||
for event in self.event_tree:
|
||||
self.tid_root.setdefault(event.start_tid, []).append(event)
|
||||
|
||||
@ -55,7 +55,7 @@ class Pattern:
|
||||
"""
|
||||
yield from traverse_dfs(self.event_tree)
|
||||
|
||||
def summary(self, events: List[_ProfilerEvent]):
|
||||
def summary(self, events: list[_ProfilerEvent]):
|
||||
default_summary = f"{self.name}: {len(events)} events matched."
|
||||
if self.should_benchmark:
|
||||
# If benchmark summary is not empty, use it.
|
||||
@ -66,7 +66,7 @@ class Pattern:
|
||||
)
|
||||
return default_summary
|
||||
|
||||
def benchmark_summary(self, events: List[_ProfilerEvent]):
|
||||
def benchmark_summary(self, events: list[_ProfilerEvent]):
|
||||
def format_time(time_ns: int):
|
||||
unit_lst = ["ns", "us", "ms"]
|
||||
for unit in unit_lst:
|
||||
@ -215,7 +215,7 @@ class ExtraCUDACopyPattern(Pattern):
|
||||
return event.name in self.init_ops
|
||||
# TODO: Check if tensor is reused
|
||||
|
||||
def benchmark(self, events: List[_ProfilerEvent]):
|
||||
def benchmark(self, events: list[_ProfilerEvent]):
|
||||
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
|
||||
for shape in shapes_factor_map:
|
||||
size = shape[0]
|
||||
@ -252,7 +252,7 @@ class ForLoopIndexingPattern(Pattern):
|
||||
super().__init__(prof, should_benchmark)
|
||||
self.name = "For Loop Indexing Pattern"
|
||||
self.description = "For loop indexing detected. Vectorization recommended."
|
||||
self.visited: Set[int] = set()
|
||||
self.visited: set[int] = set()
|
||||
|
||||
def eventTreeTraversal(self):
|
||||
"""
|
||||
@ -326,7 +326,7 @@ class FP32MatMulPattern(Pattern):
|
||||
def report(self, event: _ProfilerEvent):
|
||||
return self.description
|
||||
|
||||
def benchmark(self, events: List[_ProfilerEvent]):
|
||||
def benchmark(self, events: list[_ProfilerEvent]):
|
||||
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
|
||||
for shape in shapes_factor_map:
|
||||
matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
|
||||
@ -553,7 +553,7 @@ class MatMulDimInFP16Pattern(Pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
def benchmark(self, events: List[_ProfilerEvent]):
|
||||
def benchmark(self, events: list[_ProfilerEvent]):
|
||||
def closest_multiple(shapes, multiple):
|
||||
return [multiple * math.ceil(shape / multiple) for shape in shapes]
|
||||
|
||||
@ -609,7 +609,7 @@ def report_all_anti_patterns(
|
||||
print_enable: bool = True,
|
||||
json_report_dir: Optional[str] = None,
|
||||
):
|
||||
report_dict: Dict = {}
|
||||
report_dict: dict = {}
|
||||
anti_patterns = [
|
||||
ExtraCUDACopyPattern(prof, should_benchmark),
|
||||
# ForLoopIndexingPattern(prof, should_benchmark),
|
||||
|
@ -4,7 +4,7 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.profiler import DeviceType
|
||||
@ -64,7 +64,7 @@ class EventKey:
|
||||
def __repr__(self):
|
||||
return f"{self.event.name}"
|
||||
|
||||
def intervals_overlap(self, intervals: List[Interval]):
|
||||
def intervals_overlap(self, intervals: list[Interval]):
|
||||
overlap_time = 0
|
||||
intervals = sorted(intervals, key=lambda x: x.start)
|
||||
|
||||
@ -100,13 +100,13 @@ class EventKey:
|
||||
class BasicEvaluation:
|
||||
def __init__(self, prof: profile):
|
||||
self.profile = prof
|
||||
self.metrics: Dict[EventKey, EventMetrics] = {}
|
||||
self.metrics: dict[EventKey, EventMetrics] = {}
|
||||
self.compute_self_time()
|
||||
self.event_keys = sorted(
|
||||
(e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns
|
||||
)
|
||||
self.events = [e.event for e in self.event_keys]
|
||||
self.cuda_events: List[_KinetoEvent] = []
|
||||
self.cuda_events: list[_KinetoEvent] = []
|
||||
self.queue_depth_list = self.compute_queue_depth()
|
||||
self.compute_idle_time()
|
||||
|
||||
@ -162,7 +162,7 @@ class BasicEvaluation:
|
||||
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
|
||||
)
|
||||
|
||||
kernel_mapping: Dict[_KinetoEvent, int] = {}
|
||||
kernel_mapping: dict[_KinetoEvent, int] = {}
|
||||
last_mapped_kernel = 0
|
||||
for cuda_launch_event in cuda_launch_events:
|
||||
index = index_of_first_match(
|
||||
@ -188,7 +188,7 @@ class BasicEvaluation:
|
||||
return event.start_time_ns
|
||||
raise Exception("Unknown Event Type") # noqa: TRY002
|
||||
|
||||
queue_depth_list: List[Interval] = []
|
||||
queue_depth_list: list[Interval] = []
|
||||
all_events.sort(key=new_old_event_comparator)
|
||||
for event in all_events:
|
||||
# Find latest cuda kernel event
|
||||
@ -233,7 +233,7 @@ class BasicEvaluation:
|
||||
# Based on queue_depth_list, we can calculate idle time for all the events
|
||||
idle = False
|
||||
idle_start = 0
|
||||
idle_intervals: List[Interval] = []
|
||||
idle_intervals: list[Interval] = []
|
||||
if self.queue_depth_list and self.events:
|
||||
idle_intervals += [
|
||||
Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),
|
||||
|
@ -5,9 +5,10 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from typing_extensions import Self
|
||||
from warnings import warn
|
||||
|
||||
@ -167,7 +168,7 @@ class _KinetoProfile:
|
||||
self.use_device = _get_privateuse1_backend_name()
|
||||
|
||||
# user-defined metadata to be amended to the trace
|
||||
self.preset_metadata: Dict[str, str] = {}
|
||||
self.preset_metadata: dict[str, str] = {}
|
||||
|
||||
def start(self):
|
||||
self.prepare_trace()
|
||||
@ -723,8 +724,8 @@ class profile(_KinetoProfile):
|
||||
self.current_action = self.schedule(self.step_num)
|
||||
self.step_rec_fn: Optional[prof.record_function] = None
|
||||
|
||||
self.action_map: Dict[
|
||||
tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]
|
||||
self.action_map: dict[
|
||||
tuple[ProfilerAction, Optional[ProfilerAction]], list[Any]
|
||||
] = {
|
||||
# key is (prev_action, current_action), value is action list corresponding to the state pair.
|
||||
(ProfilerAction.NONE, ProfilerAction.NONE): [],
|
||||
|
@ -1,12 +1,11 @@
|
||||
import os
|
||||
import site
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prefix_regex() -> typing.List[str]:
|
||||
def _prefix_regex() -> list[str]:
|
||||
raw_paths = (
|
||||
site.getsitepackages()
|
||||
+ sys.path
|
||||
|
@ -15,19 +15,7 @@ import threading
|
||||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
IO,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
import torch
|
||||
@ -79,7 +67,7 @@ STORAGE_KEY_SEPARATOR = ","
|
||||
|
||||
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
|
||||
MAP_LOCATION: TypeAlias = Optional[
|
||||
Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
|
||||
Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]]
|
||||
]
|
||||
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
|
||||
|
||||
@ -132,8 +120,8 @@ def mkdtemp():
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
_package_registry: List[
|
||||
Tuple[
|
||||
_package_registry: list[
|
||||
tuple[
|
||||
int,
|
||||
Callable[[STORAGE], Optional[str]],
|
||||
Callable[[STORAGE, str], Optional[STORAGE]],
|
||||
@ -270,14 +258,14 @@ def clear_safe_globals() -> None:
|
||||
_weights_only_unpickler._clear_safe_globals()
|
||||
|
||||
|
||||
def get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
|
||||
def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
|
||||
"""
|
||||
Returns the list of user-added globals that are safe for ``weights_only`` load.
|
||||
"""
|
||||
return _weights_only_unpickler._get_safe_globals()
|
||||
|
||||
|
||||
def add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]) -> None:
|
||||
def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None:
|
||||
"""
|
||||
Marks the given globals as safe for ``weights_only`` load. For example, functions
|
||||
added to this list can be called during unpickling, classes could be instantiated
|
||||
@ -338,7 +326,7 @@ class safe_globals(_weights_only_unpickler._safe_globals):
|
||||
"""
|
||||
|
||||
|
||||
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]:
|
||||
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]:
|
||||
"""Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
|
||||
|
||||
For a given function or class ``f``, the corresponding string will be of the form
|
||||
@ -804,7 +792,7 @@ class _open_zipfile_writer_buffer(_opener):
|
||||
|
||||
|
||||
def _open_zipfile_writer(name_or_buffer):
|
||||
container: Type[_opener]
|
||||
container: type[_opener]
|
||||
if _is_path(name_or_buffer):
|
||||
container = _open_zipfile_writer_file
|
||||
else:
|
||||
@ -965,15 +953,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
import torch.nn as nn
|
||||
|
||||
serialized_container_types = {}
|
||||
serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {}
|
||||
serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {}
|
||||
|
||||
# Since loading storages that view the same data with different dtypes is
|
||||
# not supported, we need to keep track of the dtype associated with each
|
||||
# storage data_ptr and throw an error if the dtype is ever different.
|
||||
# TODO: This feature could be added in the future
|
||||
storage_dtypes: Dict[int, torch.dtype] = {}
|
||||
storage_dtypes: dict[int, torch.dtype] = {}
|
||||
|
||||
def persistent_id(obj: Any) -> Optional[Tuple]:
|
||||
def persistent_id(obj: Any) -> Optional[tuple]:
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
# but torch store returns tuples. This works only in the binary protocol
|
||||
# see
|
||||
@ -1032,7 +1020,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
view_metadata: Optional[Tuple[str, int, int]]
|
||||
view_metadata: Optional[tuple[str, int, int]]
|
||||
|
||||
# Offset is always 0, but we keep it for backwards compatibility
|
||||
# with the old serialization format (which supported storage views)
|
||||
@ -1127,13 +1115,13 @@ def _save(
|
||||
_disable_byteorder_record,
|
||||
):
|
||||
serialized_storages = {}
|
||||
id_map: Dict[int, str] = {}
|
||||
id_map: dict[int, str] = {}
|
||||
|
||||
# Since loading storages that view the same data with different dtypes is
|
||||
# not supported, we need to keep track of the dtype associated with each
|
||||
# storage data_ptr and throw an error if the dtype is ever different.
|
||||
# TODO: This feature could be added in the future
|
||||
storage_dtypes: Dict[int, torch.dtype] = {}
|
||||
storage_dtypes: dict[int, torch.dtype] = {}
|
||||
|
||||
def persistent_id(obj):
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
@ -1534,7 +1522,7 @@ copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
|
||||
|
||||
|
||||
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
deserialized_objects: Dict[int, Any] = {}
|
||||
deserialized_objects: dict[int, Any] = {}
|
||||
|
||||
restore_location = _get_restore_location(map_location)
|
||||
|
||||
@ -1599,7 +1587,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
warnings.warn(msg, SourceChangeWarning)
|
||||
|
||||
def legacy_load(f):
|
||||
deserialized_objects: Dict[int, Any] = {}
|
||||
deserialized_objects: dict[int, Any] = {}
|
||||
|
||||
def persistent_load(saved_id):
|
||||
if isinstance(saved_id, tuple):
|
||||
@ -1950,7 +1938,7 @@ def _load(
|
||||
|
||||
return typed_storage
|
||||
|
||||
load_module_mapping: Dict[str, str] = {
|
||||
load_module_mapping: dict[str, str] = {
|
||||
# See https://github.com/pytorch/pytorch/pull/51633
|
||||
"torch.tensor": "torch._tensor"
|
||||
}
|
||||
|
@ -19,11 +19,11 @@ from .semi_structured import (
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
|
||||
DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]]
|
||||
else:
|
||||
# The JIT doesn't understand Union, nor torch.dtype here
|
||||
DType = int
|
||||
DimOrDims = Optional[Tuple[int]]
|
||||
DimOrDims = Optional[tuple[int]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -591,7 +591,7 @@ def as_sparse_gradcheck(gradcheck):
|
||||
"""Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
|
||||
if not isinstance(args, (list, tuple)):
|
||||
args = (args,)
|
||||
new_args: List[Any] = []
|
||||
new_args: list[Any] = []
|
||||
for obj in args:
|
||||
if (
|
||||
isinstance(obj, torch.Tensor)
|
||||
|
@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import weakref
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import warn_once
|
||||
@ -1124,7 +1124,7 @@ def _int_bsr_dense_addmm(
|
||||
right_alpha: Optional[torch.Tensor] = None,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
):
|
||||
if out is None and dense.dtype is torch.int8:
|
||||
@ -1165,7 +1165,7 @@ def bsr_dense_addmm(
|
||||
right_alpha: Optional[torch.Tensor] = None,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
):
|
||||
"""Compute
|
||||
@ -1647,7 +1647,7 @@ if has_triton():
|
||||
alpha=1.0,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
):
|
||||
f_name = "sampled_addmm"
|
||||
|
||||
@ -1731,7 +1731,7 @@ if has_triton():
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
):
|
||||
f_name = "bsr_dense_mm"
|
||||
|
@ -103,7 +103,7 @@ import inspect
|
||||
import itertools
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.hub import tqdm
|
||||
@ -937,7 +937,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True):
|
||||
dump()
|
||||
|
||||
|
||||
_operation_device_version_data: Dict[Any, Dict] = {
|
||||
_operation_device_version_data: dict[Any, dict] = {
|
||||
# Warning: the data in between the BEGIN/END DATA comment lines
|
||||
# below is generated. It can be updated either manually or via
|
||||
# calling dump function defined above.
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
@ -54,13 +54,13 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
"""
|
||||
|
||||
_DEFAULT_ALG_ID: int = 0
|
||||
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
||||
_DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
||||
_FORCE_CUTLASS: bool = False
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
BACKEND: str
|
||||
SPARSE_DISPATCH: Dict[Callable, Callable]
|
||||
SPARSE_DISPATCH: dict[Callable, Callable]
|
||||
|
||||
packed: Optional[torch.Tensor]
|
||||
meta: Optional[torch.Tensor]
|
||||
@ -161,7 +161,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
|
||||
def __tensor_flatten__(
|
||||
self,
|
||||
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
|
||||
) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]:
|
||||
inner_tensors = list(
|
||||
filter(lambda x: getattr(self, x) is not None, self.__slots__)
|
||||
)
|
||||
@ -177,7 +177,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
def __tensor_unflatten__(
|
||||
cls,
|
||||
inner_tensors,
|
||||
tensor_meta: Tuple[torch.Size, bool, int, bool],
|
||||
tensor_meta: tuple[torch.Size, bool, int, bool],
|
||||
outer_size,
|
||||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
|
@ -23,13 +23,13 @@ from .streams import Event, Stream
|
||||
_initialized = False
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_queued_calls: List[
|
||||
Tuple[Callable[[], None], List[str]]
|
||||
_queued_calls: list[
|
||||
tuple[Callable[[], None], list[str]]
|
||||
] = [] # don't invoke these until initialization occurs
|
||||
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int, None]
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
|
||||
|
||||
def _is_compiled() -> bool:
|
||||
@ -216,7 +216,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
|
||||
r"""Get the xpu capability of a device.
|
||||
|
||||
Args:
|
||||
@ -418,7 +418,7 @@ def synchronize(device: _device_t = None) -> None:
|
||||
return torch._C._xpu_synchronize(device)
|
||||
|
||||
|
||||
def get_arch_list() -> List[str]:
|
||||
def get_arch_list() -> list[str]:
|
||||
r"""Return list XPU architectures this library was compiled for."""
|
||||
if not is_available():
|
||||
return []
|
||||
|
@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch.types import Device
|
||||
@ -53,7 +53,7 @@ def reset_accumulated_memory_stats(device: _device_t = None) -> None:
|
||||
return torch._C._xpu_resetAccumulatedMemoryStats(device)
|
||||
|
||||
|
||||
def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]:
|
||||
def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]:
|
||||
r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary."""
|
||||
if not is_initialized():
|
||||
return {}
|
||||
@ -61,7 +61,7 @@ def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]:
|
||||
return torch._C._xpu_memoryStats(device)
|
||||
|
||||
|
||||
def memory_stats(device: _device_t = None) -> Dict[str, Any]:
|
||||
def memory_stats(device: _device_t = None) -> dict[str, Any]:
|
||||
r"""Return a dictionary of XPU memory allocator statistics for a given device.
|
||||
|
||||
The return value of this function is a dictionary of statistics, each of
|
||||
@ -178,7 +178,7 @@ def max_memory_reserved(device: _device_t = None) -> int:
|
||||
return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
|
||||
|
||||
|
||||
def mem_get_info(device: _device_t = None) -> Tuple[int, int]:
|
||||
def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
||||
r"""Return the global free and total GPU memory for a given device.
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Iterable, List, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -29,7 +30,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
|
||||
return default_generator.get_state()
|
||||
|
||||
|
||||
def get_rng_state_all() -> List[Tensor]:
|
||||
def get_rng_state_all() -> list[Tensor]:
|
||||
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
||||
results = [get_rng_state(i) for i in range(device_count())]
|
||||
return results
|
||||
|
Reference in New Issue
Block a user