mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Get Device instance with correct type when privateuse1 backend is registered (#117966)
Fixes #ISSUE_NUMBER If privateuse1 backend is registered. Let torch.device return corresponding instance of Device when only index is given. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117966 Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
6fc015fedc
commit
b025e5984c
@ -148,4 +148,8 @@ void register_privateuse1_backend(const std::string& backend_name) {
|
||||
privateuse1_backend_name_set.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
bool is_privateuse1_backend_registered() {
|
||||
return privateuse1_backend_name_set.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -104,6 +104,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
|
||||
C10_API void register_privateuse1_backend(const std::string& backend_name);
|
||||
C10_API std::string get_privateuse1_backend(bool lower_case = true);
|
||||
|
||||
C10_API bool is_privateuse1_backend_registered();
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
@ -808,6 +808,11 @@ inline at::Device toDevice(PyObject* obj) {
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
const auto device_index = THPUtils_unpackLong(obj);
|
||||
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
if (c10::is_privateuse1_backend_registered()) {
|
||||
return at::Device(
|
||||
c10::DeviceType::PrivateUse1,
|
||||
static_cast<c10::DeviceIndex>(device_index));
|
||||
}
|
||||
return at::Device(
|
||||
c10::DeviceType::CUDA, static_cast<c10::DeviceIndex>(device_index));
|
||||
}
|
||||
|
Reference in New Issue
Block a user