mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re-implement pin_memory to be device-agnostic by leveraging the Accelerator concept (#126376)
This PR re-implements pin memory aiming to get rid of the optional `device` argument and makes all related APIs to be device-agnostic. We add two new abstract APIs in [AcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/AcceleratorHooksInterface.h#L12) and redefine pin memory as: "Pin memory is always pinned for the current accelerator device". In detail, it uses [getAcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Context.h#L61) in pin_memory/is_pinned to get an appropriate device and invoke the corresponding overridden interfaces, instead of using BackendSelect and then dispatching to CUDA or other specific backends' implement methods. Note: For new backends who want to implement and use pin memory, just inherit AcceleratorHooksInterface and overwrite the `isPinnedPtr` and `getPinnedMemoryAllocator` methods. Additional context: To avoid BC-breaking, this PR just preserves the `device` arg of related APIs and would throw a deprecation warning if `device` arg is passed. Another PR will be submitted to update all PT callers (`Tensor.is_pinned()`, `Tensor.pin_memory()`...) not to pass this arg based on this PR. In future, `device` arg will be actually removed. Relates #124908 Relates #14560 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126376 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
074b420641
commit
8963623494
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_LINUX,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_MPS,
|
||||
TEST_PRIVATEUSE1,
|
||||
TEST_XPU,
|
||||
)
|
||||
@ -37,7 +38,13 @@ def remove_build_path():
|
||||
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
|
||||
# The test will be skipped if any of the following conditions are met:
|
||||
@unittest.skipIf(
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_XPU or TEST_PRIVATEUSE1 or TEST_ROCM,
|
||||
IS_ARM64
|
||||
or not IS_LINUX
|
||||
or TEST_CUDA
|
||||
or TEST_XPU
|
||||
or TEST_MPS
|
||||
or TEST_PRIVATEUSE1
|
||||
or TEST_ROCM,
|
||||
"Only on linux platform and mutual exclusive to other backends",
|
||||
)
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
|
Reference in New Issue
Block a user