Get Device instance with correct type when privateuse1 backend is registered (#117966)

Fixes #ISSUE_NUMBER
If privateuse1 backend is registered. Let torch.device return corresponding instance of Device when only index is given.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117966
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
dilililiwhy
2024-01-24 19:03:13 +00:00
committed by PyTorch MergeBot
parent 6fc015fedc
commit b025e5984c
3 changed files with 11 additions and 0 deletions

View File

@ -148,4 +148,8 @@ void register_privateuse1_backend(const std::string& backend_name) {
privateuse1_backend_name_set.store(true, std::memory_order_relaxed);
}
bool is_privateuse1_backend_registered() {
return privateuse1_backend_name_set.load(std::memory_order_acquire);
}
} // namespace c10

View File

@ -104,6 +104,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
C10_API void register_privateuse1_backend(const std::string& backend_name);
C10_API std::string get_privateuse1_backend(bool lower_case = true);
C10_API bool is_privateuse1_backend_registered();
} // namespace c10
namespace std {

View File

@ -808,6 +808,11 @@ inline at::Device toDevice(PyObject* obj) {
if (THPUtils_checkLong(obj)) {
const auto device_index = THPUtils_unpackLong(obj);
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
if (c10::is_privateuse1_backend_registered()) {
return at::Device(
c10::DeviceType::PrivateUse1,
static_cast<c10::DeviceIndex>(device_index));
}
return at::Device(
c10::DeviceType::CUDA, static_cast<c10::DeviceIndex>(device_index));
}