mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Re-implement pin_memory to be device-agnostic by leveraging the Accelerator concept (#126376)
This PR re-implements pin memory aiming to get rid of the optional `device` argument and makes all related APIs to be device-agnostic. We add two new abstract APIs in [AcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/AcceleratorHooksInterface.h#L12) and redefine pin memory as: "Pin memory is always pinned for the current accelerator device". In detail, it uses [getAcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Context.h#L61) in pin_memory/is_pinned to get an appropriate device and invoke the corresponding overridden interfaces, instead of using BackendSelect and then dispatching to CUDA or other specific backends' implement methods. Note: For new backends who want to implement and use pin memory, just inherit AcceleratorHooksInterface and overwrite the `isPinnedPtr` and `getPinnedMemoryAllocator` methods. Additional context: To avoid BC-breaking, this PR just preserves the `device` arg of related APIs and would throw a deprecation warning if `device` arg is passed. Another PR will be submitted to update all PT callers (`Tensor.is_pinned()`, `Tensor.pin_memory()`...) not to pass this arg based on this PR. In future, `device` arg will be actually removed. Relates #124908 Relates #14560 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126376 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							38b7d89aa4
						
					
				
				
					commit
					c986aeea2d
				
			| @ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( | ||||
|     IS_LINUX, | ||||
|     skipIfTorchDynamo, | ||||
|     TEST_CUDA, | ||||
|     TEST_MPS, | ||||
|     TEST_PRIVATEUSE1, | ||||
|     TEST_XPU, | ||||
| ) | ||||
| @ -37,7 +38,13 @@ def remove_build_path(): | ||||
| # Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other. | ||||
| # The test will be skipped if any of the following conditions are met: | ||||
| @unittest.skipIf( | ||||
|     IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_XPU or TEST_PRIVATEUSE1 or TEST_ROCM, | ||||
|     IS_ARM64 | ||||
|     or not IS_LINUX | ||||
|     or TEST_CUDA | ||||
|     or TEST_XPU | ||||
|     or TEST_MPS | ||||
|     or TEST_PRIVATEUSE1 | ||||
|     or TEST_ROCM, | ||||
|     "Only on linux platform and mutual exclusive to other backends", | ||||
| ) | ||||
| @torch.testing._internal.common_utils.markDynamoStrictTest | ||||
|  | ||||
		Reference in New Issue
	
	Block a user