Enable detectron on AMD GPU

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17862

Differential Revision: D14429234

Pulled By: bddppq

fbshipit-source-id: 5cb8750bd9db0ff8a179977d2bfbb180265cce81
This commit is contained in:
Sandeep Kumar
2019-03-12 16:22:41 -07:00
committed by Facebook Github Bot
parent 1cfb50334f
commit 6248266d91
15 changed files with 45 additions and 32 deletions

View File

@ -1,3 +1,4 @@
project(modules CXX C)
add_subdirectory(detectron)
add_subdirectory(module_test)
add_subdirectory(observers)

View File

@ -1,5 +1,6 @@
file(GLOB Detectron_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
file(GLOB Detectron_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.cu)
file(GLOB_RECURSE Detectron_HIP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.hip)
if (BUILD_CAFFE2_OPS)
if (USE_OPENMP AND OPENMP_FOUND)
@ -19,6 +20,16 @@ if (BUILD_CAFFE2_OPS)
if (MSVC)
install(FILES $<TARGET_PDB_FILE:caffe2_detectron_ops_gpu> DESTINATION lib OPTIONAL)
endif()
elseif(USE_ROCM)
hip_include_directories(${Caffe2_HIP_INCLUDES})
set_source_files_properties(${Detectron_HIP_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
HIP_ADD_LIBRARY(
caffe2_detectron_ops_hip SHARED
${Detectron_CPU_SRCS}
${Detectron_HIP_SRCS})
target_compile_options(caffe2_detectron_ops_hip PRIVATE ${HIP_CXX_FLAGS})
target_link_libraries(caffe2_detectron_ops_hip caffe2_hip)
install(TARGETS caffe2_detectron_ops_hip DESTINATION lib)
elseif(NOT IOS_PLATFORM)
add_library(caffe2_detectron_ops SHARED ${Detectron_CPU_SRCS})
target_link_libraries(caffe2_detectron_ops caffe2 ${OpenMP_link})

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "batch_permutation_op.h"
#include "modules/detectron/batch_permutation_op.h"
#include "caffe2/core/context_gpu.h"
namespace caffe2 {

View File

@ -17,7 +17,7 @@
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "group_spatial_softmax_op.h"
#include "modules/detectron/group_spatial_softmax_op.h"
namespace caffe2 {

View File

@ -73,7 +73,7 @@
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "ps_roi_pool_op.h"
#include "modules/detectron/ps_roi_pool_op.h"
namespace caffe2 {
@ -123,8 +123,8 @@ __global__ void PSRoIPoolForward(
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0
T roi_height = max(roi_end_h - roi_start_h, 0.1);
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); // avoid 0
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
@ -200,8 +200,8 @@ __global__ void PSRoIPoolBackward(
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
T roi_height = max(roi_end_h - roi_start_h, 0.1);
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); //avoid 0
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);

View File

@ -17,7 +17,7 @@
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "roi_pool_f_op.h"
#include "modules/detectron/roi_pool_f_op.h"
namespace caffe2 {

View File

@ -21,7 +21,7 @@ Y's output samples are the samples of X for which L > 0.
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "sample_as_op.h"
#include "modules/detectron/sample_as_op.h"
#include <stdio.h>

View File

@ -15,7 +15,7 @@
*/
#include "caffe2/core/context_gpu.h"
#include "select_smooth_l1_loss_op.h"
#include "modules/detectron/select_smooth_l1_loss_op.h"
namespace caffe2 {
@ -38,11 +38,11 @@ __global__ void SelectSmoothL1Kernel(
float y_hat = Y_hat[ind];
float y = Y[i * 4 + j];
float val = y_hat - y;
float abs_val = abs(val);
float abs_val = c10::cuda::compat::abs(val);
if (abs_val < beta) {
out[ind] = (0.5 * val * val / beta) / max(S[0], 1.0);
out[ind] = (0.5 * val * val / beta) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
} else {
out[ind] = (abs_val - 0.5 * beta) / max(S[0], 1.0);
out[ind] = (abs_val - 0.5 * beta) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
}
}
}
@ -75,11 +75,11 @@ __global__ void SelectSmoothL1GradientKernel(
float y_hat = Y_hat[ind];
float y = Y[i * 4 + j];
float val = y_hat - y;
float abs_val = abs(val);
float abs_val = c10::cuda::compat::abs(val);
if (abs_val < beta) {
out[ind] = norm * d_loss * val / beta / max(S[0], 1.0);
out[ind] = norm * d_loss * val / beta / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
} else {
out[ind] = norm * d_loss * ((float(0) < val) - (val < float(0))) / max(S[0], 1.0);
out[ind] = norm * d_loss * ((float(0) < val) - (val < float(0))) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
}
}
}

View File

@ -15,7 +15,7 @@
*/
#include "caffe2/core/context_gpu.h"
#include "sigmoid_cross_entropy_loss_op.h"
#include "modules/detectron/sigmoid_cross_entropy_loss_op.h"
namespace caffe2 {

View File

@ -17,7 +17,7 @@
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "sigmoid_focal_loss_op.h"
#include "modules/detectron/sigmoid_focal_loss_op.h"
namespace caffe2 {
@ -45,7 +45,7 @@ __global__ void SigmoidFocalLossKernel(
float c1 = (t == (d + 1));
float c2 = (t != -1 & t != (d + 1));
float Np = max(weight_pos[0], 1.0);
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
float zn = (1.0 - alpha) / Np;
float zp = alpha / Np;
@ -53,7 +53,7 @@ __global__ void SigmoidFocalLossKernel(
float p = 1. / (1. + expf(-logits[i]));
// (1 - p)**gamma * log(p) where
float term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
float term1 = powf((1. - p), gamma) * logf(c10::cuda::compat::max(p, FLT_MIN));
// p**gamma * log(1 - p)
float term2 =
powf(p, gamma) *
@ -82,7 +82,7 @@ __global__ void SigmoidFocalLossGradientKernel(
int a = c / num_classes; // current anchor
int d = c % num_classes; // current class
float Np = max(weight_pos[0], 1.0);
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
float zn = (1.0 - alpha) / Np;
float zp = alpha / Np;
int t = targets[n * (H * W * A) + a * (H * W) + y * W + x];
@ -94,7 +94,7 @@ __global__ void SigmoidFocalLossGradientKernel(
// (1-p)**g * (1 - p - g*p*log(p))
float term1 =
powf((1. - p), gamma) *
(1. - p - (p * gamma * logf(max(p, FLT_MIN))));
(1. - p - (p * gamma * logf(c10::cuda::compat::max(p, FLT_MIN))));
// (p**g) * (g*(1-p)*log(1-p) - p)
float term2 =
powf(p, gamma) *

View File

@ -15,7 +15,7 @@
*/
#include "caffe2/core/context_gpu.h"
#include "smooth_l1_loss_op.h"
#include "modules/detectron/smooth_l1_loss_op.h"
namespace caffe2 {
@ -27,7 +27,7 @@ __global__ void SmoothL1Kernel(
// |x| - 0.5 * beta otherwise
CUDA_1D_KERNEL_LOOP(index, n) {
T val = in[index];
T abs_val = abs(val);
T abs_val = c10::cuda::compat::abs(val);
if (abs_val < beta) {
out[index] = 0.5 * val * val / beta;
} else {
@ -49,7 +49,7 @@ __global__ void SmoothL1GradientKernel(
// We also scale by norm * d_loss in this kernel for convenience
CUDA_1D_KERNEL_LOOP(index, n) {
T val = in[index];
T abs_val = abs(val);
T abs_val = c10::cuda::compat::abs(val);
T d_loss = *d_loss_data;
if (abs_val < beta) {
out[index] = norm * d_loss * val / beta;

View File

@ -17,7 +17,7 @@
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "softmax_focal_loss_op.h"
#include "modules/detectron/softmax_focal_loss_op.h"
namespace caffe2 {
@ -69,7 +69,7 @@ __global__ void SoftmaxFocalLossKernel(
int n = i / (W * H * A);
const int label = static_cast<int>(targets[i]);
float Np = max(weight_pos[0], 1.0);
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
float z = (label == 0) * (1 - alpha) / Np +
(label >= 1) * alpha / Np;
@ -79,7 +79,7 @@ __global__ void SoftmaxFocalLossKernel(
int idx = n * (H * W * D) + (offset + label) * (H * W) + y * W + x;
losses[i] =
-(pow(1.0f - Pdata[idx], gamma) *
log(max(Pdata[idx], FLT_MIN))) * z;
log(c10::cuda::compat::max(Pdata[idx], FLT_MIN))) * z;
}
}
}
@ -97,7 +97,7 @@ __global__ void SoftmaxFocalLossGradientWeightKernel(
int a = (i / (W * H)) % A;
int n = i / (W * H * A);
const int label = static_cast<int>(targets[i]);
float Np = max(weight_pos[0], 1.0);
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
float z = (label == 0) * (1 - alpha) / Np +
(label >= 1) * alpha / Np;
@ -109,7 +109,7 @@ __global__ void SoftmaxFocalLossGradientWeightKernel(
float p = Pdata[idx];
buff[i] =
(-pow(onemp, gamma) +
gamma * pow(onemp, gamma - 1) * p * log(max(p, FLT_MIN))) * z;
gamma * pow(onemp, gamma - 1) * p * log(c10::cuda::compat::max(p, FLT_MIN))) * z;
}
}
}

View File

@ -16,7 +16,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/operator.h"
#include "spatial_narrow_as_op.h"
#include "modules/detectron/spatial_narrow_as_op.h"
namespace caffe2 {

View File

@ -58,7 +58,7 @@
#include "caffe2/core/context_gpu.h"
#include "upsample_nearest_op.h"
#include "modules/detectron/upsample_nearest_op.h"
namespace caffe2 {

View File

@ -63,6 +63,7 @@ includes = [
"caffe2/utils/*",
"c10/cuda/*",
"c10/cuda/test/CMakeLists.txt",
"modules/*",
# PyTorch paths
# Keep this synchronized with is_pytorch_file in hipify_python.py
"aten/src/ATen/cuda/*",