mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable detectron on AMD GPU
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17862 Differential Revision: D14429234 Pulled By: bddppq fbshipit-source-id: 5cb8750bd9db0ff8a179977d2bfbb180265cce81
This commit is contained in:
committed by
Facebook Github Bot
parent
1cfb50334f
commit
6248266d91
@ -1,3 +1,4 @@
|
||||
project(modules CXX C)
|
||||
add_subdirectory(detectron)
|
||||
add_subdirectory(module_test)
|
||||
add_subdirectory(observers)
|
||||
|
@ -1,5 +1,6 @@
|
||||
file(GLOB Detectron_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
file(GLOB Detectron_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.cu)
|
||||
file(GLOB_RECURSE Detectron_HIP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.hip)
|
||||
|
||||
if (BUILD_CAFFE2_OPS)
|
||||
if (USE_OPENMP AND OPENMP_FOUND)
|
||||
@ -19,6 +20,16 @@ if (BUILD_CAFFE2_OPS)
|
||||
if (MSVC)
|
||||
install(FILES $<TARGET_PDB_FILE:caffe2_detectron_ops_gpu> DESTINATION lib OPTIONAL)
|
||||
endif()
|
||||
elseif(USE_ROCM)
|
||||
hip_include_directories(${Caffe2_HIP_INCLUDES})
|
||||
set_source_files_properties(${Detectron_HIP_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
HIP_ADD_LIBRARY(
|
||||
caffe2_detectron_ops_hip SHARED
|
||||
${Detectron_CPU_SRCS}
|
||||
${Detectron_HIP_SRCS})
|
||||
target_compile_options(caffe2_detectron_ops_hip PRIVATE ${HIP_CXX_FLAGS})
|
||||
target_link_libraries(caffe2_detectron_ops_hip caffe2_hip)
|
||||
install(TARGETS caffe2_detectron_ops_hip DESTINATION lib)
|
||||
elseif(NOT IOS_PLATFORM)
|
||||
add_library(caffe2_detectron_ops SHARED ${Detectron_CPU_SRCS})
|
||||
target_link_libraries(caffe2_detectron_ops caffe2 ${OpenMP_link})
|
||||
|
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "batch_permutation_op.h"
|
||||
#include "modules/detectron/batch_permutation_op.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
@ -17,7 +17,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "group_spatial_softmax_op.h"
|
||||
#include "modules/detectron/group_spatial_softmax_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -73,7 +73,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "ps_roi_pool_op.h"
|
||||
#include "modules/detectron/ps_roi_pool_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -123,8 +123,8 @@ __global__ void PSRoIPoolForward(
|
||||
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
|
||||
|
||||
// Force too small ROIs to be 1x1
|
||||
T roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0
|
||||
T roi_height = max(roi_end_h - roi_start_h, 0.1);
|
||||
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); // avoid 0
|
||||
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
|
||||
|
||||
// Compute w and h at bottom
|
||||
T bin_size_h = roi_height / static_cast<T>(pooled_height);
|
||||
@ -200,8 +200,8 @@ __global__ void PSRoIPoolBackward(
|
||||
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
|
||||
|
||||
// Force too small ROIs to be 1x1
|
||||
T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
|
||||
T roi_height = max(roi_end_h - roi_start_h, 0.1);
|
||||
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); //avoid 0
|
||||
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
|
||||
|
||||
// Compute w and h at bottom
|
||||
T bin_size_h = roi_height / static_cast<T>(pooled_height);
|
||||
|
@ -17,7 +17,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "roi_pool_f_op.h"
|
||||
#include "modules/detectron/roi_pool_f_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -21,7 +21,7 @@ Y's output samples are the samples of X for which L > 0.
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "sample_as_op.h"
|
||||
#include "modules/detectron/sample_as_op.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "select_smooth_l1_loss_op.h"
|
||||
#include "modules/detectron/select_smooth_l1_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -38,11 +38,11 @@ __global__ void SelectSmoothL1Kernel(
|
||||
float y_hat = Y_hat[ind];
|
||||
float y = Y[i * 4 + j];
|
||||
float val = y_hat - y;
|
||||
float abs_val = abs(val);
|
||||
float abs_val = c10::cuda::compat::abs(val);
|
||||
if (abs_val < beta) {
|
||||
out[ind] = (0.5 * val * val / beta) / max(S[0], 1.0);
|
||||
out[ind] = (0.5 * val * val / beta) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
|
||||
} else {
|
||||
out[ind] = (abs_val - 0.5 * beta) / max(S[0], 1.0);
|
||||
out[ind] = (abs_val - 0.5 * beta) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -75,11 +75,11 @@ __global__ void SelectSmoothL1GradientKernel(
|
||||
float y_hat = Y_hat[ind];
|
||||
float y = Y[i * 4 + j];
|
||||
float val = y_hat - y;
|
||||
float abs_val = abs(val);
|
||||
float abs_val = c10::cuda::compat::abs(val);
|
||||
if (abs_val < beta) {
|
||||
out[ind] = norm * d_loss * val / beta / max(S[0], 1.0);
|
||||
out[ind] = norm * d_loss * val / beta / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
|
||||
} else {
|
||||
out[ind] = norm * d_loss * ((float(0) < val) - (val < float(0))) / max(S[0], 1.0);
|
||||
out[ind] = norm * d_loss * ((float(0) < val) - (val < float(0))) / c10::cuda::compat::max(S[0], static_cast<float>(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "sigmoid_cross_entropy_loss_op.h"
|
||||
#include "modules/detectron/sigmoid_cross_entropy_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "sigmoid_focal_loss_op.h"
|
||||
#include "modules/detectron/sigmoid_focal_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -45,7 +45,7 @@ __global__ void SigmoidFocalLossKernel(
|
||||
float c1 = (t == (d + 1));
|
||||
float c2 = (t != -1 & t != (d + 1));
|
||||
|
||||
float Np = max(weight_pos[0], 1.0);
|
||||
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
|
||||
float zn = (1.0 - alpha) / Np;
|
||||
float zp = alpha / Np;
|
||||
|
||||
@ -53,7 +53,7 @@ __global__ void SigmoidFocalLossKernel(
|
||||
float p = 1. / (1. + expf(-logits[i]));
|
||||
|
||||
// (1 - p)**gamma * log(p) where
|
||||
float term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
|
||||
float term1 = powf((1. - p), gamma) * logf(c10::cuda::compat::max(p, FLT_MIN));
|
||||
// p**gamma * log(1 - p)
|
||||
float term2 =
|
||||
powf(p, gamma) *
|
||||
@ -82,7 +82,7 @@ __global__ void SigmoidFocalLossGradientKernel(
|
||||
int a = c / num_classes; // current anchor
|
||||
int d = c % num_classes; // current class
|
||||
|
||||
float Np = max(weight_pos[0], 1.0);
|
||||
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
|
||||
float zn = (1.0 - alpha) / Np;
|
||||
float zp = alpha / Np;
|
||||
int t = targets[n * (H * W * A) + a * (H * W) + y * W + x];
|
||||
@ -94,7 +94,7 @@ __global__ void SigmoidFocalLossGradientKernel(
|
||||
// (1-p)**g * (1 - p - g*p*log(p))
|
||||
float term1 =
|
||||
powf((1. - p), gamma) *
|
||||
(1. - p - (p * gamma * logf(max(p, FLT_MIN))));
|
||||
(1. - p - (p * gamma * logf(c10::cuda::compat::max(p, FLT_MIN))));
|
||||
// (p**g) * (g*(1-p)*log(1-p) - p)
|
||||
float term2 =
|
||||
powf(p, gamma) *
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "smooth_l1_loss_op.h"
|
||||
#include "modules/detectron/smooth_l1_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -27,7 +27,7 @@ __global__ void SmoothL1Kernel(
|
||||
// |x| - 0.5 * beta otherwise
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
T val = in[index];
|
||||
T abs_val = abs(val);
|
||||
T abs_val = c10::cuda::compat::abs(val);
|
||||
if (abs_val < beta) {
|
||||
out[index] = 0.5 * val * val / beta;
|
||||
} else {
|
||||
@ -49,7 +49,7 @@ __global__ void SmoothL1GradientKernel(
|
||||
// We also scale by norm * d_loss in this kernel for convenience
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
T val = in[index];
|
||||
T abs_val = abs(val);
|
||||
T abs_val = c10::cuda::compat::abs(val);
|
||||
T d_loss = *d_loss_data;
|
||||
if (abs_val < beta) {
|
||||
out[index] = norm * d_loss * val / beta;
|
||||
|
@ -17,7 +17,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "softmax_focal_loss_op.h"
|
||||
#include "modules/detectron/softmax_focal_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -69,7 +69,7 @@ __global__ void SoftmaxFocalLossKernel(
|
||||
int n = i / (W * H * A);
|
||||
const int label = static_cast<int>(targets[i]);
|
||||
|
||||
float Np = max(weight_pos[0], 1.0);
|
||||
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
|
||||
float z = (label == 0) * (1 - alpha) / Np +
|
||||
(label >= 1) * alpha / Np;
|
||||
|
||||
@ -79,7 +79,7 @@ __global__ void SoftmaxFocalLossKernel(
|
||||
int idx = n * (H * W * D) + (offset + label) * (H * W) + y * W + x;
|
||||
losses[i] =
|
||||
-(pow(1.0f - Pdata[idx], gamma) *
|
||||
log(max(Pdata[idx], FLT_MIN))) * z;
|
||||
log(c10::cuda::compat::max(Pdata[idx], FLT_MIN))) * z;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -97,7 +97,7 @@ __global__ void SoftmaxFocalLossGradientWeightKernel(
|
||||
int a = (i / (W * H)) % A;
|
||||
int n = i / (W * H * A);
|
||||
const int label = static_cast<int>(targets[i]);
|
||||
float Np = max(weight_pos[0], 1.0);
|
||||
float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
|
||||
float z = (label == 0) * (1 - alpha) / Np +
|
||||
(label >= 1) * alpha / Np;
|
||||
|
||||
@ -109,7 +109,7 @@ __global__ void SoftmaxFocalLossGradientWeightKernel(
|
||||
float p = Pdata[idx];
|
||||
buff[i] =
|
||||
(-pow(onemp, gamma) +
|
||||
gamma * pow(onemp, gamma - 1) * p * log(max(p, FLT_MIN))) * z;
|
||||
gamma * pow(onemp, gamma - 1) * p * log(c10::cuda::compat::max(p, FLT_MIN))) * z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "spatial_narrow_as_op.h"
|
||||
#include "modules/detectron/spatial_narrow_as_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -58,7 +58,7 @@
|
||||
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "upsample_nearest_op.h"
|
||||
#include "modules/detectron/upsample_nearest_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -63,6 +63,7 @@ includes = [
|
||||
"caffe2/utils/*",
|
||||
"c10/cuda/*",
|
||||
"c10/cuda/test/CMakeLists.txt",
|
||||
"modules/*",
|
||||
# PyTorch paths
|
||||
# Keep this synchronized with is_pytorch_file in hipify_python.py
|
||||
"aten/src/ATen/cuda/*",
|
||||
|
Reference in New Issue
Block a user