[export] make UnflattenedModule not inherit from GraphModule (#115408)

UnflattenedModule doesn't really behave like a graph module; we customize `__call__` to do something completely different than what GraphModule does. So, things that test `isinstance(unflattened_module, GraphModule)` and do something with the GraphModule are often broken.

This change makes UnflattenedModule it's own thing.

Differential Revision: [D51959097](https://our.internmc.facebook.com/intern/diff/D51959097/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115408
Approved by: https://github.com/zhxchen17
This commit is contained in:
suo
2023-12-11 09:45:18 -08:00
committed by PyTorch MergeBot
parent 8c1567d021
commit c137335b5c
4 changed files with 78 additions and 50 deletions

View File

@ -1804,6 +1804,24 @@ class TestFX(JitTestCase):
self.assertEqual(interpreter.run(input), gm(input))
self.assertEqual(interpreter.run(input), m(input))
def test_interpreter_other_graph(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
interpreter = Interpreter(gm, graph=gm.graph)
input = torch.randn(3, 4)
self.assertEqual(interpreter.run(input), gm(input))
self.assertEqual(interpreter.run(input), m(input))
def test_interpreter_run_node_override(self):
class MyModule(torch.nn.Module):
def __init__(self):