Re-implement pin_memory to be device-agnostic by leveraging the Accelerator concept (#126376)

This PR re-implements pin memory aiming to get rid of the optional `device` argument and makes all related APIs to be device-agnostic. We add two new abstract APIs in [AcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/AcceleratorHooksInterface.h#L12) and redefine pin memory as: "Pin memory is always pinned for the current accelerator device". In detail, it uses [getAcceleratorHooksInterface](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Context.h#L61) in pin_memory/is_pinned to get an appropriate device and invoke the corresponding overridden interfaces, instead of using BackendSelect and then dispatching to CUDA or other specific backends' implement methods.

Note: For new backends who want to implement and use pin memory, just inherit AcceleratorHooksInterface and overwrite the `isPinnedPtr` and `getPinnedMemoryAllocator` methods.

Additional context: To avoid BC-breaking, this PR just preserves the `device` arg of related APIs and would throw a deprecation warning if `device` arg is passed. Another PR will be submitted to update all PT callers (`Tensor.is_pinned()`, `Tensor.pin_memory()`...) not to pass this arg based on this PR. In future, `device` arg will be actually removed.

Relates #124908
Relates #14560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126376
Approved by: https://github.com/albanD
This commit is contained in:
wizzniu
2024-07-23 01:44:15 +00:00
committed by PyTorch MergeBot
parent 074b420641
commit 8963623494
25 changed files with 210 additions and 203 deletions

View File

@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
IS_LINUX,
skipIfTorchDynamo,
TEST_CUDA,
TEST_MPS,
TEST_PRIVATEUSE1,
TEST_XPU,
)
@ -37,7 +38,13 @@ def remove_build_path():
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
# The test will be skipped if any of the following conditions are met:
@unittest.skipIf(
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_XPU or TEST_PRIVATEUSE1 or TEST_ROCM,
IS_ARM64
or not IS_LINUX
or TEST_CUDA
or TEST_XPU
or TEST_MPS
or TEST_PRIVATEUSE1
or TEST_ROCM,
"Only on linux platform and mutual exclusive to other backends",
)
@torch.testing._internal.common_utils.markDynamoStrictTest