mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sigmoid] fix for FX tracing unflattened modules (#115708)
Differential Revision: [D52095387](https://our.internmc.facebook.com/intern/diff/D52095387/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115708 Approved by: https://github.com/zhxchen17
This commit is contained in:
@ -273,6 +273,21 @@ class TestFX(JitTestCase):
|
||||
t = T()
|
||||
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
|
||||
|
||||
def test_varargs_concrete(self):
|
||||
class T(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
x = args[0] + args[1]
|
||||
return x
|
||||
|
||||
args = (torch.rand(1), torch.rand(1))
|
||||
|
||||
t = T()
|
||||
ref_outs = t(*args)
|
||||
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
|
||||
gm.graph.lint()
|
||||
test_outs = gm(*args)
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_args_kwargs_no_self(self):
|
||||
class T(torch.nn.Module):
|
||||
def forward(*args, **kwargs): # noqa: B902
|
||||
|
Reference in New Issue
Block a user