mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add PrivateUse1 device support in function options_from_string. (#118627)
add PrivateUse1 device support in function options_from_string. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118627 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
7aff92c838
commit
99b69e1ffb
@ -94,6 +94,10 @@ class TORCH_API DeprecatedTypeProperties {
|
||||
return toBackend(Backend::HIP);
|
||||
}
|
||||
|
||||
DeprecatedTypeProperties & privateUser1() const {
|
||||
return toBackend(Backend::PrivateUse1);
|
||||
}
|
||||
|
||||
/// Constructs the `TensorOptions` from a type and a `device_index`.
|
||||
TensorOptions options(int16_t device_index = -1) const {
|
||||
return TensorOptions().dtype(typeMeta())
|
||||
|
@ -48,6 +48,7 @@ namespace VariableType {
|
||||
TORCH_API std::vector<at::DeprecatedTypeProperties*> allCUDATypes();
|
||||
TORCH_API std::vector<at::DeprecatedTypeProperties*> allXPUTypes();
|
||||
TORCH_API std::vector<at::DeprecatedTypeProperties*> allCPUTypes();
|
||||
TORCH_API std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types();
|
||||
|
||||
at::Tensor & unpack(Tensor & t, const char * name, int pos);
|
||||
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
|
||||
|
@ -50,6 +50,12 @@ std::vector<at::DeprecatedTypeProperties*> allXPUTypes() {
|
||||
return allTypesForBackends({Backend::XPU, Backend::SparseXPU});
|
||||
}
|
||||
|
||||
std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types() {
|
||||
at::globalContext().lazyInitPrivateUse1();
|
||||
return allTypesForBackends(
|
||||
{Backend::PrivateUse1, Backend::SparsePrivateUse1});
|
||||
}
|
||||
|
||||
namespace {
|
||||
const Variable& checked_cast_variable(
|
||||
const Tensor& t,
|
||||
|
@ -78,13 +78,18 @@ std::string type_to_string(const at::DeprecatedTypeProperties& type) {
|
||||
at::TensorOptions options_from_string(const std::string& str) {
|
||||
static std::string cuda_prefix("torch.cuda.");
|
||||
static std::string xpu_prefix("torch.xpu.");
|
||||
static std::string privateUser_prefix(
|
||||
std::string(parse_privateuseone_backend()) + ".");
|
||||
static c10::once_flag cpu_once;
|
||||
static c10::once_flag cuda_once;
|
||||
static c10::once_flag xpu_once;
|
||||
static c10::once_flag privateUser1_once;
|
||||
static std::unordered_map<std::string, at::DeprecatedTypeProperties*> cpu_map;
|
||||
static std::unordered_map<std::string, at::DeprecatedTypeProperties*> xpu_map;
|
||||
static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
|
||||
cuda_map;
|
||||
static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
|
||||
privateUser1_map;
|
||||
|
||||
const std::unordered_map<std::string, at::DeprecatedTypeProperties*>* map =
|
||||
nullptr;
|
||||
@ -115,6 +120,17 @@ at::TensorOptions options_from_string(const std::string& str) {
|
||||
}
|
||||
});
|
||||
map = &xpu_map;
|
||||
} else if (
|
||||
std::mismatch(
|
||||
privateUser_prefix.begin(), privateUser_prefix.end(), str.begin())
|
||||
.first == privateUser_prefix.end()) {
|
||||
// torch.foo. foo is privateUser1 name
|
||||
c10::call_once(privateUser1_once, []() {
|
||||
for (auto type : autograd::VariableType::allPrivateUser1Types()) {
|
||||
privateUser1_map.emplace(type_to_string(*type), type);
|
||||
}
|
||||
});
|
||||
map = &privateUser1_map;
|
||||
} else {
|
||||
c10::call_once(cpu_once, []() {
|
||||
for (auto type : autograd::VariableType::allCPUTypes()) {
|
||||
|
Reference in New Issue
Block a user