mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
support meta_tensor.to(device='cpu') under fake_mode (#146729)
Fixing this is actually a bit annoying: (1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor (2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729 Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD ghstack dependencies: #146642
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec0b318ddb
commit
5cda021cac
@ -2077,6 +2077,13 @@ class FakeTensorDispatchCache(TestCase):
|
||||
self.assertTrue(isinstance(t, FakeTensor))
|
||||
self.assertEqual(t.device, torch.device('cpu'))
|
||||
|
||||
def test_meta_tensor_to_fake_cpu(self):
|
||||
x = torch.randn(4, 4, device='meta')
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
x_cpu = x.to(device='cpu')
|
||||
self.assertTrue(isinstance(x_cpu, FakeTensor))
|
||||
self.assertEqual(x_cpu.device, torch.device('cpu'))
|
||||
|
||||
def test_cache_tuple_outputs(self):
|
||||
"""
|
||||
Test to check that ops with tuple outputs work.
|
||||
|
Reference in New Issue
Block a user