mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Allow storage() to work on python tensor subclasses, but error on future data accesses (#107417)
This was discussed in feedback from the original version of my "reorder proxy/fake" PR. This PR allows calls to `tensor.untyped_storage()` to **always** return a python storage object to the user. Previously, we would error loudly if we detected that the storage had a null dataptr. Instead, I updated the python bindings for the python storage methods that I saw involve data access, to throw an error later, only if you try to access those methods (e.g. `storage.data_ptr()` will now raise an error if the data ptr is null). Pull Request resolved: https://github.com/pytorch/pytorch/pull/107417 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							df42f15e28
						
					
				
				
					commit
					2c8759df9d
				
			| @ -254,6 +254,27 @@ class TestSubclass(TestCase): | ||||
|         with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"): | ||||
|             param = nn.Parameter(NonRewrappingTensor(torch.randn(3))) | ||||
|  | ||||
|     def test_tensor_subclass_storage_data_accesses_throw(self): | ||||
|         from torch.testing._internal.logging_tensor import LoggingTensor | ||||
|         x = torch.ones(2) | ||||
|         x_log = LoggingTensor(x) | ||||
|         # Accessing storage on a tensor subclass is valid | ||||
|         storage = x_log.untyped_storage() | ||||
|         # This includes accessing metadata on the storage | ||||
|         sz = storage.size() | ||||
|         # But storage methods that access data will throw | ||||
|         with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): | ||||
|             storage.data_ptr() | ||||
|         with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): | ||||
|             storage.resize_(0) | ||||
|         with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): | ||||
|             storage.copy_(storage) | ||||
|         with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): | ||||
|             storage.fill_(0) | ||||
|         with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): | ||||
|             storage._write_file("file") | ||||
|  | ||||
|  | ||||
| instantiate_parametrized_tests(TestSubclass) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
		Reference in New Issue
	
	Block a user