[4/N] Apply py39 ruff and pyupgrade fixes (#143257)

```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
cyy
2025-01-04 10:47:51 +00:00
committed by PyTorch MergeBot
parent a881954b0c
commit df458be4e5
55 changed files with 247 additions and 227 deletions

View File

@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
"""
Annotate the type of getitem nodes, inferred from the type of sequence node.
If sequence node is not annotated with a type, do nothing.
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
This is helpful since annotations on local names within function are lost during FX transforms.
Adding back known type annotation for getitem nodes to improve jit scriptability.
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
elif sequence_node.type._name == "List":
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# Generic Alias Type
elif hasattr(sequence_node.type, "__origin__"):
parameterized_types = sequence_node.type.__args__
if sequence_node.type.__origin__ is tuple:
if len(parameterized_types) == 2 and isinstance(
parameterized_types[1], type(...)
):
node.type = parameterized_types[0]
else:
assert len(parameterized_types) > index_node
node_type = parameterized_types[index_node]
node.type = node_type
elif sequence_node.type.__origin__ is list:
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# NamedTuple type
elif hasattr(sequence_node.type, "__annotations__"):
if sequence_node.type == torch.Tensor: