mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix privateuse1 backend name case (#132980)
### Problem `get_privateuse1_backend(bool lower_case)` always returns a lower case name and `lower_case` is not used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132980 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8275e25a7
commit
343071cd96
@ -131,6 +131,9 @@ std::string get_privateuse1_backend(bool lower_case) {
|
||||
// set, and will never be written to.
|
||||
auto backend_name =
|
||||
name_registered ? privateuse1_backend_name : "privateuseone";
|
||||
auto op_case = lower_case ? ::tolower : ::toupper;
|
||||
std::transform(
|
||||
backend_name.begin(), backend_name.end(), backend_name.begin(), op_case);
|
||||
return backend_name;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
// -- Device -------------------------------------------------------
|
||||
@ -46,3 +47,10 @@ TEST(DeviceTest, BasicConstruction) {
|
||||
EXPECT_THROW(make_device(ds), c10::Error) << "Device String: " << ds;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeviceTypeTest, PrivateUseOneDeviceType) {
|
||||
c10::register_privateuse1_backend("my_privateuse1_backend");
|
||||
ASSERT_TRUE(c10::is_privateuse1_backend_registered());
|
||||
ASSERT_EQ(c10::get_privateuse1_backend(true), "my_privateuse1_backend");
|
||||
ASSERT_EQ(c10::get_privateuse1_backend(false), "MY_PRIVATEUSE1_BACKEND");
|
||||
}
|
||||
|
Reference in New Issue
Block a user