[sigmoid] fix for FX tracing unflattened modules (#115708)

Differential Revision: [D52095387](https://our.internmc.facebook.com/intern/diff/D52095387/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115708
Approved by: https://github.com/zhxchen17
This commit is contained in:
suo
2023-12-12 22:44:05 -08:00
committed by PyTorch MergeBot
parent 75d3bbaaa2
commit 926236305f
2 changed files with 51 additions and 11 deletions

View File

@ -273,6 +273,21 @@ class TestFX(JitTestCase):
t = T()
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
def test_varargs_concrete(self):
class T(torch.nn.Module):
def forward(self, *args, **kwargs):
x = args[0] + args[1]
return x
args = (torch.rand(1), torch.rand(1))
t = T()
ref_outs = t(*args)
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
gm.graph.lint()
test_outs = gm(*args)
self.assertEqual(ref_outs, test_outs)
def test_args_kwargs_no_self(self):
class T(torch.nn.Module):
def forward(*args, **kwargs): # noqa: B902