mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] Make NormalizeArgs preserve node type (#85637)
Summary: Make `NormalizeArgs` preserve node types when transforming the graph. This bug is preventing me from scripting a graph that goes through the fx2trt `acc_tracer`. Test Plan: New unit test Reviewed By: ipiszy Differential Revision: D39753021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85637 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
5547c6aa4e
commit
f325c29b05
@ -1014,6 +1014,19 @@ class {test_classname}(torch.nn.Module):
|
||||
else:
|
||||
self.fail("Didn't find call_function torch.add")
|
||||
|
||||
def test_normalize_args_perserve_type(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, a: List[torch.Tensor]):
|
||||
return torch.add(a[0], a[1])
|
||||
|
||||
m = MyModule()
|
||||
traced = symbolic_trace(m)
|
||||
traced = NormalizeArgs(traced).transform()
|
||||
|
||||
for node in traced.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(node.type, List[torch.Tensor])
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_annotate_returns_with_schema(self):
|
||||
m = resnet18()
|
||||
|
@ -59,6 +59,7 @@ class NormalizeArgs(Transformer):
|
||||
if n.op != "output":
|
||||
self.node_map[out] = n
|
||||
out.node.meta = n.meta
|
||||
out.node.type = n.type
|
||||
return out
|
||||
|
||||
def call_function(
|
||||
|
Reference in New Issue
Block a user