add PrivateUse1 device support in function options_from_string. (#118627)

add PrivateUse1 device support in function options_from_string.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118627
Approved by: https://github.com/soulitzer
This commit is contained in:
Shan19900305
2024-01-31 18:52:54 +00:00
committed by PyTorch MergeBot
parent 7aff92c838
commit 99b69e1ffb
4 changed files with 27 additions and 0 deletions

View File

@ -94,6 +94,10 @@ class TORCH_API DeprecatedTypeProperties {
return toBackend(Backend::HIP);
}
DeprecatedTypeProperties & privateUser1() const {
return toBackend(Backend::PrivateUse1);
}
/// Constructs the `TensorOptions` from a type and a `device_index`.
TensorOptions options(int16_t device_index = -1) const {
return TensorOptions().dtype(typeMeta())

View File

@ -48,6 +48,7 @@ namespace VariableType {
TORCH_API std::vector<at::DeprecatedTypeProperties*> allCUDATypes();
TORCH_API std::vector<at::DeprecatedTypeProperties*> allXPUTypes();
TORCH_API std::vector<at::DeprecatedTypeProperties*> allCPUTypes();
TORCH_API std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types();
at::Tensor & unpack(Tensor & t, const char * name, int pos);
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);

View File

@ -50,6 +50,12 @@ std::vector<at::DeprecatedTypeProperties*> allXPUTypes() {
return allTypesForBackends({Backend::XPU, Backend::SparseXPU});
}
std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types() {
at::globalContext().lazyInitPrivateUse1();
return allTypesForBackends(
{Backend::PrivateUse1, Backend::SparsePrivateUse1});
}
namespace {
const Variable& checked_cast_variable(
const Tensor& t,

View File

@ -78,13 +78,18 @@ std::string type_to_string(const at::DeprecatedTypeProperties& type) {
at::TensorOptions options_from_string(const std::string& str) {
static std::string cuda_prefix("torch.cuda.");
static std::string xpu_prefix("torch.xpu.");
static std::string privateUser_prefix(
std::string(parse_privateuseone_backend()) + ".");
static c10::once_flag cpu_once;
static c10::once_flag cuda_once;
static c10::once_flag xpu_once;
static c10::once_flag privateUser1_once;
static std::unordered_map<std::string, at::DeprecatedTypeProperties*> cpu_map;
static std::unordered_map<std::string, at::DeprecatedTypeProperties*> xpu_map;
static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
cuda_map;
static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
privateUser1_map;
const std::unordered_map<std::string, at::DeprecatedTypeProperties*>* map =
nullptr;
@ -115,6 +120,17 @@ at::TensorOptions options_from_string(const std::string& str) {
}
});
map = &xpu_map;
} else if (
std::mismatch(
privateUser_prefix.begin(), privateUser_prefix.end(), str.begin())
.first == privateUser_prefix.end()) {
// torch.foo. foo is privateUser1 name
c10::call_once(privateUser1_once, []() {
for (auto type : autograd::VariableType::allPrivateUser1Types()) {
privateUser1_map.emplace(type_to_string(*type), type);
}
});
map = &privateUser1_map;
} else {
c10::call_once(cpu_once, []() {
for (auto type : autograd::VariableType::allCPUTypes()) {