mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257 Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
||||
"""
|
||||
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
||||
If sequence node is not annotated with a type, do nothing.
|
||||
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
|
||||
Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
|
||||
|
||||
This is helpful since annotations on local names within function are lost during FX transforms.
|
||||
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
||||
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
||||
elif sequence_node.type._name == "List":
|
||||
assert len(parameterized_types) == 1
|
||||
node.type = parameterized_types[0]
|
||||
# Generic Alias Type
|
||||
elif hasattr(sequence_node.type, "__origin__"):
|
||||
parameterized_types = sequence_node.type.__args__
|
||||
if sequence_node.type.__origin__ is tuple:
|
||||
if len(parameterized_types) == 2 and isinstance(
|
||||
parameterized_types[1], type(...)
|
||||
):
|
||||
node.type = parameterized_types[0]
|
||||
else:
|
||||
assert len(parameterized_types) > index_node
|
||||
node_type = parameterized_types[index_node]
|
||||
node.type = node_type
|
||||
elif sequence_node.type.__origin__ is list:
|
||||
assert len(parameterized_types) == 1
|
||||
node.type = parameterized_types[0]
|
||||
# NamedTuple type
|
||||
elif hasattr(sequence_node.type, "__annotations__"):
|
||||
if sequence_node.type == torch.Tensor:
|
||||
|
Reference in New Issue
Block a user