Refactor make device agnostic in accelerator hooks (#137558)

Make `AcceleratorHooksInterface` consistent for multiple accelerators
- Add `getDeviceFromPtr` method declaration in `AcceleratorHooksInterface`
- Fix clangtidy warning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137558
Approved by: https://github.com/FFFrog, https://github.com/ezyang
This commit is contained in:
zeshengzong
2024-10-12 18:13:52 +00:00
committed by PyTorch MergeBot
parent 0430e72e75
commit 47c8aa8090
5 changed files with 8 additions and 10 deletions

View File

@ -86,14 +86,8 @@ class TORCH_API Context {
initXPUIfNeeded(device_type);
if (device_type == at::kCPU) {
return c10::DeviceType::CPU;
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDeviceFromPtr(data);
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks().getDeviceFromPtr(data);
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
} else {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
}
}
bool isPinnedPtr(

View File

@ -50,6 +50,10 @@ struct TORCH_API AcceleratorHooksInterface {
TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()");
return nullptr;
}
virtual Device getDeviceFromPtr(void* data) const {
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
}
};
} // namespace at

View File

@ -73,7 +73,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
}
virtual Device getDeviceFromPtr(void* /*data*/) const {
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -18,7 +18,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
}
virtual at::Device getDeviceFromPtr(void* data) const {
at::Device getDeviceFromPtr(void* data) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");

View File

@ -50,7 +50,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
}
virtual Device getDeviceFromPtr(void* /*data*/) const {
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
}