support meta_tensor.to(device='cpu') under fake_mode (#146729)

Fixing this is actually a bit annoying:

(1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor

(2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729
Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD
ghstack dependencies: #146642
This commit is contained in:
Brian Hirsh
2025-02-11 12:36:22 -08:00
committed by PyTorch MergeBot
parent ec0b318ddb
commit 5cda021cac
2 changed files with 13 additions and 0 deletions

View File

@ -2077,6 +2077,13 @@ class FakeTensorDispatchCache(TestCase):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device, torch.device('cpu'))
def test_meta_tensor_to_fake_cpu(self):
x = torch.randn(4, 4, device='meta')
with FakeTensorMode(allow_non_fake_inputs=True):
x_cpu = x.to(device='cpu')
self.assertTrue(isinstance(x_cpu, FakeTensor))
self.assertEqual(x_cpu.device, torch.device('cpu'))
def test_cache_tuple_outputs(self):
"""
Test to check that ops with tuple outputs work.