mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	support meta_tensor.to(device='cpu') under fake_mode (#146729)
Fixing this is actually a bit annoying: (1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor (2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729 Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD ghstack dependencies: #146642
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							ec0b318ddb
						
					
				
				
					commit
					5cda021cac
				
			| @ -2077,6 +2077,13 @@ class FakeTensorDispatchCache(TestCase): | ||||
|             self.assertTrue(isinstance(t, FakeTensor)) | ||||
|             self.assertEqual(t.device, torch.device('cpu')) | ||||
|  | ||||
|     def test_meta_tensor_to_fake_cpu(self): | ||||
|         x = torch.randn(4, 4, device='meta') | ||||
|         with FakeTensorMode(allow_non_fake_inputs=True): | ||||
|             x_cpu = x.to(device='cpu') | ||||
|         self.assertTrue(isinstance(x_cpu, FakeTensor)) | ||||
|         self.assertEqual(x_cpu.device, torch.device('cpu')) | ||||
|  | ||||
|     def test_cache_tuple_outputs(self): | ||||
|         """ | ||||
|         Test to check that ops with tuple outputs work. | ||||
|  | ||||
| @ -2000,6 +2000,11 @@ class FakeTensorMode(TorchDispatchMode): | ||||
|         converter = self.fake_tensor_converter | ||||
|  | ||||
|         is_lift_func = func in self.lift_fns | ||||
|         device_conversion_skip_const_prop = ( | ||||
|             func is torch.ops.aten._to_copy.default | ||||
|             and isinstance(args[0], torch.Tensor) | ||||
|             and args[0].device.type == "meta" | ||||
|         ) | ||||
|  | ||||
|         # To constant propagate through these functions: | ||||
|         # 1, If this is a lift due to a torch.tensor call, | ||||
| @ -2013,6 +2018,7 @@ class FakeTensorMode(TorchDispatchMode): | ||||
|             should_allow_numbers_as_tensors(func) | ||||
|             and not has_symbolic_sizes | ||||
|             and not flat_arg_fake_tensors | ||||
|             and not device_conversion_skip_const_prop | ||||
|         ): | ||||
|             assert all( | ||||
|                 t.constant is not None for t in flat_arg_fake_tensors | ||||
|  | ||||
		Reference in New Issue
	
	Block a user