Fix FakeTensor printing (#99205)

I got too confused by the FakeTensor printing, so this PR fixes it to
print normally.

Before:
```
with FakeTensorMode():
    x = torch.empty(2, 2, device="cpu")
    print(x)
    # FakeTensor(FakeTensor(..., device='meta', shape=(2, 2)), cpu)
```
After (Tensor printing doesn't print the default device):
```
FakeTensor(..., shape=(2, 2))
```

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99205
Approved by: https://github.com/eellison
This commit is contained in:
Richard Zou
2023-04-17 10:10:01 -07:00
committed by PyTorch MergeBot
parent 20a90a1f80
commit 57e1a50da3
3 changed files with 12 additions and 8 deletions

View File

@ -76,6 +76,13 @@ class FakeTensorTest(TestCase):
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))
def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:

View File

@ -924,12 +924,6 @@ class FakeTensor(torch.Tensor):
def from_tensor(t, fake_mode):
return fake_mode.from_tensor(t)
# TODO: resolve error in default __repr__
def __repr__(self):
with in_kernel_invocation_manager(self.fake_mode):
self_repr = super().__repr__()
return f"FakeTensor({self_repr}, {self.fake_device})"
@classmethod
@count
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

View File

@ -536,12 +536,15 @@ def _str_intern(inp, *, tensor_contents=None):
prefix = "_to_functional_tensor("
tensor_str = repr(torch._from_functional_tensor(self))
else:
if self.is_meta:
# Circular import problem, so we import it here
from torch._subclasses.fake_tensor import FakeTensor
if self.is_meta or isinstance(self, FakeTensor):
suffixes.append("size=" + str(tuple(self.shape)))
if self.dtype != torch.get_default_dtype():
suffixes.append("dtype=" + str(self.dtype))
# TODO: This implies that ellipses is valid syntax for allocating
# a meta tensor, which it could be, but it isn't right now
# a meta tensor or FakeTensor, which it could be, but it isn't right now
if not custom_contents_provided:
tensor_str = "..."
else: