fix fake tensor tolist implementation (#135131)

Summary:
When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies.

Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints.

Test Plan:
Some expected failures are gone now.
Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes.

Differential Revision: D62197742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131
Approved by: https://github.com/ezyang
This commit is contained in:
Avik Chaudhuri
2024-09-05 23:20:31 +00:00
committed by PyTorch MergeBot
parent 65e1c34061
commit 43f4947d44
5 changed files with 33 additions and 39 deletions

View File

@ -2239,6 +2239,7 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
for ref_v, res_v in zip(values_copy, values):
self.assertEqual(ref_v.grad, res_v.grad)
@torch._dynamo.config.patch({"capture_scalar_outputs": True})
def test_unbind(self):
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim

View File

@ -2286,8 +2286,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
M = M_v3
export(N(), (t,), strict=strict)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
@testing.expectedFailureSerDer # T195866111
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
# suggested fixes for data-dependent errors only work in non-strict mode
@ -2418,6 +2416,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
strict=strict,
)
def test_tolist(self):
class M(torch.nn.Module):
def forward(self, x):
return x.tolist()
ep = export(M(), (torch.ones(3, dtype=torch.int),))
self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3])
def test_if_functional(self):
class Module(torch.nn.Module):
def forward(self, x):
@ -7920,13 +7926,6 @@ class TestExportCustomClass(TorchTestCase):
arg = node.args[0]
self.assertTrue(arg.op == "placeholder")
def test_tolist_nonstrict_output(self):
class M(torch.nn.Module):
def forward(self, x):
x.tolist()
ep = torch.export.export(M(), (torch.ones(3),), strict=False)
def test_preserve_non_cia_op(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -1025,9 +1025,14 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
)
emb = torch.nn.Embedding(100, 8, device=device)
y = emb(x)
ys = y.unbind()
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
@torch._dynamo.disable
def check(inputs, y):
ys = y.unbind()
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
check(inputs, y)
@skipMeta
@torch.inference_mode()
@ -7120,9 +7125,13 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
a, b, c = nt.unbind()
b.sum().backward()
expected_grad = torch.zeros_like(nt)
expected_grad.unbind()[1].add_(1.0)
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
@torch._dynamo.disable
def check(nt):
expected_grad = torch.zeros_like(nt)
expected_grad.unbind()[1].add_(1.0)
self.assertEqual(nt.grad, expected_grad)
check(nt)
FORWARD_FAILURES = {

View File

@ -14,6 +14,7 @@ import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
Dict,
@ -893,28 +894,14 @@ class FakeTensor(Tensor):
)
return self.nested_int_memo * coeff
# We must handle tolist in a special way for FakeTensors here in the case
# where tolist is called from torch dispatch for tensor subclasses.
# Ordinarily, if a program calls .tolist compiling still works because there is
# special handling in dynamo, but for tensor subclasses if .tolist is called
# inside torch dispatch, the .tolist call may be directly on a FakeTensor.
# This would result in an error since wrapper subclasses don't have storage.
# To avoid this, we handle the FakeTensor case by (1) specializing on the size
# of the tensor to create the output Python list, and (2) creating unbacked
# symints for each element of the list.
def tolist(self) -> List[SymInt]:
assert self.dim() == 1, "NYI for higher dims"
shape_env = self.fake_mode.shape_env
assert shape_env is not None
out = []
# Specialize on the length of the list
for _ in range(self.shape[0]):
s = shape_env.create_unbacked_symint()
# max value?
torch._check_is_size(s)
torch._check(s >= 2)
out.append(s)
return out
# Similar to FunctionalTensor.tolist
def tolist(self) -> Any:
if self.dim() == 0:
return self.item()
elif self.dim() == 1:
return [elem.item() for elem in self]
else:
return [elem.tolist() for elem in self]
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]

View File

@ -23797,8 +23797,6 @@ python_ref_db = [
"_refs.tensor_split",
torch_opinfo_name="tensor_split",
skips=(
# TensorMeta doesn't support tolist
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
# RuntimeError: no _refs support for torch.Tensor.tolist
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
),