mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257 Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
@ -1,8 +1,3 @@
|
||||
import dis
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import functorch._C
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
|
@ -27,7 +27,7 @@ from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
import warnings
|
||||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -73,11 +73,11 @@ class ParsedExpression:
|
||||
"""
|
||||
self.has_ellipsis: bool = False
|
||||
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
||||
self.identifiers: set[Union[str, AnonymousAxis]] = set()
|
||||
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
||||
self.has_non_unitary_anonymous_axes: bool = False
|
||||
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
||||
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
||||
self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
|
||||
if "." in expression:
|
||||
if "..." not in expression:
|
||||
raise ValueError(
|
||||
@ -90,7 +90,7 @@ class ParsedExpression:
|
||||
expression = expression.replace("...", _ellipsis)
|
||||
self.has_ellipsis = True
|
||||
|
||||
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
||||
bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
|
||||
|
||||
def add_axis_name(x: str) -> None:
|
||||
if x in self.identifiers:
|
||||
@ -164,7 +164,7 @@ class ParsedExpression:
|
||||
@staticmethod
|
||||
def check_axis_name_return_reason(
|
||||
name: str, allow_underscore: bool = False
|
||||
) -> Tuple[bool, str]:
|
||||
) -> tuple[bool, str]:
|
||||
"""Check if the given axis name is valid, and a message explaining why if not.
|
||||
|
||||
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
||||
@ -174,7 +174,7 @@ class ParsedExpression:
|
||||
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
||||
tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
||||
"""
|
||||
if not str.isidentifier(name):
|
||||
return False, "not a valid python identifier"
|
||||
@ -211,7 +211,7 @@ class ParsedExpression:
|
||||
|
||||
def parse_pattern(
|
||||
pattern: str, axes_lengths: Mapping[str, int]
|
||||
) -> Tuple[ParsedExpression, ParsedExpression]:
|
||||
) -> tuple[ParsedExpression, ParsedExpression]:
|
||||
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
||||
|
||||
Args:
|
||||
@ -219,7 +219,7 @@ def parse_pattern(
|
||||
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
||||
|
||||
Returns:
|
||||
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
||||
tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
||||
"""
|
||||
# adapted from einops.einops._prepare_transformation_recipe
|
||||
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
@ -18,7 +18,6 @@ from ._parsing import (
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
dims = _C.dims
|
||||
@ -69,9 +68,9 @@ def _create_rearrange_callable(
|
||||
# an identity rearrangement on a 0-dimension tensor
|
||||
return lambda tensor: tensor
|
||||
|
||||
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
||||
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
||||
anon_axes: List[AnonymousAxis] = []
|
||||
first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
||||
identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
|
||||
anon_axes: list[AnonymousAxis] = []
|
||||
|
||||
# map the left-hand side identifiers to strings representing first class dims
|
||||
dims_i = 0
|
||||
@ -99,11 +98,11 @@ def _create_rearrange_callable(
|
||||
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||
|
||||
def composition_to_dims(
|
||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
|
||||
) -> List[Union[str, Tuple[str, ...]]]:
|
||||
composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
|
||||
) -> list[Union[str, tuple[str, ...]]]:
|
||||
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
||||
class dims."""
|
||||
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
||||
dim_composition: list[Union[str, tuple[str, ...]]] = []
|
||||
for dimension in composition:
|
||||
if isinstance(dimension, list):
|
||||
dim_composition.append(
|
||||
@ -152,7 +151,7 @@ def _create_rearrange_callable(
|
||||
|
||||
|
||||
def rearrange(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
pattern: str,
|
||||
**axes_lengths: int,
|
||||
) -> torch.Tensor:
|
||||
|
Reference in New Issue
Block a user