mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix fake tensor tolist implementation (#135131)
Summary: When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies. Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints. Test Plan: Some expected failures are gone now. Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes. Differential Revision: D62197742 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
65e1c34061
commit
43f4947d44
@ -2239,6 +2239,7 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
for ref_v, res_v in zip(values_copy, values):
|
||||
self.assertEqual(ref_v.grad, res_v.grad)
|
||||
|
||||
@torch._dynamo.config.patch({"capture_scalar_outputs": True})
|
||||
def test_unbind(self):
|
||||
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
|
||||
# This causes a recompile later on when it realizes the batch and last dim
|
||||
|
@ -2286,8 +2286,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
M = M_v3
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
|
||||
# suggested fixes for data-dependent errors only work in non-strict mode
|
||||
@ -2418,6 +2416,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
def test_tolist(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.tolist()
|
||||
|
||||
ep = export(M(), (torch.ones(3, dtype=torch.int),))
|
||||
self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3])
|
||||
|
||||
def test_if_functional(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -7920,13 +7926,6 @@ class TestExportCustomClass(TorchTestCase):
|
||||
arg = node.args[0]
|
||||
self.assertTrue(arg.op == "placeholder")
|
||||
|
||||
def test_tolist_nonstrict_output(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x.tolist()
|
||||
|
||||
ep = torch.export.export(M(), (torch.ones(3),), strict=False)
|
||||
|
||||
def test_preserve_non_cia_op(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -1025,10 +1025,15 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
|
||||
)
|
||||
emb = torch.nn.Embedding(100, 8, device=device)
|
||||
y = emb(x)
|
||||
|
||||
@torch._dynamo.disable
|
||||
def check(inputs, y):
|
||||
ys = y.unbind()
|
||||
for i, inp in enumerate(inputs):
|
||||
self.assertEqual(emb(inp), ys[i])
|
||||
|
||||
check(inputs, y)
|
||||
|
||||
@skipMeta
|
||||
@torch.inference_mode()
|
||||
@dtypes(*floating_types_and_half())
|
||||
@ -7120,9 +7125,13 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
a, b, c = nt.unbind()
|
||||
b.sum().backward()
|
||||
|
||||
@torch._dynamo.disable
|
||||
def check(nt):
|
||||
expected_grad = torch.zeros_like(nt)
|
||||
expected_grad.unbind()[1].add_(1.0)
|
||||
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
|
||||
self.assertEqual(nt.grad, expected_grad)
|
||||
|
||||
check(nt)
|
||||
|
||||
|
||||
FORWARD_FAILURES = {
|
||||
|
@ -14,6 +14,7 @@ import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
@ -893,28 +894,14 @@ class FakeTensor(Tensor):
|
||||
)
|
||||
return self.nested_int_memo * coeff
|
||||
|
||||
# We must handle tolist in a special way for FakeTensors here in the case
|
||||
# where tolist is called from torch dispatch for tensor subclasses.
|
||||
# Ordinarily, if a program calls .tolist compiling still works because there is
|
||||
# special handling in dynamo, but for tensor subclasses if .tolist is called
|
||||
# inside torch dispatch, the .tolist call may be directly on a FakeTensor.
|
||||
# This would result in an error since wrapper subclasses don't have storage.
|
||||
# To avoid this, we handle the FakeTensor case by (1) specializing on the size
|
||||
# of the tensor to create the output Python list, and (2) creating unbacked
|
||||
# symints for each element of the list.
|
||||
def tolist(self) -> List[SymInt]:
|
||||
assert self.dim() == 1, "NYI for higher dims"
|
||||
shape_env = self.fake_mode.shape_env
|
||||
assert shape_env is not None
|
||||
out = []
|
||||
# Specialize on the length of the list
|
||||
for _ in range(self.shape[0]):
|
||||
s = shape_env.create_unbacked_symint()
|
||||
# max value?
|
||||
torch._check_is_size(s)
|
||||
torch._check(s >= 2)
|
||||
out.append(s)
|
||||
return out
|
||||
# Similar to FunctionalTensor.tolist
|
||||
def tolist(self) -> Any:
|
||||
if self.dim() == 0:
|
||||
return self.item()
|
||||
elif self.dim() == 1:
|
||||
return [elem.item() for elem in self]
|
||||
else:
|
||||
return [elem.tolist() for elem in self]
|
||||
|
||||
|
||||
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
|
||||
|
@ -23797,8 +23797,6 @@ python_ref_db = [
|
||||
"_refs.tensor_split",
|
||||
torch_opinfo_name="tensor_split",
|
||||
skips=(
|
||||
# TensorMeta doesn't support tolist
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
|
||||
# RuntimeError: no _refs support for torch.Tensor.tolist
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||
),
|
||||
|
Reference in New Issue
Block a user