mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585: Missed conversions (#145342)
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
8696e59ae2
commit
7178b827d7
@ -20,7 +20,7 @@ import types
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import ( # noqa: F401 # (Dict, List, Tuple) imported by torch.jit.annotations
|
||||
from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -1125,7 +1125,8 @@ def _get_overloaded_methods(method, mod_class):
|
||||
|
||||
|
||||
def is_tuple(ann) -> bool:
|
||||
if ann is Tuple:
|
||||
# Check for typing.Tuple missing args (but `tuple` is fine)
|
||||
if ann is typing.Tuple: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Tuple")
|
||||
|
||||
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
|
||||
@ -1133,35 +1134,31 @@ def is_tuple(ann) -> bool:
|
||||
return False
|
||||
|
||||
ann_origin = get_origin(ann)
|
||||
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
|
||||
return True
|
||||
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
|
||||
return ann.__module__ in ("builtins", "typing") and ann_origin is tuple
|
||||
|
||||
|
||||
def is_list(ann) -> bool:
|
||||
if ann is List:
|
||||
# Check for typing.List missing args (but `list` is fine)
|
||||
if ann is typing.List: # noqa: UP006
|
||||
raise_error_container_parameter_missing("List")
|
||||
|
||||
if not hasattr(ann, "__module__"):
|
||||
return False
|
||||
|
||||
ann_origin = get_origin(ann)
|
||||
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
|
||||
return True
|
||||
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
|
||||
return ann.__module__ in ("builtins", "typing") and ann_origin is list
|
||||
|
||||
|
||||
def is_dict(ann) -> bool:
|
||||
if ann is Dict:
|
||||
# Check for typing.Dict missing args (but `dict` is fine)
|
||||
if ann is typing.Dict: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Dict")
|
||||
|
||||
if not hasattr(ann, "__module__"):
|
||||
return False
|
||||
|
||||
ann_origin = get_origin(ann)
|
||||
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
|
||||
return True
|
||||
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
|
||||
return ann.__module__ in ("builtins", "typing") and ann_origin is dict
|
||||
|
||||
|
||||
def is_union(ann):
|
||||
@ -1371,11 +1368,11 @@ def raise_error_container_parameter_missing(target_type) -> None:
|
||||
|
||||
|
||||
def check_args_exist(target_type) -> None:
|
||||
if target_type is List or target_type is list:
|
||||
if target_type is typing.List or target_type is list: # noqa: UP006
|
||||
raise_error_container_parameter_missing("List")
|
||||
elif target_type is Tuple or target_type is tuple:
|
||||
elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Tuple")
|
||||
elif target_type is Dict or target_type is dict:
|
||||
elif target_type is typing.Dict or target_type is dict: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Dict")
|
||||
elif target_type is None or target_type is Optional:
|
||||
raise_error_container_parameter_missing("Optional")
|
||||
@ -1399,7 +1396,7 @@ def container_checker(obj, target_type) -> bool:
|
||||
check_args_exist(target_type)
|
||||
if origin_type is None:
|
||||
return False
|
||||
elif origin_type is list or origin_type is List:
|
||||
elif origin_type is list or origin_type is typing.List: # noqa: UP006
|
||||
check_empty_containers(obj)
|
||||
if not isinstance(obj, list):
|
||||
return False
|
||||
@ -1413,7 +1410,7 @@ def container_checker(obj, target_type) -> bool:
|
||||
elif not isinstance(el, arg_type):
|
||||
return False
|
||||
return True
|
||||
elif origin_type is Dict or origin_type is dict:
|
||||
elif origin_type is typing.Dict or origin_type is dict: # noqa: UP006
|
||||
check_empty_containers(obj)
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
@ -1430,7 +1427,7 @@ def container_checker(obj, target_type) -> bool:
|
||||
elif not isinstance(val, val_type):
|
||||
return False
|
||||
return True
|
||||
elif origin_type is Tuple or origin_type is tuple:
|
||||
elif origin_type is typing.Tuple or origin_type is tuple: # noqa: UP006
|
||||
check_empty_containers(obj)
|
||||
if not isinstance(obj, tuple):
|
||||
return False
|
||||
|
@ -209,8 +209,8 @@ def derived_types(
|
||||
|
||||
def derived_seq_types(typ: Union[type, typing._SpecialForm]):
|
||||
return (
|
||||
typing.Sequence[typ], # type: ignore[valid-type]
|
||||
typing.List[typ], # type: ignore[valid-type]
|
||||
typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
|
||||
typing.List[typ], # type: ignore[valid-type] # noqa: UP006
|
||||
GenericAlias(collections.abc.Sequence, (typ,)),
|
||||
GenericAlias(list, (typ,)),
|
||||
)
|
||||
@ -252,7 +252,7 @@ def get_supported_param_types():
|
||||
|
||||
SUPPORTED_RETURN_TYPES = {
|
||||
Tensor: "Tensor",
|
||||
typing.List[Tensor]: "Tensor[]",
|
||||
typing.List[Tensor]: "Tensor[]", # noqa: UP006
|
||||
list[Tensor]: "Tensor[]",
|
||||
int: "SymInt",
|
||||
float: "float",
|
||||
@ -306,7 +306,7 @@ def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
|
||||
# Account for different python versions, e.g. python 3.8 would give ()
|
||||
# but python 3.12 would give None.
|
||||
if (
|
||||
tuple_type is typing.Tuple
|
||||
tuple_type is typing.Tuple # noqa: UP006
|
||||
or tuple_type is tuple
|
||||
or type_args == ()
|
||||
or type_args is None
|
||||
|
@ -2,7 +2,7 @@
|
||||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, cast, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -787,7 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
||||
FutureWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
|
||||
return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
|
||||
else:
|
||||
raise ValueError(f"Unsupported group type: {type(group)}, {group}")
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import abc
|
||||
import io
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, IO, Optional, Type
|
||||
from typing import cast, IO, Optional
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
@ -189,11 +189,11 @@ class ZStandard(StreamTransformExtension):
|
||||
class ExtensionRegistry:
|
||||
def __init__(self) -> None:
|
||||
# Populate default registry contents
|
||||
self.extensions: dict[str, Type[Extension]] = {
|
||||
self.extensions: dict[str, type[Extension]] = {
|
||||
cls.registry_name(): cls for cls in (ZStandard,)
|
||||
}
|
||||
|
||||
def register(self, cls: Type[Extension]) -> None:
|
||||
def register(self, cls: type[Extension]) -> None:
|
||||
self.extensions[cls.registry_name()] = cls
|
||||
|
||||
def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd):
|
||||
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
|
||||
)
|
||||
|
||||
min_sizes = torch.jit.annotate(List[int], [])
|
||||
max_sizes = torch.jit.annotate(List[int], [])
|
||||
min_sizes = torch.jit.annotate(list[int], [])
|
||||
max_sizes = torch.jit.annotate(list[int], [])
|
||||
for d in range(num_spatial_dims):
|
||||
dim_size = (
|
||||
(input.size(d + num_non_spatial_dims) - 1) * stride[d]
|
||||
@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd):
|
||||
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
|
||||
)
|
||||
|
||||
res = torch.jit.annotate(List[int], [])
|
||||
res = torch.jit.annotate(list[int], [])
|
||||
for d in range(num_spatial_dims):
|
||||
res.append(output_size[d] - min_sizes[d])
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import typing
|
||||
from typing import cast, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -20,7 +21,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||
_tensor_or_tensors = Union[
|
||||
torch.Tensor,
|
||||
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
|
||||
]
|
||||
|
||||
|
||||
def _no_grad(func):
|
||||
@ -73,13 +77,13 @@ def _get_total_norm(
|
||||
if len(tensors) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = tensors[0].device
|
||||
grouped_tensors: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
grouped_tensors: dict[
|
||||
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[tensors] # type: ignore[list-item]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
norms: list[Tensor] = []
|
||||
for (device, _), ([device_tensors], _) in grouped_tensors.items():
|
||||
if (foreach is None and _has_foreach_support(device_tensors, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
@ -146,8 +150,8 @@ def _clip_grads_with_norm_(
|
||||
max_norm = float(max_norm)
|
||||
if len(grads) == 0:
|
||||
return
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
grouped_grads: dict[
|
||||
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
@ -269,10 +273,10 @@ def clip_grad_value_(
|
||||
for (device, _), ([grads], _) in grouped_grads.items():
|
||||
if (
|
||||
foreach is None
|
||||
and _has_foreach_support(cast(List[Tensor], grads), device=device)
|
||||
and _has_foreach_support(cast(list[Tensor], grads), device=device)
|
||||
) or (foreach and _device_has_foreach_support(device)):
|
||||
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
|
||||
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
|
||||
torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value)
|
||||
torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value)
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -30,7 +30,7 @@ class Adamax(Optimizer):
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 2e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0,
|
||||
foreach: Optional[bool] = None,
|
||||
@ -134,11 +134,11 @@ class Adamax(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_infs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
exp_avgs: list[Tensor] = []
|
||||
exp_infs: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
@ -223,11 +223,11 @@ Adamax.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_infs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_infs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
eps: float,
|
||||
beta1: float,
|
||||
@ -297,11 +297,11 @@ def _single_tensor_adamax(
|
||||
|
||||
|
||||
def _multi_tensor_adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_infs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_infs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
eps: float,
|
||||
beta1: float,
|
||||
@ -339,11 +339,11 @@ def _multi_tensor_adamax(
|
||||
grouped_exp_infs_,
|
||||
grouped_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_infs = cast(list[Tensor], grouped_exp_infs_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
@ -389,7 +389,7 @@ def _multi_tensor_adamax(
|
||||
torch._foreach_add_(grouped_grads, eps)
|
||||
torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
|
||||
|
||||
bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_corrections: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
if capturable:
|
||||
bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
|
||||
# foreach_sub doesn't allow a scalar as the first arg
|
||||
@ -410,11 +410,11 @@ def _multi_tensor_adamax(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
|
||||
def adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_infs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_infs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the NAdam algorithm."""
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -32,7 +32,7 @@ class NAdam(Optimizer): # noqa: D101
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 2e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0,
|
||||
momentum_decay: float = 4e-3,
|
||||
@ -167,13 +167,13 @@ class NAdam(Optimizer): # noqa: D101
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
mu_products: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
exp_avgs: list[Tensor] = []
|
||||
exp_avg_sqs: list[Tensor] = []
|
||||
mu_products: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
beta1, beta2 = cast(tuple[float, float], group["betas"])
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
@ -277,12 +277,12 @@ NAdam.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
mu_products: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
@ -371,12 +371,12 @@ def _single_tensor_nadam(
|
||||
|
||||
|
||||
def _multi_tensor_nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
mu_products: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
@ -417,12 +417,12 @@ def _multi_tensor_nadam(
|
||||
grouped_mu_products_,
|
||||
grouped_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
|
||||
grouped_mu_products = cast(List[Tensor], grouped_mu_products_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
|
||||
grouped_mu_products = cast(list[Tensor], grouped_mu_products_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
# handle complex
|
||||
if has_complex:
|
||||
@ -469,9 +469,9 @@ def _multi_tensor_nadam(
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
|
||||
|
||||
bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
mus: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
mus: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
mu_nexts: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
if capturable:
|
||||
# mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
|
||||
exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
|
||||
@ -579,12 +579,12 @@ def _multi_tensor_nadam(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
|
||||
def nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
mu_products: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RAdam algorithm."""
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -31,7 +31,7 @@ class RAdam(Optimizer): # noqa: D101
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0,
|
||||
decoupled_weight_decay: bool = False,
|
||||
@ -138,12 +138,12 @@ class RAdam(Optimizer): # noqa: D101
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
params_with_grad: list[Tensor] = []
|
||||
grads: list[Tensor] = []
|
||||
exp_avgs: list[Tensor] = []
|
||||
exp_avg_sqs: list[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
beta1, beta2 = cast(tuple[float, float], group["betas"])
|
||||
|
||||
has_complex = self._init_group(
|
||||
group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
|
||||
@ -252,11 +252,11 @@ RAdam.__doc__ = (
|
||||
|
||||
|
||||
def _single_tensor_radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
@ -351,11 +351,11 @@ def _single_tensor_radam(
|
||||
|
||||
|
||||
def _multi_tensor_radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
@ -394,11 +394,11 @@ def _multi_tensor_radam(
|
||||
grouped_exp_avg_sqs_,
|
||||
grouped_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
grouped_params = cast(List[Tensor], grouped_params_)
|
||||
grouped_grads = cast(List[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
|
||||
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
|
||||
grouped_params = cast(list[Tensor], grouped_params_)
|
||||
grouped_grads = cast(list[Tensor], grouped_grads_)
|
||||
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
|
||||
grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
|
||||
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
@ -422,9 +422,9 @@ def _multi_tensor_radam(
|
||||
# maximum length of the approximated SMA
|
||||
rho_inf = 2 / (1 - beta2) - 1
|
||||
# compute the length of the approximated SMA
|
||||
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
rho_t_list: Union[tuple[Tensor, ...], list[Tensor]]
|
||||
if capturable:
|
||||
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
|
||||
torch._foreach_neg_(bias_correction1)
|
||||
@ -547,11 +547,11 @@ def _multi_tensor_radam(
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
|
||||
def radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
params: list[Tensor],
|
||||
grads: list[Tensor],
|
||||
exp_avgs: list[Tensor],
|
||||
exp_avg_sqs: list[Tensor],
|
||||
state_steps: list[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
|
@ -1,15 +1,16 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from re import match as _match
|
||||
from typing import List, Optional, Sequence, Set, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def read_file(fname: Union[Path, str]) -> List[str]:
|
||||
def read_file(fname: Union[Path, str]) -> list[str]:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
return f.readlines()
|
||||
|
||||
|
||||
def _embed_headers(
|
||||
content: List[str], include_dirs: List[Path], processed_files: Set[str]
|
||||
content: list[str], include_dirs: list[Path], processed_files: set[str]
|
||||
) -> str:
|
||||
for line_idx, cur_line in enumerate(content):
|
||||
m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line)
|
||||
|
@ -32,7 +32,6 @@ from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
@ -747,7 +746,7 @@ class TreeSpec:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: List[PyTree]) -> None:
|
||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
|
||||
if treespec.is_leaf():
|
||||
subtrees.append(tree)
|
||||
return
|
||||
@ -881,7 +880,7 @@ def tree_flatten(
|
||||
to reconstruct the pytree.
|
||||
"""
|
||||
|
||||
def helper(node: PyTree, leaves: List[Any]) -> TreeSpec:
|
||||
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
|
||||
if _is_leaf(node, is_leaf=is_leaf):
|
||||
leaves.append(node)
|
||||
return _LEAF_SPEC
|
||||
|
Reference in New Issue
Block a user