Add back HIP support to async net (#13400)

Summary:
We lost HIP support in last refactoring 620ece2668
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13400

Differential Revision: D12868211

Pulled By: bddppq

fbshipit-source-id: 72dbfda105b826bee28ddf480e88fca7d63f93d8
This commit is contained in:
Junjie Bai
2018-10-31 17:48:46 -07:00
committed by Facebook Github Bot
parent eaf141dd64
commit a682ce9144
3 changed files with 16 additions and 6 deletions

View File

@ -167,7 +167,8 @@ TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
if (use_single_pool_) {
return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
}
if (IsCPUDeviceType(device_option.device_type())) {
const auto device_type = device_option.device_type();
if (IsCPUDeviceType(device_type)) {
auto numa_node_id = -1;
if (device_option.has_numa_node_id()) {
numa_node_id = device_option.numa_node_id();
@ -178,24 +179,24 @@ TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
FLAGS_caffe2_net_async_max_numa_nodes,
"Invalid NUMA node id: ",
numa_node_id);
return poolGetter(cpu_pools_, PROTO_CPU, numa_node_id, num_workers_);
} else if (device_option.device_type() == PROTO_CUDA) {
return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
} else if (IsGPUDeviceType(device_type)) {
auto gpu_id = device_option.device_id();
CAFFE_ENFORCE(
gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
"Invalid GPU id: " + caffe2::to_string(gpu_id));
return poolGetter(gpu_pools_, PROTO_CUDA, gpu_id, num_workers_);
return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
} else {
CAFFE_THROW(
"Unsupported device type " +
caffe2::to_string(device_option.device_type()));
caffe2::to_string(device_type));
}
}
int AsyncNetBase::stream(int task_id) {
const auto& device_option = event(task_id).GetDeviceOption();
int stream_id = 0;
if (device_option.device_type() == PROTO_CUDA) {
if (IsGPUDeviceType(device_option.device_type())) {
int gpu_id = device_option.device_id();
CAFFE_ENFORCE_GE(gpu_id, 0, "Invalid gpu id: " + caffe2::to_string(gpu_id));
if ((unsigned)gpu_id >= getStreamCounters().size()) {

View File

@ -58,6 +58,14 @@ C10_EXPORT bool IsCPUDeviceType(int device_type) {
return cpu_types.count(device_type);
}
C10_EXPORT bool IsGPUDeviceType(int device_type) {
static const std::unordered_set<int> gpu_types{
PROTO_CUDA,
PROTO_HIP,
};
return gpu_types.count(device_type);
}
C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) {
std::ifstream ifs(filename, std::ios::in);
if (!ifs) {

View File

@ -31,6 +31,7 @@ CAFFE2_API int DeviceId(const DeviceOption& option);
CAFFE2_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
CAFFE2_API bool IsCPUDeviceType(int device_type);
CAFFE2_API bool IsGPUDeviceType(int device_type);
// Common interfaces that reads file contents into a string.
CAFFE2_API bool ReadStringFromFile(const char* filename, string* str);