mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor make device agnostic in accelerator hooks (#137558)
Make `AcceleratorHooksInterface` consistent for multiple accelerators - Add `getDeviceFromPtr` method declaration in `AcceleratorHooksInterface` - Fix clangtidy warning Pull Request resolved: https://github.com/pytorch/pytorch/pull/137558 Approved by: https://github.com/FFFrog, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0430e72e75
commit
47c8aa8090
@ -86,14 +86,8 @@ class TORCH_API Context {
|
||||
initXPUIfNeeded(device_type);
|
||||
if (device_type == at::kCPU) {
|
||||
return c10::DeviceType::CPU;
|
||||
} else if (device_type == at::kCUDA) {
|
||||
return at::detail::getCUDAHooks().getDeviceFromPtr(data);
|
||||
} else if (device_type == at::kXPU) {
|
||||
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
||||
} else if (device_type == at::kPrivateUse1) {
|
||||
return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
|
||||
} else {
|
||||
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
||||
return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
|
||||
}
|
||||
}
|
||||
bool isPinnedPtr(
|
||||
|
@ -50,6 +50,10 @@ struct TORCH_API AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual Device getDeviceFromPtr(void* data) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
@ -73,7 +73,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual Device getDeviceFromPtr(void* /*data*/) const {
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
|
||||
}
|
||||
|
||||
virtual at::Device getDeviceFromPtr(void* data) const {
|
||||
at::Device getDeviceFromPtr(void* data) const override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
|
||||
|
@ -50,7 +50,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
|
||||
TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual Device getDeviceFromPtr(void* /*data*/) const {
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user