[PT2] Add tolist() to FunctionalTensor for torch.export (#121242)

Adding tolist() to FunctionalTensor for torch.exporting TorchRec data types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121242
Approved by: https://github.com/ezyang
This commit is contained in:
PaulZhang12
2024-03-06 18:10:40 +00:00
committed by PyTorch MergeBot
parent 05c256849b
commit c66d68ba51
2 changed files with 14 additions and 0 deletions

View File

@ -3846,6 +3846,12 @@ class TestExportCustomClass(TorchTestCase):
arg = node.args[0]
self.assertTrue(arg.op == "placeholder")
def test_tolist_nonstrict_output(self):
class M(torch.nn.Module):
def forward(self, x):
x.tolist()
ep = torch.export.export(M(), (torch.ones(3),), strict=False)
if __name__ == '__main__':
run_tests()

View File

@ -213,6 +213,14 @@ class FunctionalTensor(torch.Tensor):
def mark_mutation_hidden_from_autograd(self) -> None:
torch._functionalize_mark_mutation_hidden_from_autograd(self.elem)
def tolist(self) -> Any:
if self.elem.dim() == 0:
return self.elem.item()
elif self.elem.dim() == 1:
return [elem.item() for elem in self.elem]
else:
return [elem.tolist() for elem in self.elem]
class FunctionalTensorMode(TorchDispatchMode):
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):