[4/N] Apply py39 ruff and pyupgrade fixes (#143257)

```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
cyy
2025-01-04 10:47:51 +00:00
committed by PyTorch MergeBot
parent a881954b0c
commit df458be4e5
55 changed files with 247 additions and 227 deletions

View File

@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
class MyModule2(torch.nn.Module):
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
inp_0 = inp[0]
inp_1 = inp[1]
inp2_0 = inp2[0]
inp3_x = inp3.x
inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
my_module = MyModule()
my_module_traced = torch.fx.symbolic_trace(my_module)
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
my_module = MyModule2()
my_module_traced = torch.fx.symbolic_trace(my_module)
# by default, fx transform loses type annotation of getitem nodes.
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
assert node.type is None
annotate_getitem_nodes(my_module_traced.graph)
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
def test_subgraph_uniquename(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None: