[fx] Make NormalizeArgs preserve node type (#85637)

Summary: Make `NormalizeArgs` preserve node types when transforming the graph. This bug is preventing me from scripting a graph that goes through the fx2trt `acc_tracer`.

Test Plan: New unit test

Reviewed By: ipiszy

Differential Revision: D39753021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85637
Approved by: https://github.com/Chillee
This commit is contained in:
Mike Iovine
2022-09-26 21:30:16 +00:00
committed by PyTorch MergeBot
parent 5547c6aa4e
commit f325c29b05
2 changed files with 14 additions and 0 deletions

View File

@ -1014,6 +1014,19 @@ class {test_classname}(torch.nn.Module):
else:
self.fail("Didn't find call_function torch.add")
def test_normalize_args_perserve_type(self):
class MyModule(torch.nn.Module):
def forward(self, a: List[torch.Tensor]):
return torch.add(a[0], a[1])
m = MyModule()
traced = symbolic_trace(m)
traced = NormalizeArgs(traced).transform()
for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(node.type, List[torch.Tensor])
@skipIfNoTorchVision
def test_annotate_returns_with_schema(self):
m = resnet18()

View File

@ -59,6 +59,7 @@ class NormalizeArgs(Transformer):
if n.op != "output":
self.node_map[out] = n
out.node.meta = n.meta
out.node.type = n.type
return out
def call_function(