mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257 Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
|
||||
inp3_y = inp3.y
|
||||
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||
|
||||
class MyModule2(torch.nn.Module):
|
||||
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
|
||||
inp_0 = inp[0]
|
||||
inp_1 = inp[1]
|
||||
inp2_0 = inp2[0]
|
||||
inp3_x = inp3.x
|
||||
inp3_y = inp3.y
|
||||
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||
|
||||
my_module = MyModule()
|
||||
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||
|
||||
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
|
||||
if node.target == operator.getitem:
|
||||
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||
|
||||
my_module = MyModule2()
|
||||
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||
|
||||
# by default, fx transform loses type annotation of getitem nodes.
|
||||
for node in my_module_traced.graph.nodes:
|
||||
if node.target == operator.getitem:
|
||||
assert node.type is None
|
||||
|
||||
annotate_getitem_nodes(my_module_traced.graph)
|
||||
|
||||
for node in my_module_traced.graph.nodes:
|
||||
if node.target == operator.getitem:
|
||||
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||
|
||||
def test_subgraph_uniquename(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
Reference in New Issue
Block a user