mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add back HIP support to async net (#13400)
Summary:
We lost HIP support in last refactoring 620ece2668
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13400
Differential Revision: D12868211
Pulled By: bddppq
fbshipit-source-id: 72dbfda105b826bee28ddf480e88fca7d63f93d8
This commit is contained in:
committed by
Facebook Github Bot
parent
eaf141dd64
commit
a682ce9144
@ -167,7 +167,8 @@ TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
|
||||
if (use_single_pool_) {
|
||||
return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
|
||||
}
|
||||
if (IsCPUDeviceType(device_option.device_type())) {
|
||||
const auto device_type = device_option.device_type();
|
||||
if (IsCPUDeviceType(device_type)) {
|
||||
auto numa_node_id = -1;
|
||||
if (device_option.has_numa_node_id()) {
|
||||
numa_node_id = device_option.numa_node_id();
|
||||
@ -178,24 +179,24 @@ TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
|
||||
FLAGS_caffe2_net_async_max_numa_nodes,
|
||||
"Invalid NUMA node id: ",
|
||||
numa_node_id);
|
||||
return poolGetter(cpu_pools_, PROTO_CPU, numa_node_id, num_workers_);
|
||||
} else if (device_option.device_type() == PROTO_CUDA) {
|
||||
return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
|
||||
} else if (IsGPUDeviceType(device_type)) {
|
||||
auto gpu_id = device_option.device_id();
|
||||
CAFFE_ENFORCE(
|
||||
gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
|
||||
"Invalid GPU id: " + caffe2::to_string(gpu_id));
|
||||
return poolGetter(gpu_pools_, PROTO_CUDA, gpu_id, num_workers_);
|
||||
return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"Unsupported device type " +
|
||||
caffe2::to_string(device_option.device_type()));
|
||||
caffe2::to_string(device_type));
|
||||
}
|
||||
}
|
||||
|
||||
int AsyncNetBase::stream(int task_id) {
|
||||
const auto& device_option = event(task_id).GetDeviceOption();
|
||||
int stream_id = 0;
|
||||
if (device_option.device_type() == PROTO_CUDA) {
|
||||
if (IsGPUDeviceType(device_option.device_type())) {
|
||||
int gpu_id = device_option.device_id();
|
||||
CAFFE_ENFORCE_GE(gpu_id, 0, "Invalid gpu id: " + caffe2::to_string(gpu_id));
|
||||
if ((unsigned)gpu_id >= getStreamCounters().size()) {
|
||||
|
@ -58,6 +58,14 @@ C10_EXPORT bool IsCPUDeviceType(int device_type) {
|
||||
return cpu_types.count(device_type);
|
||||
}
|
||||
|
||||
C10_EXPORT bool IsGPUDeviceType(int device_type) {
|
||||
static const std::unordered_set<int> gpu_types{
|
||||
PROTO_CUDA,
|
||||
PROTO_HIP,
|
||||
};
|
||||
return gpu_types.count(device_type);
|
||||
}
|
||||
|
||||
C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) {
|
||||
std::ifstream ifs(filename, std::ios::in);
|
||||
if (!ifs) {
|
||||
|
@ -31,6 +31,7 @@ CAFFE2_API int DeviceId(const DeviceOption& option);
|
||||
CAFFE2_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
|
||||
|
||||
CAFFE2_API bool IsCPUDeviceType(int device_type);
|
||||
CAFFE2_API bool IsGPUDeviceType(int device_type);
|
||||
|
||||
// Common interfaces that reads file contents into a string.
|
||||
CAFFE2_API bool ReadStringFromFile(const char* filename, string* str);
|
||||
|
Reference in New Issue
Block a user