mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix FakeTensor printing (#99205)
I got too confused by the FakeTensor printing, so this PR fixes it to print normally. Before: ``` with FakeTensorMode(): x = torch.empty(2, 2, device="cpu") print(x) # FakeTensor(FakeTensor(..., device='meta', shape=(2, 2)), cpu) ``` After (Tensor printing doesn't print the default device): ``` FakeTensor(..., shape=(2, 2)) ``` Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/99205 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
20a90a1f80
commit
57e1a50da3
@ -76,6 +76,13 @@ class FakeTensorTest(TestCase):
|
||||
self.assertEqual(out.device.type, "cpu")
|
||||
self.assertTrue(isinstance(out, FakeTensor))
|
||||
|
||||
def test_repr(self):
|
||||
with FakeTensorMode():
|
||||
x = torch.empty(2, 2, device="cpu")
|
||||
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
|
||||
x = torch.empty(2, 2, device="meta")
|
||||
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_zero_dim(self):
|
||||
with FakeTensorMode() as mode:
|
||||
|
@ -924,12 +924,6 @@ class FakeTensor(torch.Tensor):
|
||||
def from_tensor(t, fake_mode):
|
||||
return fake_mode.from_tensor(t)
|
||||
|
||||
# TODO: resolve error in default __repr__
|
||||
def __repr__(self):
|
||||
with in_kernel_invocation_manager(self.fake_mode):
|
||||
self_repr = super().__repr__()
|
||||
return f"FakeTensor({self_repr}, {self.fake_device})"
|
||||
|
||||
@classmethod
|
||||
@count
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
@ -536,12 +536,15 @@ def _str_intern(inp, *, tensor_contents=None):
|
||||
prefix = "_to_functional_tensor("
|
||||
tensor_str = repr(torch._from_functional_tensor(self))
|
||||
else:
|
||||
if self.is_meta:
|
||||
# Circular import problem, so we import it here
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
if self.is_meta or isinstance(self, FakeTensor):
|
||||
suffixes.append("size=" + str(tuple(self.shape)))
|
||||
if self.dtype != torch.get_default_dtype():
|
||||
suffixes.append("dtype=" + str(self.dtype))
|
||||
# TODO: This implies that ellipses is valid syntax for allocating
|
||||
# a meta tensor, which it could be, but it isn't right now
|
||||
# a meta tensor or FakeTensor, which it could be, but it isn't right now
|
||||
if not custom_contents_provided:
|
||||
tensor_str = "..."
|
||||
else:
|
||||
|
Reference in New Issue
Block a user