mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #160547 ### Summary: bug ``` def test_namedtuple(self): from collections import namedtuple Point = namedtuple('Point', 'x y') class M(torch.nn.Module): def forward(self, x, y): return x + y inp = Point(torch.ones(3), torch.ones(3)) print(M()(*inp)) # errors ep = torch.export.export(M(), inp, strict=False) print(ep) # succeeds ep = torch.export.export(M(), inp, strict=True) print(ep) # workaround could be to convert namedtuple to a kwarg inp_kwargs = {field: getattr(inp, field) for field in inp._fields} ep = torch.export.export(M(), (), inp_kwargs) print(ep) ``` FIx : namedtuple is subclass of tuple but namedtuple is not expected So, this change handles named tuple case I have added 🧪 test case for this as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/162959 Approved by: https://github.com/angelayi Co-authored-by: Angela Yi <angelayi@meta.com>