mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] make UnflattenedModule not inherit from GraphModule (#115408)
UnflattenedModule doesn't really behave like a graph module; we customize `__call__` to do something completely different than what GraphModule does. So, things that test `isinstance(unflattened_module, GraphModule)` and do something with the GraphModule are often broken. This change makes UnflattenedModule it's own thing. Differential Revision: [D51959097](https://our.internmc.facebook.com/intern/diff/D51959097/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115408 Approved by: https://github.com/zhxchen17
This commit is contained in:
@ -1804,6 +1804,24 @@ class TestFX(JitTestCase):
|
||||
self.assertEqual(interpreter.run(input), gm(input))
|
||||
self.assertEqual(interpreter.run(input), m(input))
|
||||
|
||||
def test_interpreter_other_graph(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||
self.linear = torch.nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
||||
|
||||
m = MyModule()
|
||||
gm = torch.fx.symbolic_trace(m)
|
||||
|
||||
interpreter = Interpreter(gm, graph=gm.graph)
|
||||
input = torch.randn(3, 4)
|
||||
self.assertEqual(interpreter.run(input), gm(input))
|
||||
self.assertEqual(interpreter.run(input), m(input))
|
||||
|
||||
def test_interpreter_run_node_override(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user