support meta_tensor.to(device='cpu') under fake_mode (#146729)

Fixing this is actually a bit annoying:

(1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor

(2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729
Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD
ghstack dependencies: #146642
This commit is contained in:
Brian Hirsh
2025-02-11 12:36:22 -08:00
committed by PyTorch MergeBot
parent ec0b318ddb
commit 5cda021cac
2 changed files with 13 additions and 0 deletions

View File

@ -2077,6 +2077,13 @@ class FakeTensorDispatchCache(TestCase):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device, torch.device('cpu'))
def test_meta_tensor_to_fake_cpu(self):
x = torch.randn(4, 4, device='meta')
with FakeTensorMode(allow_non_fake_inputs=True):
x_cpu = x.to(device='cpu')
self.assertTrue(isinstance(x_cpu, FakeTensor))
self.assertEqual(x_cpu.device, torch.device('cpu'))
def test_cache_tuple_outputs(self):
"""
Test to check that ops with tuple outputs work.

View File

@ -2000,6 +2000,11 @@ class FakeTensorMode(TorchDispatchMode):
converter = self.fake_tensor_converter
is_lift_func = func in self.lift_fns
device_conversion_skip_const_prop = (
func is torch.ops.aten._to_copy.default
and isinstance(args[0], torch.Tensor)
and args[0].device.type == "meta"
)
# To constant propagate through these functions:
# 1, If this is a lift due to a torch.tensor call,
@ -2013,6 +2018,7 @@ class FakeTensorMode(TorchDispatchMode):
should_allow_numbers_as_tensors(func)
and not has_symbolic_sizes
and not flat_arg_fake_tensors
and not device_conversion_skip_const_prop
):
assert all(
t.constant is not None for t in flat_arg_fake_tensors