[4/N] Apply py39 ruff and pyupgrade fixes (#143257)

```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
cyy
2025-01-04 10:47:51 +00:00
committed by PyTorch MergeBot
parent a881954b0c
commit df458be4e5
55 changed files with 247 additions and 227 deletions

View File

@ -1,8 +1,3 @@
import dis
import inspect
from collections.abc import Sequence
from typing import Union
import functorch._C
import torch
from functorch._C import dim as _C

View File

@ -27,7 +27,7 @@ from __future__ import annotations
import keyword
import warnings
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@ -73,11 +73,11 @@ class ParsedExpression:
"""
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
self.identifiers: set[Union[str, AnonymousAxis]] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
if "." in expression:
if "..." not in expression:
raise ValueError(
@ -90,7 +90,7 @@ class ParsedExpression:
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
def add_axis_name(x: str) -> None:
if x in self.identifiers:
@ -164,7 +164,7 @@ class ParsedExpression:
@staticmethod
def check_axis_name_return_reason(
name: str, allow_underscore: bool = False
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
@ -174,7 +174,7 @@ class ParsedExpression:
allow_underscore (bool): whether axis names are allowed to start with an underscore
Returns:
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
tuple[bool, str]: whether the axis name is valid, a message explaining why if not
"""
if not str.isidentifier(name):
return False, "not a valid python identifier"
@ -211,7 +211,7 @@ class ParsedExpression:
def parse_pattern(
pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]:
) -> tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args:
@ -219,7 +219,7 @@ def parse_pattern(
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
Returns:
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
"""
# adapted from einops.einops._prepare_transformation_recipe
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import functools
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
from typing import Callable, TYPE_CHECKING, Union
import torch
from functorch._C import dim as _C
@ -18,7 +18,6 @@ from ._parsing import (
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["rearrange"]
dims = _C.dims
@ -69,9 +68,9 @@ def _create_rearrange_callable(
# an identity rearrangement on a 0-dimension tensor
return lambda tensor: tensor
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
anon_axes: List[AnonymousAxis] = []
first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
anon_axes: list[AnonymousAxis] = []
# map the left-hand side identifiers to strings representing first class dims
dims_i = 0
@ -99,11 +98,11 @@ def _create_rearrange_callable(
raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
) -> List[Union[str, Tuple[str, ...]]]:
composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
) -> list[Union[str, tuple[str, ...]]]:
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
class dims."""
dim_composition: List[Union[str, Tuple[str, ...]]] = []
dim_composition: list[Union[str, tuple[str, ...]]] = []
for dimension in composition:
if isinstance(dimension, list):
dim_composition.append(
@ -152,7 +151,7 @@ def _create_rearrange_callable(
def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
pattern: str,
**axes_lengths: int,
) -> torch.Tensor: