mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
support meta_tensor.to(device='cpu') under fake_mode (#146729)
Fixing this is actually a bit annoying: (1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor (2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729 Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD ghstack dependencies: #146642
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec0b318ddb
commit
5cda021cac
@ -2077,6 +2077,13 @@ class FakeTensorDispatchCache(TestCase):
|
||||
self.assertTrue(isinstance(t, FakeTensor))
|
||||
self.assertEqual(t.device, torch.device('cpu'))
|
||||
|
||||
def test_meta_tensor_to_fake_cpu(self):
|
||||
x = torch.randn(4, 4, device='meta')
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
x_cpu = x.to(device='cpu')
|
||||
self.assertTrue(isinstance(x_cpu, FakeTensor))
|
||||
self.assertEqual(x_cpu.device, torch.device('cpu'))
|
||||
|
||||
def test_cache_tuple_outputs(self):
|
||||
"""
|
||||
Test to check that ops with tuple outputs work.
|
||||
|
@ -2000,6 +2000,11 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
converter = self.fake_tensor_converter
|
||||
|
||||
is_lift_func = func in self.lift_fns
|
||||
device_conversion_skip_const_prop = (
|
||||
func is torch.ops.aten._to_copy.default
|
||||
and isinstance(args[0], torch.Tensor)
|
||||
and args[0].device.type == "meta"
|
||||
)
|
||||
|
||||
# To constant propagate through these functions:
|
||||
# 1, If this is a lift due to a torch.tensor call,
|
||||
@ -2013,6 +2018,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
should_allow_numbers_as_tensors(func)
|
||||
and not has_symbolic_sizes
|
||||
and not flat_arg_fake_tensors
|
||||
and not device_conversion_skip_const_prop
|
||||
):
|
||||
assert all(
|
||||
t.constant is not None for t in flat_arg_fake_tensors
|
||||
|
Reference in New Issue
Block a user