PEP585: Missed conversions (#145342)

Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-25 11:54:14 -08:00
committed by PyTorch MergeBot
parent 8696e59ae2
commit 7178b827d7
11 changed files with 144 additions and 143 deletions

View File

@ -20,7 +20,7 @@ import types
import typing
import warnings
import weakref
from typing import ( # noqa: F401 # (Dict, List, Tuple) imported by torch.jit.annotations
from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations
Any,
Callable,
Dict,
@ -1125,7 +1125,8 @@ def _get_overloaded_methods(method, mod_class):
def is_tuple(ann) -> bool:
if ann is Tuple:
# Check for typing.Tuple missing args (but `tuple` is fine)
if ann is typing.Tuple: # noqa: UP006
raise_error_container_parameter_missing("Tuple")
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
@ -1133,35 +1134,31 @@ def is_tuple(ann) -> bool:
return False
ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
return True
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
return ann.__module__ in ("builtins", "typing") and ann_origin is tuple
def is_list(ann) -> bool:
if ann is List:
# Check for typing.List missing args (but `list` is fine)
if ann is typing.List: # noqa: UP006
raise_error_container_parameter_missing("List")
if not hasattr(ann, "__module__"):
return False
ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
return True
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
return ann.__module__ in ("builtins", "typing") and ann_origin is list
def is_dict(ann) -> bool:
if ann is Dict:
# Check for typing.Dict missing args (but `dict` is fine)
if ann is typing.Dict: # noqa: UP006
raise_error_container_parameter_missing("Dict")
if not hasattr(ann, "__module__"):
return False
ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
return True
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
return ann.__module__ in ("builtins", "typing") and ann_origin is dict
def is_union(ann):
@ -1371,11 +1368,11 @@ def raise_error_container_parameter_missing(target_type) -> None:
def check_args_exist(target_type) -> None:
if target_type is List or target_type is list:
if target_type is typing.List or target_type is list: # noqa: UP006
raise_error_container_parameter_missing("List")
elif target_type is Tuple or target_type is tuple:
elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006
raise_error_container_parameter_missing("Tuple")
elif target_type is Dict or target_type is dict:
elif target_type is typing.Dict or target_type is dict: # noqa: UP006
raise_error_container_parameter_missing("Dict")
elif target_type is None or target_type is Optional:
raise_error_container_parameter_missing("Optional")
@ -1399,7 +1396,7 @@ def container_checker(obj, target_type) -> bool:
check_args_exist(target_type)
if origin_type is None:
return False
elif origin_type is list or origin_type is List:
elif origin_type is list or origin_type is typing.List: # noqa: UP006
check_empty_containers(obj)
if not isinstance(obj, list):
return False
@ -1413,7 +1410,7 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(el, arg_type):
return False
return True
elif origin_type is Dict or origin_type is dict:
elif origin_type is typing.Dict or origin_type is dict: # noqa: UP006
check_empty_containers(obj)
if not isinstance(obj, dict):
return False
@ -1430,7 +1427,7 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(val, val_type):
return False
return True
elif origin_type is Tuple or origin_type is tuple:
elif origin_type is typing.Tuple or origin_type is tuple: # noqa: UP006
check_empty_containers(obj)
if not isinstance(obj, tuple):
return False

View File

@ -209,8 +209,8 @@ def derived_types(
def derived_seq_types(typ: Union[type, typing._SpecialForm]):
return (
typing.Sequence[typ], # type: ignore[valid-type]
typing.List[typ], # type: ignore[valid-type]
typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
typing.List[typ], # type: ignore[valid-type] # noqa: UP006
GenericAlias(collections.abc.Sequence, (typ,)),
GenericAlias(list, (typ,)),
)
@ -252,7 +252,7 @@ def get_supported_param_types():
SUPPORTED_RETURN_TYPES = {
Tensor: "Tensor",
typing.List[Tensor]: "Tensor[]",
typing.List[Tensor]: "Tensor[]", # noqa: UP006
list[Tensor]: "Tensor[]",
int: "SymInt",
float: "float",
@ -306,7 +306,7 @@ def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
# Account for different python versions, e.g. python 3.8 would give ()
# but python 3.12 would give None.
if (
tuple_type is typing.Tuple
tuple_type is typing.Tuple # noqa: UP006
or tuple_type is tuple
or type_args == ()
or type_args is None

View File

@ -2,7 +2,7 @@
import contextlib
import sys
import warnings
from typing import Any, cast, List, Optional, TYPE_CHECKING, Union
from typing import Any, cast, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -787,7 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
FutureWarning,
stacklevel=3,
)
return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
else:
raise ValueError(f"Unsupported group type: {type(group)}, {group}")

View File

@ -3,7 +3,7 @@
import abc
import io
from collections.abc import Sequence
from typing import cast, IO, Optional, Type
from typing import cast, IO, Optional
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer
@ -189,11 +189,11 @@ class ZStandard(StreamTransformExtension):
class ExtensionRegistry:
def __init__(self) -> None:
# Populate default registry contents
self.extensions: dict[str, Type[Extension]] = {
self.extensions: dict[str, type[Extension]] = {
cls.registry_name(): cls for cls in (ZStandard,)
}
def register(self, cls: Type[Extension]) -> None:
def register(self, cls: type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls
def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import deprecated
import torch
@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd):
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
)
min_sizes = torch.jit.annotate(List[int], [])
max_sizes = torch.jit.annotate(List[int], [])
min_sizes = torch.jit.annotate(list[int], [])
max_sizes = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims):
dim_size = (
(input.size(d + num_non_spatial_dims) - 1) * stride[d]
@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd):
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
)
res = torch.jit.annotate(List[int], [])
res = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims):
res.append(output_size[d] - min_sizes[d])

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
import typing
from typing import cast, Optional, Union
from typing_extensions import deprecated
import torch
@ -20,7 +21,10 @@ __all__ = [
]
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
_tensor_or_tensors = Union[
torch.Tensor,
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
]
def _no_grad(func):
@ -73,13 +77,13 @@ def _get_total_norm(
if len(tensors) == 0:
return torch.tensor(0.0)
first_device = tensors[0].device
grouped_tensors: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
grouped_tensors: dict[
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
] = _group_tensors_by_device_and_dtype(
[tensors] # type: ignore[list-item]
) # type: ignore[assignment]
norms: List[Tensor] = []
norms: list[Tensor] = []
for (device, _), ([device_tensors], _) in grouped_tensors.items():
if (foreach is None and _has_foreach_support(device_tensors, device)) or (
foreach and _device_has_foreach_support(device)
@ -146,8 +150,8 @@ def _clip_grads_with_norm_(
max_norm = float(max_norm)
if len(grads) == 0:
return
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
grouped_grads: dict[
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
@ -269,10 +273,10 @@ def clip_grad_value_(
for (device, _), ([grads], _) in grouped_grads.items():
if (
foreach is None
and _has_foreach_support(cast(List[Tensor], grads), device=device)
and _has_foreach_support(cast(list[Tensor], grads), device=device)
) or (foreach and _device_has_foreach_support(device)):
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value)
torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value)
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -30,7 +30,7 @@ class Adamax(Optimizer):
self,
params: ParamsT,
lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999),
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
foreach: Optional[bool] = None,
@ -134,11 +134,11 @@ class Adamax(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_infs: List[Tensor] = []
state_steps: List[Tensor] = []
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_infs: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = group["betas"]
eps = group["eps"]
@ -223,11 +223,11 @@ Adamax.__doc__ = (
def _single_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_infs: list[Tensor],
state_steps: list[Tensor],
*,
eps: float,
beta1: float,
@ -297,11 +297,11 @@ def _single_tensor_adamax(
def _multi_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_infs: list[Tensor],
state_steps: list[Tensor],
*,
eps: float,
beta1: float,
@ -339,11 +339,11 @@ def _multi_tensor_adamax(
grouped_exp_infs_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_infs = cast(list[Tensor], grouped_exp_infs_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
if has_complex:
_view_as_real(
@ -389,7 +389,7 @@ def _multi_tensor_adamax(
torch._foreach_add_(grouped_grads, eps)
torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
bias_corrections: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
@ -410,11 +410,11 @@ def _multi_tensor_adamax(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
def adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_infs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -32,7 +32,7 @@ class NAdam(Optimizer): # noqa: D101
self,
params: ParamsT,
lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999),
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
momentum_decay: float = 4e-3,
@ -167,13 +167,13 @@ class NAdam(Optimizer): # noqa: D101
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
mu_products: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"])
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_avg_sqs: list[Tensor] = []
mu_products: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = cast(tuple[float, float], group["betas"])
has_complex = self._init_group(
group,
@ -277,12 +277,12 @@ NAdam.__doc__ = (
def _single_tensor_nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
mu_products: list[Tensor],
state_steps: list[Tensor],
*,
beta1: float,
beta2: float,
@ -371,12 +371,12 @@ def _single_tensor_nadam(
def _multi_tensor_nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
mu_products: list[Tensor],
state_steps: list[Tensor],
*,
beta1: float,
beta2: float,
@ -417,12 +417,12 @@ def _multi_tensor_nadam(
grouped_mu_products_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
grouped_mu_products = cast(List[Tensor], grouped_mu_products_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
grouped_mu_products = cast(list[Tensor], grouped_mu_products_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
# handle complex
if has_complex:
@ -469,9 +469,9 @@ def _multi_tensor_nadam(
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
mus: Union[Tuple[Tensor, ...], List[Tensor]]
mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
mus: Union[tuple[Tensor, ...], list[Tensor]]
mu_nexts: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
# mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
@ -579,12 +579,12 @@ def _multi_tensor_nadam(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
def nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
mu_products: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor
@ -31,7 +31,7 @@ class RAdam(Optimizer): # noqa: D101
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
decoupled_weight_decay: bool = False,
@ -138,12 +138,12 @@ class RAdam(Optimizer): # noqa: D101
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"])
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_avg_sqs: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = cast(tuple[float, float], group["betas"])
has_complex = self._init_group(
group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
@ -252,11 +252,11 @@ RAdam.__doc__ = (
def _single_tensor_radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
*,
beta1: float,
beta2: float,
@ -351,11 +351,11 @@ def _single_tensor_radam(
def _multi_tensor_radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
*,
beta1: float,
beta2: float,
@ -394,11 +394,11 @@ def _multi_tensor_radam(
grouped_exp_avg_sqs_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
@ -422,9 +422,9 @@ def _multi_tensor_radam(
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
rho_t_list: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction1)
@ -547,11 +547,11 @@ def _multi_tensor_radam(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
def radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,

View File

@ -1,15 +1,16 @@
from collections.abc import Sequence
from pathlib import Path
from re import match as _match
from typing import List, Optional, Sequence, Set, Union
from typing import Optional, Union
def read_file(fname: Union[Path, str]) -> List[str]:
def read_file(fname: Union[Path, str]) -> list[str]:
with open(fname, encoding="utf-8") as f:
return f.readlines()
def _embed_headers(
content: List[str], include_dirs: List[Path], processed_files: Set[str]
content: list[str], include_dirs: list[Path], processed_files: set[str]
) -> str:
for line_idx, cur_line in enumerate(content):
m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line)

View File

@ -32,7 +32,6 @@ from typing import (
Callable,
cast,
Generic,
List,
Optional,
overload,
Protocol,
@ -747,7 +746,7 @@ class TreeSpec:
return self.num_nodes == 1 and self.num_leaves == 1
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: List[PyTree]) -> None:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf():
subtrees.append(tree)
return
@ -881,7 +880,7 @@ def tree_flatten(
to reconstruct the pytree.
"""
def helper(node: PyTree, leaves: List[Any]) -> TreeSpec:
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
if _is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC