mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PT2] Add tolist() to FunctionalTensor for torch.export (#121242)
Adding tolist() to FunctionalTensor for torch.exporting TorchRec data types Pull Request resolved: https://github.com/pytorch/pytorch/pull/121242 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
05c256849b
commit
c66d68ba51
@ -3846,6 +3846,12 @@ class TestExportCustomClass(TorchTestCase):
|
||||
arg = node.args[0]
|
||||
self.assertTrue(arg.op == "placeholder")
|
||||
|
||||
def test_tolist_nonstrict_output(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x.tolist()
|
||||
|
||||
ep = torch.export.export(M(), (torch.ones(3),), strict=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -213,6 +213,14 @@ class FunctionalTensor(torch.Tensor):
|
||||
def mark_mutation_hidden_from_autograd(self) -> None:
|
||||
torch._functionalize_mark_mutation_hidden_from_autograd(self.elem)
|
||||
|
||||
def tolist(self) -> Any:
|
||||
if self.elem.dim() == 0:
|
||||
return self.elem.item()
|
||||
elif self.elem.dim() == 1:
|
||||
return [elem.item() for elem in self.elem]
|
||||
else:
|
||||
return [elem.tolist() for elem in self.elem]
|
||||
|
||||
|
||||
class FunctionalTensorMode(TorchDispatchMode):
|
||||
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
|
||||
|
Reference in New Issue
Block a user