mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
HIP Operators Generator--> HipOpG (#9322)
Summary: The goal of this PR is to add an infrastructure; to convert(hipify) CUDA ops into [HIP](https://github.com/ROCm-Developer-Tools/HIP) ops , at **compile** time. Note that HIP ops, which are portable c++ code, can run on AMD and NVIDIA platform. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9322 Differential Revision: D8884707 Pulled By: bddppq fbshipit-source-id: dabc6319546002c308c10528238e6684f7aef0f8
This commit is contained in:
committed by
Facebook Github Bot
parent
45f0d05202
commit
54db14e390
@ -155,6 +155,9 @@ if [[ $BUILD_ENVIRONMENT == *rocm* ]]; then
|
||||
export LANG=C.UTF-8
|
||||
export LC_ALL=C.UTF-8
|
||||
export HCC_AMDGPU_TARGET=gfx900
|
||||
|
||||
########## HIPIFY Caffe2 operators
|
||||
${PYTHON} "${ROOT_DIR}/tools/amd_build/build_caffe2_amd.py"
|
||||
fi
|
||||
|
||||
# Try to include Redis support for Linux builds
|
||||
@ -195,6 +198,7 @@ else
|
||||
fi
|
||||
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Configure and make
|
||||
###############################################################################
|
||||
|
@ -462,7 +462,8 @@ if(BUILD_CAFFE2)
|
||||
endif()
|
||||
|
||||
if(USE_ROCM)
|
||||
add_library(caffe2_pybind11_state_hip MODULE ${Caffe2_HIP_PYTHON_SRCS})
|
||||
hip_add_library(caffe2_pybind11_state_hip MODULE ${Caffe2_HIP_PYTHON_SRCS})
|
||||
set_target_properties(caffe2_pybind11_state_hip PROPERTIES LINKER_LANGUAGE HIP)
|
||||
set_target_properties(caffe2_pybind11_state_hip PROPERTIES COMPILE_FLAGS "${HIP_HIPCC_FLAGS} -fvisibility=hidden")
|
||||
set_target_properties(caffe2_pybind11_state_hip PROPERTIES PREFIX "")
|
||||
set_target_properties(caffe2_pybind11_state_hip PROPERTIES SUFFIX ${PY_EXT_SUFFIX})
|
||||
|
@ -49,7 +49,7 @@ TEST(EnginePrefTest, GPUDeviceDefaultPreferredEngines)
|
||||
{
|
||||
const auto op = CreateOperator(op_def, &ws);
|
||||
EXPECT_NE(nullptr, op.get());
|
||||
EXPECT_EQ(static_cast<JustTest*>(op.get())->type(), "HIP");
|
||||
EXPECT_EQ(static_cast<JustTest*>(op.get())->type(), "MIOPEN");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,8 +70,8 @@ PerOpEnginePrefType& g_per_op_engine_pref() {
|
||||
}
|
||||
|
||||
GlobalEnginePrefType& g_global_engine_pref() {
|
||||
static auto* g_global_engine_pref_ =
|
||||
new GlobalEnginePrefType{{DeviceType::CUDA, {"CUDNN"}}};
|
||||
static auto* g_global_engine_pref_ = new GlobalEnginePrefType{
|
||||
{DeviceType::CUDA, {"CUDNN"}}, {DeviceType::HIP, {"MIOPEN"}}};
|
||||
return *g_global_engine_pref_;
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "file_store_handler_op.h"
|
||||
#include "caffe2/distributed/file_store_handler_op.h"
|
||||
|
||||
#include <caffe2/core/context_gpu.h>
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "redis_store_handler_op.h"
|
||||
#include "caffe2/distributed/redis_store_handler_op.h"
|
||||
|
||||
#include <caffe2/core/context_gpu.h>
|
||||
|
||||
|
@ -11,6 +11,14 @@ if(USE_OPENCV AND OpenCV_FOUND)
|
||||
file(GLOB tmp *_test.cc)
|
||||
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
|
||||
|
||||
# ---[ HIP files
|
||||
# ------[ general hip
|
||||
file(GLOB_RECURSE tmp *_hip.cc)
|
||||
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
|
||||
# exclude test files
|
||||
file(GLOB_RECURSE tmp *_test.cc)
|
||||
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
|
||||
|
||||
# ---[ CPU files.
|
||||
file(GLOB tmp *.cc)
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
|
||||
@ -18,21 +26,29 @@ if(USE_OPENCV AND OpenCV_FOUND)
|
||||
file(GLOB tmp *_test.cc)
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS})
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_HIP_SRCS})
|
||||
|
||||
# ---[ GPU test files
|
||||
file(GLOB tmp *_gpu_test.cc)
|
||||
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
|
||||
|
||||
# ---[ HIP test files
|
||||
file(GLOB_RECURSE tmp *_hip_test.cc)
|
||||
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
|
||||
|
||||
# ---[ CPU test files
|
||||
file(GLOB tmp *_test.cc)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
|
||||
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS})
|
||||
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_HIP_TEST_SRCS})
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} PARENT_SCOPE)
|
||||
else()
|
||||
message(STATUS "Excluding image processing operators due to no opencv")
|
||||
endif()
|
||||
|
@ -6,13 +6,12 @@
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void BooleanMaskCopyKernel(
|
||||
const TIndex numOfOutput,
|
||||
const TIndex numBytes,
|
||||
const TIndex* indices,
|
||||
const T* src,
|
||||
T* dest) {
|
||||
const uint8_t* src,
|
||||
uint8_t* dest) {
|
||||
for (TIndex i = blockIdx.x; i < numOfOutput; i += gridDim.x) {
|
||||
const auto srcBase = indices[i] * numBytes;
|
||||
const auto destBase = i * numBytes;
|
||||
@ -81,8 +80,8 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
|
||||
std::vector<TIndex> dims = src.dims();
|
||||
dims[0] = numOfOutput;
|
||||
dest->Resize(dims);
|
||||
auto* destData = (char*)dest->raw_mutable_data(src.meta());
|
||||
const auto* srcData = (char*)src.raw_data();
|
||||
auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta());
|
||||
const auto* srcData = (uint8_t*)src.raw_data();
|
||||
if (OutputSize() == 2) {
|
||||
auto* indicesOut = Output(1);
|
||||
indicesOut->Resize(numOfOutput);
|
||||
|
@ -27,7 +27,7 @@ __global__ void FillValuesKernel(
|
||||
const size_t itemSize,
|
||||
const int* indices,
|
||||
char* const values[],
|
||||
int valueSizes[],
|
||||
int* valueSizes,
|
||||
char* dest) {
|
||||
CUDA_1D_KERNEL_LOOP(j, numMasks) {
|
||||
int k = 0;
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "channel_shuffle_op.h"
|
||||
#include "caffe2/operators/channel_shuffle_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "conv_op_shared.h"
|
||||
#include "caffe2/operators/conv_op_shared.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -9,7 +9,7 @@ __global__ void CECKernel(
|
||||
const int N, const float* S, const int* Y, const float margin,
|
||||
float* output) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
output[i] = Y[i] == 1 ? (1. - S[i]) : max(0.f, S[i] - margin);
|
||||
output[i] = Y[i] == 1 ? (1. - S[i]) : fmaxf(0.f, S[i] - margin);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "counter_ops.h"
|
||||
#include "caffe2/operators/counter_ops.h"
|
||||
|
||||
namespace caffe2 {
|
||||
REGISTER_CUDA_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CUDAContext>);
|
||||
|
@ -131,9 +131,9 @@ __global__ void L1DistanceKernel(
|
||||
for (int i = blockIdx.x; i < N; i += gridDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int j = threadIdx.x; j < D; j += blockDim.x) {
|
||||
sum +=
|
||||
abs(convert::To<T, float>(X[i * D + j]) -
|
||||
convert::To<T, float>(Y[i * D + j]));
|
||||
sum += fabsf(
|
||||
convert::To<T, float>(X[i * D + j]) -
|
||||
convert::To<T, float>(Y[i * D + j]));
|
||||
}
|
||||
|
||||
float aggregate = BlockReduce(temp_storage).Sum(sum);
|
||||
@ -395,33 +395,33 @@ bool CosineSimilarityGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
context_.cuda_stream()>>>(N, D, X_data, Y_data, xy);
|
||||
math::Div<float, CUDAContext>(N, dCos_data, xyn, scale, &context_);
|
||||
// dX
|
||||
BatchedMul<<<
|
||||
BatchedMul<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(N, D, Y_data, scale, dX_data);
|
||||
Scale2AxpyScale<<<
|
||||
Scale2AxpyScale<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(N, scale, xy, xn, axpy_scale);
|
||||
BatchedAxpy<<<
|
||||
BatchedAxpy<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(N, D, axpy_scale, X_data, dX_data);
|
||||
// dY
|
||||
BatchedMul<<<
|
||||
BatchedMul<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(N, D, X_data, scale, dY_data);
|
||||
Scale2AxpyScale<<<
|
||||
Scale2AxpyScale<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(N, scale, xy, yn, axpy_scale);
|
||||
BatchedAxpy<<<
|
||||
BatchedAxpy<float><<<
|
||||
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include "elementwise_linear_op.h"
|
||||
#include "caffe2/operators/elementwise_linear_op.h"
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/operator_fallback_gpu.h"
|
||||
|
@ -8,6 +8,11 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/conversions.h"
|
||||
|
||||
#ifdef __HIPCC__
|
||||
// rocblas doesn't fully support fp16 yet
|
||||
#define ROCBLAS_FP16 0
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
@ -111,6 +116,9 @@ void device_reduce<float16>(
|
||||
int N,
|
||||
Tensor<CUDAContext>* buffer,
|
||||
CUDAContext* context) {
|
||||
#if defined(__HIPCC__) && !ROCBLAS_FP16
|
||||
CAFFE_THROW("HIP rocblas doesn't fully support fp16 device_reduce yet.");
|
||||
#else
|
||||
auto buffer_size = 1;
|
||||
|
||||
if (buffer->size() != buffer_size) {
|
||||
@ -135,6 +143,7 @@ void device_reduce<float16>(
|
||||
out,
|
||||
CUDA_R_16F,
|
||||
CUDA_R_32F));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_THREADS>
|
||||
|
@ -6,7 +6,7 @@
|
||||
// This is a stand-alone op: Y = gamma * (X - mu) / sig + beta
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#include "group_norm_op.h"
|
||||
#include "caffe2/operators/group_norm_op.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "gru_unit_op.h"
|
||||
#include "caffe2/operators/gru_unit_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -1,114 +0,0 @@
|
||||
#ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
#define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/hip/context_hip.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
/**
|
||||
* @brief A templated class to allow one to wrap a CPU operator as a CUDA
|
||||
* operator.
|
||||
*
|
||||
* This class can be used when one does not have the CUDA implementation ready
|
||||
* yet for an operator. Essentially, what this op does is to automatically
|
||||
* deal with data copy for you. Plausibly, this causes a lot of overhead and
|
||||
* is not optimal, so you should use this operator mostly for quick prototyping
|
||||
* purpose.
|
||||
*
|
||||
* All the input and output of the original operator should be TensorCPU.
|
||||
*
|
||||
* Example usage: if you have a class MyMagicOp that is CPU based, and you use
|
||||
* the registration code
|
||||
* REGISTER_CPU_OPERATOR(MyMagic, MyMagicOp);
|
||||
* to register the CPU side, you can create its corresponding GPU operator
|
||||
* (with performance hits of course) via
|
||||
* REGISTER_HIP_OPERATOR(MyMagic,
|
||||
* GPUFallbackOp<MyMagicOp>);
|
||||
*
|
||||
* Advanced usage: if you want to have some specific outputs never copied, you
|
||||
* can use the SkipOutputCopy template argument to do that. For example, if
|
||||
* MyMagic produces two outputs and the first output is always going to live on
|
||||
* the CPU, you can do
|
||||
* REGISTER_HIP_OPERATOR(MyMagic,
|
||||
* GPUFallbackOp<MyMagicOp, SkipIndices<0>>);
|
||||
*/
|
||||
template <class CPUOp, typename SkipOutputCopy = SkipIndices<>>
|
||||
class GPUFallbackOp final : public Operator<HIPContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(HIPContext);
|
||||
GPUFallbackOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<HIPContext>(def, ws) {
|
||||
CAFFE_ENFORCE_EQ(def.device_option().device_type(), HIP);
|
||||
OperatorDef base_def_(def);
|
||||
// base_def_ runs on CPU, so we will set its device option to CPU.
|
||||
base_def_.clear_device_option();
|
||||
base_def_.mutable_device_option()->set_device_type(CPU);
|
||||
// Set up the symbols for the local workspace.
|
||||
for (const string& name : def.input()) {
|
||||
local_input_blobs_.push_back(local_ws_.CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
base_op_.reset(new CPUOp(base_def_, &local_ws_));
|
||||
for (const string& name : def.output()) {
|
||||
local_output_blobs_.push_back(local_ws_.GetBlob(name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
bool need_sync = false;
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
if (OperatorBase::InputIsType<TensorHIP>(i)) {
|
||||
local_input_blobs_[i]->template GetMutable<TensorCPU>()->CopyFrom(
|
||||
Input(i), &context_);
|
||||
need_sync = true;
|
||||
} else {
|
||||
VLOG(1) << "Input " << i << " is not TensorHIP. Skipping copy.";
|
||||
// Note(jiayq): This removes a const but conceptually
|
||||
// local_input_blobs will only be used as const blob input for the
|
||||
// base op so we are still fine.
|
||||
local_input_blobs_[i]->ShareExternal(
|
||||
const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
|
||||
OperatorBase::Inputs()[i]->meta());
|
||||
}
|
||||
}
|
||||
|
||||
// Sync to make sure copies are done.
|
||||
if (need_sync) {
|
||||
context_.FinishDeviceComputation();
|
||||
}
|
||||
|
||||
if (!base_op_->Run()) {
|
||||
LOG(ERROR) << "Base op run failed in GPUFallbackOp. Def: "
|
||||
<< ProtoDebugString(this->debug_def());
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < OutputSize(); ++i) {
|
||||
if (SkipOutputCopy::Contains(i)) {
|
||||
VLOG(1) << "Copy output: index " << i << " skipped.";
|
||||
continue;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
local_output_blobs_[i]->template IsType<TensorCPU>(),
|
||||
"GPU fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
Output(i)->CopyFrom(
|
||||
local_output_blobs_[i]->template Get<TensorCPU>(), &context_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
Workspace local_ws_;
|
||||
vector<Blob*> local_input_blobs_;
|
||||
vector<Blob*> local_output_blobs_;
|
||||
std::unique_ptr<CPUOp> base_op_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
@ -1,80 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/operators/hip/operator_fallback_hip.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class IncrementByOneOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
IncrementByOneOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<CPUContext>(def, ws) {}
|
||||
bool RunOnDevice() {
|
||||
const auto& in = Input(0);
|
||||
auto* out = Output(0);
|
||||
out->ResizeLike(in);
|
||||
const float* in_data = in.template data<float>();
|
||||
float* out_data = out->template mutable_data<float>();
|
||||
for (int i = 0; i < in.size(); ++i) {
|
||||
out_data[i] = in_data[i] + 1.f;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
OPERATOR_SCHEMA(IncrementByOne)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowInplace({{0, 0}});
|
||||
|
||||
REGISTER_CPU_OPERATOR(IncrementByOne, IncrementByOneOp);
|
||||
REGISTER_HIP_OPERATOR(IncrementByOne, GPUFallbackOp<IncrementByOneOp>);
|
||||
|
||||
TEST(OperatorFallbackTest, IncrementByOneOp) {
|
||||
OperatorDef op_def = CreateOperatorDef(
|
||||
"IncrementByOne", "", vector<string>{"X"}, vector<string>{"X"});
|
||||
Workspace ws;
|
||||
TensorCPU source_tensor(vector<TIndex>{2, 3});
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
source_tensor.mutable_data<float>()[i] = i;
|
||||
}
|
||||
ws.CreateBlob("X")->GetMutable<TensorCPU>()->CopyFrom(source_tensor);
|
||||
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
|
||||
EXPECT_TRUE(op.get() != nullptr);
|
||||
EXPECT_TRUE(op->Run());
|
||||
const TensorCPU& output = ws.GetBlob("X")->Get<TensorCPU>();
|
||||
EXPECT_EQ(output.ndim(), 2);
|
||||
EXPECT_EQ(output.dim(0), 2);
|
||||
EXPECT_EQ(output.dim(1), 3);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(output.data<float>()[i], i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OperatorFallbackTest, GPUIncrementByOneOp) {
|
||||
if (!HasHipGPU())
|
||||
return;
|
||||
OperatorDef op_def = CreateOperatorDef(
|
||||
"IncrementByOne", "", vector<string>{"X"}, vector<string>{"X"});
|
||||
op_def.mutable_device_option()->set_device_type(HIP);
|
||||
Workspace ws;
|
||||
TensorCPU source_tensor(vector<TIndex>{2, 3});
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
source_tensor.mutable_data<float>()[i] = i;
|
||||
}
|
||||
ws.CreateBlob("X")->GetMutable<TensorHIP>()->CopyFrom(source_tensor);
|
||||
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
|
||||
EXPECT_TRUE(op.get() != nullptr);
|
||||
EXPECT_TRUE(op->Run());
|
||||
const TensorHIP& output = ws.GetBlob("X")->Get<TensorHIP>();
|
||||
TensorCPU output_cpu(output);
|
||||
EXPECT_EQ(output.ndim(), 2);
|
||||
EXPECT_EQ(output.dim(0), 2);
|
||||
EXPECT_EQ(output.dim(1), 3);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(output_cpu.data<float>()[i], i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
@ -51,7 +51,7 @@ __global__ void InstanceNormInvStdevKernel(
|
||||
}
|
||||
inv_stdev_data[i] /= dim;
|
||||
inv_stdev_data[i] += epsilon;
|
||||
inv_stdev_data[i] = 1.0 / std::sqrt(inv_stdev_data[i]);
|
||||
inv_stdev_data[i] = 1.0 / sqrtf(inv_stdev_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "integral_image_op.h"
|
||||
#include "caffe2/operators/integral_image_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -116,7 +116,7 @@ bool LayerNormOp<CUDAContext>::DoRunWithType<float>() {
|
||||
mean->CopyFrom(input);
|
||||
mean->Resize(stats_dims);
|
||||
math::Set<float, CUDAContext>(
|
||||
left, std::sqrt(epsilon_), stdev->mutable_data<float>(), &context_);
|
||||
left, sqrtf(epsilon_), stdev->mutable_data<float>(), &context_);
|
||||
} else {
|
||||
// Calculate row-wise means
|
||||
// First stage: sum up feature vectors
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "lstm_unit_op.h"
|
||||
#include "caffe2/operators/lstm_unit_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "caffe2/operators/max_pool_with_index.h"
|
||||
#include "caffe2/operators/max_pool_with_index_gpu.h"
|
||||
#include "caffe2/utils/conversions.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
@ -1,5 +1,4 @@
|
||||
#ifndef CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
|
||||
#define CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
|
||||
#pragma once
|
||||
|
||||
#include <cfloat>
|
||||
#include "caffe2/core/context.h"
|
||||
@ -45,5 +44,3 @@ class MaxPoolWithIndexGradientOp final : public ConvPoolOpBase<CUDAContext> {
|
||||
};
|
||||
|
||||
}; // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
|
@ -27,7 +27,7 @@ __global__ void NormalizeKernel(
|
||||
float reduce_result = BlockReduce(temp_storage).Sum(sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
norm = sqrt(reduce_result);
|
||||
norm = sqrtf(reduce_result);
|
||||
}
|
||||
__syncthreads();
|
||||
if (norm != 0) {
|
||||
@ -66,8 +66,8 @@ __global__ void NormalizeGradientKernel(
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
row_sum = reduce_result;
|
||||
row_norm = sqrt(reduce_norm);
|
||||
row_norm_3 = pow(row_norm, 3);
|
||||
row_norm = sqrtf(reduce_norm);
|
||||
row_norm_3 = powf(row_norm, 3);
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
@ -131,7 +131,7 @@ __global__ void NormalizeL1Kernel(
|
||||
__shared__ float norm;
|
||||
for (int j = threadIdx.x; j < m; j += blockDim.x) {
|
||||
const auto x_ij = xData[base + j * sf];
|
||||
sum += abs(x_ij);
|
||||
sum += fabsf(x_ij);
|
||||
}
|
||||
float reduce_result = BlockReduce(temp_storage).Sum(sum);
|
||||
|
||||
|
@ -256,8 +256,8 @@ bool PiecewiseLinearTransformOp<float, CUDAContext>::TransformBinary() {
|
||||
X.data<float>(),
|
||||
Y->mutable_data<float>());
|
||||
} else {
|
||||
// don't want N*M threads, only N*M/2
|
||||
PieceWiseLinearTransformBinaryKernel2<<<
|
||||
// don't want N*M threads, only N*M/2
|
||||
CAFFE_GET_BLOCKS(X.size() / 2),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
|
@ -9,6 +9,10 @@ namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef __HIPCC__
|
||||
typedef __half2 half2;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void ReluCUDAKernel(const int N, const T* X, T* Y) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "resize_op.h"
|
||||
#include "caffe2/operators/resize_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "reverse_packed_segs_op.h"
|
||||
#include "caffe2/operators/reverse_packed_segs_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "roi_align_gradient_op.h"
|
||||
#include "caffe2/operators/roi_align_gradient_op.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "roi_align_op.h"
|
||||
#include "caffe2/operators/roi_align_op.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "roi_align_op.h"
|
||||
#include "caffe2/operators/roi_align_op.h"
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/flags.h"
|
||||
|
@ -3,7 +3,7 @@
|
||||
#endif // _MSC_VER
|
||||
#include <cmath>
|
||||
|
||||
#include "roi_align_rotated_gradient_op.h"
|
||||
#include "caffe2/operators/roi_align_rotated_gradient_op.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#endif // _MSC_VER
|
||||
#include <cmath>
|
||||
|
||||
#include "roi_align_rotated_op.h"
|
||||
#include "caffe2/operators/roi_align_rotated_op.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <cfloat>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "roi_pool_op.h"
|
||||
#include "caffe2/operators/roi_pool_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -33,7 +33,7 @@ bool SeluOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto* Y = Output(0);
|
||||
CAFFE_ENFORCE_GT(X.size(), 0);
|
||||
Y->ResizeLike(X);
|
||||
SeluKernel<<<
|
||||
SeluKernel<float><<<
|
||||
CAFFE_GET_BLOCKS(X.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
@ -50,7 +50,7 @@ bool SeluGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
CAFFE_ENFORCE_GT(Y.size(), 0);
|
||||
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
|
||||
dX->ResizeLike(Y);
|
||||
SeluGradientKernel<<<
|
||||
SeluGradientKernel<float><<<
|
||||
CAFFE_GET_BLOCKS(Y.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
|
@ -2,9 +2,9 @@
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "softmax_op.h"
|
||||
#include "softmax_with_loss_op.h"
|
||||
#include "spatial_softmax_with_loss_op.h"
|
||||
#include "caffe2/operators/softmax_op.h"
|
||||
#include "caffe2/operators/softmax_with_loss_op.h"
|
||||
#include "caffe2/operators/spatial_softmax_with_loss_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -70,7 +70,7 @@ __global__ void ProbCrossEntropyKernel(
|
||||
int idx = i * D + j;
|
||||
CUDA_KERNEL_ASSERT(labeldata[idx] >= 0);
|
||||
total_prob += labeldata[idx];
|
||||
sum += -logf(max(Pdata[idx], FLT_MIN)) * labeldata[idx] * weight;
|
||||
sum += -logf(fmaxf(Pdata[idx], FLT_MIN)) * labeldata[idx] * weight;
|
||||
}
|
||||
float tot = BlockReduce(temp_storage).Sum(sum);
|
||||
__syncthreads();
|
||||
@ -78,7 +78,7 @@ __global__ void ProbCrossEntropyKernel(
|
||||
if (threadIdx.x == 0) {
|
||||
Ydata[i] = tot;
|
||||
// Sanity check
|
||||
CUDA_KERNEL_ASSERT(abs(1.0 - total_prob_sum) < 1e-5f);
|
||||
CUDA_KERNEL_ASSERT(fabsf(1.0 - total_prob_sum) < 1e-5f);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -118,14 +118,14 @@ __global__ void SpatialSoftmaxKernel(
|
||||
float max_val = -FLT_MAX;
|
||||
for(int c = 0; c < D; ++c) {
|
||||
int idx = i * (H * W * D) + c * (H * W) + y * W + x;
|
||||
max_val = max(max_val, Xdata[idx]);
|
||||
max_val = fmaxf(max_val, Xdata[idx]);
|
||||
}
|
||||
|
||||
// Exponentiate
|
||||
float expsum = 0.0f;
|
||||
for(int c = 0; c < D; ++c) {
|
||||
int idx = i * (H * W * D) + c * (H * W) + y * W + x;
|
||||
float expx = exp(Xdata[idx] - max_val);
|
||||
float expx = expf(Xdata[idx] - max_val);
|
||||
Pdata[idx] = expx;
|
||||
expsum += expx;
|
||||
}
|
||||
@ -160,7 +160,7 @@ __global__ void SpatialCrossEntropyLossKernel(
|
||||
if (label != DONTCARE) {
|
||||
CUDA_KERNEL_ASSERT(label >= 0 && label < D);
|
||||
float weight = (weights == NULL ? 1.0 : weights[index]);
|
||||
loss_data[index] = -log(max(
|
||||
loss_data[index] = -logf(fmaxf(
|
||||
Pdata[i * W * H * D + label * W * H + y * W + x], 1e-20f)) * weight;
|
||||
weight_data[index] = weight;
|
||||
} else {
|
||||
@ -213,7 +213,7 @@ __global__ void SoftmaxNormalizeLogsKernel(
|
||||
float* out_log) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int n = index / D;
|
||||
out_log[index] = logits[index] - rowmax[n] - logf(max(scales[n], FLT_MIN));
|
||||
out_log[index] = logits[index] - rowmax[n] - logf(fmaxf(scales[n], FLT_MIN));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ bool SoftplusOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto* Y = Output(0);
|
||||
DCHECK_GT(X.size(), 0);
|
||||
Y->ResizeLike(X);
|
||||
SoftplusKernel<<<
|
||||
SoftplusKernel<float><<<
|
||||
CAFFE_GET_BLOCKS(X.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
@ -43,7 +43,7 @@ bool SoftplusGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
DCHECK_GT(Y.size(), 0);
|
||||
DCHECK_EQ(dY.size(), Y.size());
|
||||
dX->ResizeLike(Y);
|
||||
SoftplusGradientKernel<<<
|
||||
SoftplusGradientKernel<float><<<
|
||||
CAFFE_GET_BLOCKS(Y.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
|
@ -14,13 +14,26 @@ inline __host__ __device__ T SquareCUDA(const T x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T typed_abs(T x);
|
||||
|
||||
template <>
|
||||
inline __device__ float typed_abs(float x) {
|
||||
return fabsf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ double typed_abs(double x) {
|
||||
return fabs(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SoftsignCUDAKernel(const int N, const T* X, T* Y) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
Y[i] = __ldg(X + i) / (T(1) + abs(__ldg(X + i)));
|
||||
Y[i] = __ldg(X + i) / (T(1) + typed_abs(__ldg(X + i)));
|
||||
#else
|
||||
Y[i] = X[i] / (T(1) + abs(X[i]));
|
||||
Y[i] = X[i] / (T(1) + typed_abs(X[i]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@ -30,9 +43,9 @@ __global__ void
|
||||
SoftsignGradientCUDAKernel(const int N, const T* dY, const T* X, T* dX) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
dX[i] = __ldg(dY + i) / SquareCUDA(T(1) + abs(__ldg(X + i)));
|
||||
dX[i] = __ldg(dY + i) / SquareCUDA(T(1) + typed_abs(__ldg(X + i)));
|
||||
#else
|
||||
dX[i] = dY[i] / SquareCUDA(T(1) + abs(X[i]));
|
||||
dX[i] = dY[i] / SquareCUDA(T(1) + typed_abs(X[i]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "sparse_to_dense_op.h"
|
||||
#include "caffe2/operators/sparse_to_dense_op.h"
|
||||
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
@ -1,9 +1,3 @@
|
||||
#include <math.h>
|
||||
#include <cfloat>
|
||||
// TODO(jamesreed): I would use <cmath> here but std::isnan
|
||||
// and std::isinf are declared constexpr there and the nvidia
|
||||
// compiler throws an error because of it
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/flatten_op.h"
|
||||
#include "caffe2/operators/minmax_ops.h"
|
||||
@ -169,7 +163,7 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() {
|
||||
std::cerr << "NaN idxs:" << std::endl;
|
||||
auto* cpu_X_data = cpu_X.data<float>();
|
||||
for (size_t i = 0; i < cpu_X.size(); ++i) {
|
||||
if (isnan(cpu_X_data[i]) || isinf(cpu_X_data[i])) {
|
||||
if (std::isnan(cpu_X_data[i]) || std::isinf(cpu_X_data[i])) {
|
||||
std::cerr << i << " ";
|
||||
}
|
||||
}
|
||||
@ -404,7 +398,7 @@ bool ScatterWeightedSumOp<float, CUDAContext>::DoRunWithType() {
|
||||
TIndex K = indices.size();
|
||||
TIndex block_size = M / N;
|
||||
|
||||
T* data = output->template mutable_data<T>();
|
||||
float* data = output->template mutable_data<float>();
|
||||
|
||||
// In order to have all device pointers of x_i (and weight_i similarly)
|
||||
// consecutively in device memory, copy pointers to a host vector and then
|
||||
|
@ -9,33 +9,31 @@ from caffe2.python import extension_loader
|
||||
# attempt to load the cpu version. The cpu backend is the minimum required, so
|
||||
# if that still fails, we will exit loud.
|
||||
with extension_loader.DlopenGuard():
|
||||
has_hip_support = False
|
||||
has_gpu_support = False
|
||||
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state_gpu import * # noqa
|
||||
if num_cuda_devices(): # noqa
|
||||
has_gpu_support = True
|
||||
else:
|
||||
has_gpu_support = False
|
||||
except ImportError as e:
|
||||
has_gpu_support = False
|
||||
except ImportError as gpu_e:
|
||||
logging.info('Failed to import cuda module: {}'.format(gpu_e))
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state_hip import * # noqa
|
||||
if num_hip_devices():
|
||||
has_hip_support = True
|
||||
logging.info('This caffe2 python run has AMD GPU support!')
|
||||
else:
|
||||
has_hip_support = False
|
||||
except ImportError as e:
|
||||
logging.info('Failed to import AMD hip module: {}'.format(e))
|
||||
except ImportError as hip_e:
|
||||
logging.info('Failed to import AMD hip module: {}'.format(hip_e))
|
||||
|
||||
logging.warning(
|
||||
'This caffe2 python run does not have GPU support. '
|
||||
'Will run in CPU only mode.')
|
||||
logging.warning('Debug message: {0}'.format(str(e)))
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state import * # noqa
|
||||
except ImportError as e:
|
||||
except ImportError as cpu_e:
|
||||
logging.critical(
|
||||
'Cannot load caffe2.python. Error: {0}'.format(str(e)))
|
||||
'Cannot load caffe2.python. Error: {0}'.format(str(cpu_e)))
|
||||
sys.exit(1)
|
||||
|
||||
# libcaffe2_python contains a global Workspace that we need to properly delete
|
||||
|
@ -251,7 +251,8 @@ def tensors1d(n, min_len=1, max_len=64, dtype=np.float32, elements=None):
|
||||
|
||||
cpu_do = caffe2_pb2.DeviceOption()
|
||||
gpu_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA)
|
||||
device_options = [cpu_do] + ([gpu_do] if workspace.has_gpu_support else [])
|
||||
hip_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.HIP)
|
||||
device_options = [cpu_do] + ([gpu_do] if workspace.has_gpu_support else []) + ([hip_do] if workspace.has_hip_support else [])
|
||||
# Include device option for each GPU
|
||||
expanded_device_options = [cpu_do] + (
|
||||
[caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA, cuda_gpu_id=i)
|
||||
|
@ -42,6 +42,7 @@ operator_tracebacks = defaultdict(dict)
|
||||
|
||||
is_asan = C.is_asan
|
||||
has_gpu_support = C.has_gpu_support
|
||||
has_hip_support = C.has_hip_support
|
||||
if has_gpu_support:
|
||||
NumCudaDevices = C.num_cuda_devices
|
||||
GetCUDAVersion = C.get_cuda_version
|
||||
|
@ -9,6 +9,14 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
|
||||
file(GLOB tmp *_test.cc)
|
||||
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
|
||||
|
||||
# ---[ HIP files
|
||||
# ------[ general GPU
|
||||
file(GLOB_RECURSE tmp *_hip.cc)
|
||||
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
|
||||
# exclude test files
|
||||
file(GLOB_RECURSE tmp *_test.cc)
|
||||
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
|
||||
|
||||
# ---[ CPU files.
|
||||
file(GLOB tmp *.cc)
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
|
||||
@ -16,18 +24,26 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
|
||||
file(GLOB tmp *_test.cc)
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS})
|
||||
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_HIP_SRCS})
|
||||
|
||||
# ---[ GPU test files
|
||||
file(GLOB tmp *_gpu_test.cc)
|
||||
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
|
||||
|
||||
# ---[ HIP test files
|
||||
file(GLOB_RECURSE tmp *_hip_test.cc)
|
||||
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
|
||||
|
||||
# ---[ CPU test files
|
||||
file(GLOB tmp *_test.cc)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
|
||||
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS})
|
||||
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_HIP_TEST_SRCS})
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} PARENT_SCOPE)
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include "adagrad_op.h"
|
||||
#include "caffe2/sgd/adagrad_op.h"
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/mixed_utils.h"
|
||||
@ -19,7 +19,7 @@ __global__ void AdagradUpdate(
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
float gi = g[i];
|
||||
float hi = nh[i] = decay * h[i] + gi * gi;
|
||||
nw[i] = w[i] + lr[0] * gi / (std::sqrt(hi) + epsilon);
|
||||
nw[i] = w[i] + lr[0] * gi / (sqrtf(hi) + epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ __global__ void SparseAdagradKernel(
|
||||
mixed_add(grad[gradIdx] * grad[gradIdx], param_mom[paramIdx]);
|
||||
mixed_store(&mom_new, &(param_mom[paramIdx]));
|
||||
float param_new = mixed_add(
|
||||
LR * grad[gradIdx] / (sqrt(mom_new) + epsilon), param[paramIdx]);
|
||||
LR * grad[gradIdx] / (sqrtf(mom_new) + epsilon), param[paramIdx]);
|
||||
mixed_store(¶m_new, &(param[paramIdx]));
|
||||
}
|
||||
}
|
||||
@ -107,7 +107,7 @@ __global__ void RowWiseSparseAdagradKernel(
|
||||
}
|
||||
__syncthreads();
|
||||
// update param
|
||||
float step = lr[0] / (std::sqrt(param_mom[index]) + epsilon);
|
||||
float step = lr[0] / (sqrtf(param_mom[index]) + epsilon);
|
||||
for (int j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
param[index * N + j] = param[index * N + j] + grad[i * N + j] * step;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "adam_op.h"
|
||||
#include "caffe2/sgd/adam_op.h"
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
@ -21,7 +21,7 @@ __global__ void AdamUpdate(
|
||||
float gi = g[i];
|
||||
float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
|
||||
float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
|
||||
ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
|
||||
ng[i] = lr[0] * correction * mi / (sqrtf(vi) + eps_hat);
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ __global__ void AdamCompute(
|
||||
float gi = g[i];
|
||||
float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
|
||||
float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
|
||||
float ng = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
|
||||
float ng = lr[0] * correction * mi / (sqrtf(vi) + eps_hat);
|
||||
nw[i] = w[i] + ng;
|
||||
}
|
||||
}
|
||||
@ -130,7 +130,7 @@ bool SparseAdamOp<float, CUDAContext>::DoRunWithType() {
|
||||
auto grad_slice_sz = Input(GRAD).size_from_dim(Input(INDICES).ndim());
|
||||
const auto iter =
|
||||
OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
|
||||
const float correction = std::sqrt(1.0f - std::pow(beta2_, iter + 1)) /
|
||||
const float correction = sqrtf(1.0f - std::pow(beta2_, iter + 1)) /
|
||||
(1.0f - std::pow(beta1_, iter + 1));
|
||||
|
||||
SparseAdamKernel<SIndex>
|
||||
|
@ -1,10 +1,16 @@
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
#include "fp16_momentum_sgd_op.h"
|
||||
#include "caffe2/sgd/fp16_momentum_sgd_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
#ifdef __HIPCC__
|
||||
typedef __half half;
|
||||
typedef __half2 half2;
|
||||
#endif
|
||||
|
||||
__global__ void FP16MomentumSGDKernel(
|
||||
int N,
|
||||
const half2* g,
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
#include "fp32_momentum_sgd_op.h"
|
||||
#include "caffe2/sgd/fp32_momentum_sgd_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "momentum_sgd_op.h"
|
||||
#include "caffe2/sgd/momentum_sgd_op.h"
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "rmsprop_op.h"
|
||||
#include "caffe2/sgd/rmsprop_op.h"
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
@ -21,7 +21,7 @@ __global__ void RmsPropUpdate(
|
||||
nms[i] = ms[i] + (1.0f - decay) * (g[i] * g[i] - ms[i]);
|
||||
// Update momentum estimate
|
||||
nmom[i] =
|
||||
mom[i] * momentum + lr[0] * g[i] / std::sqrt(epsilon + nms[i]);
|
||||
mom[i] * momentum + lr[0] * g[i] / sqrtf(epsilon + nms[i]);
|
||||
// New gradient is the momentum
|
||||
ng[i] = nmom[i];
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
116
caffe2/utils/mixed_utils_hip.h
Normal file
116
caffe2/utils/mixed_utils_hip.h
Normal file
@ -0,0 +1,116 @@
|
||||
// Copyright 2004-present Facebook. All Rights Reserved.
|
||||
#ifndef CAFFE2_UTILS_MIXED_UTILS_HIP_H
|
||||
#define CAFFE2_UTILS_MIXED_UTILS_HIP_H
|
||||
|
||||
#include "caffe2/core/hip/common_hip.h"
|
||||
#include "caffe2/core/hip/context_hip.h"
|
||||
|
||||
// define functions to allow add/mult/store operaions for input/output with
|
||||
// mixed precisions.
|
||||
namespace caffe2 {
|
||||
|
||||
// functions that will only be triggered when there is no spcialized version
|
||||
// supported
|
||||
template <typename T, typename T2>
|
||||
inline __device__ T mixed_mult(T data1, T2 data2)
|
||||
{
|
||||
return data1 * data2;
|
||||
};
|
||||
|
||||
template <typename T, typename T2>
|
||||
inline __device__ T mixed_add(T data1, T2 data2)
|
||||
{
|
||||
return data1 + data2;
|
||||
};
|
||||
|
||||
template <typename TIN, typename TOUT>
|
||||
inline __device__ void mixed_store(TIN* data_in, TOUT* data_out)
|
||||
{
|
||||
*data_out = *data_in;
|
||||
return;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void mixed_store(T* data_in, T* data_out)
|
||||
{
|
||||
*data_out = *data_in;
|
||||
return;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float mixed_mult(float data1, const float data2)
|
||||
{
|
||||
return data1 * data2;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float mixed_mult(float data1, const half data2)
|
||||
{
|
||||
return data1 * __half2float(data2);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float mixed_mult(float data1, float16 data2)
|
||||
{
|
||||
half* data2_half = reinterpret_cast<half*>(&data2);
|
||||
return data1 * __half2float(*data2_half);
|
||||
}
|
||||
template <>
|
||||
inline __device__ float mixed_add(float data1, const float data2)
|
||||
{
|
||||
return data1 + data2;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float mixed_add(float data1, const half data2)
|
||||
{
|
||||
return data1 + __half2float(data2);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float mixed_add(float data1, float16 data2)
|
||||
{
|
||||
half* data2_half = reinterpret_cast<half*>(&data2);
|
||||
return data1 + __half2float(*data2_half);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void mixed_store(float* data_in, float* data_out)
|
||||
{
|
||||
*data_out = *data_in;
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void mixed_store(half* data_in, float* data_out)
|
||||
{
|
||||
*data_out = __half2float(*data_in);
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void mixed_store(float16* data_in, float* data_out)
|
||||
{
|
||||
half* data_in_half = reinterpret_cast<half*>(data_in);
|
||||
*data_out = __half2float(*data_in_half);
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void mixed_store(float* data_in, float16* data_out)
|
||||
{
|
||||
half data_in_half = __float2half(*data_in);
|
||||
float16* data_in_float16 = reinterpret_cast<float16*>(&data_in_half);
|
||||
*data_out = *data_in_float16;
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void mixed_store(float* data_in, half* data_out)
|
||||
{
|
||||
half data_in_half = __float2half(*data_in);
|
||||
*data_out = data_in_half;
|
||||
return;
|
||||
}
|
||||
} // namespace caffe2
|
||||
#endif // for CAFFE2_UTILS_MIXED_UTILS_HIP_H
|
47
tools/amd_build/build_caffe2_amd.py
Executable file
47
tools/amd_build/build_caffe2_amd.py
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
proj_dir = os.path.join(os.path.dirname(os.path.dirname(amd_build_dir)))
|
||||
|
||||
includes = [
|
||||
"caffe2/operators/*",
|
||||
"caffe2/sgd/*",
|
||||
"caffe2/image/*",
|
||||
"caffe2/transforms/*",
|
||||
"caffe2/video/*",
|
||||
"caffe2/distributed/*",
|
||||
]
|
||||
|
||||
ignores = [
|
||||
"caffe2/operators/depthwise_3x3_conv_op.cu",
|
||||
"caffe2/operators/depthwise_3x3_conv_op_cudnn.cu",
|
||||
"caffe2/operators/top_k.cu",
|
||||
"caffe2/operators/top_k_radix_selection.cuh",
|
||||
"caffe2/operators/top_k_heap_selection.cuh",
|
||||
"caffe2/operators/pool_op_cudnn.cu",
|
||||
"caffe2/operators/roi_align_op_gpu_test.cc",
|
||||
# elementwise ops test is failing
|
||||
"caffe2/operators/elementwise_op_gpu_test.cc",
|
||||
'**/hip/**',
|
||||
]
|
||||
|
||||
file_extensions = ['.cc', '.cu', '.h', '.cuh']
|
||||
|
||||
# Execute the Hipify Script.
|
||||
args = [
|
||||
"--project-directory", proj_dir,
|
||||
"--output-directory", proj_dir,
|
||||
"--includes"] + includes + \
|
||||
["--extensions"] + file_extensions + \
|
||||
["--ignores"] + ignores + \
|
||||
["--hipify_caffe2", "True"] + \
|
||||
["--add-static-casts", "True"]
|
||||
|
||||
subprocess.check_call([
|
||||
sys.executable,
|
||||
os.path.join(amd_build_dir, "pyHIPIFY", "hipify-python.py"),
|
||||
] + args)
|
@ -8,9 +8,9 @@ from functools import reduce
|
||||
|
||||
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
|
||||
include_dirs = [
|
||||
"aten",
|
||||
"torch"
|
||||
includes = [
|
||||
"aten/*",
|
||||
"torch/*"
|
||||
]
|
||||
|
||||
# List of operators currently disabled
|
||||
@ -63,9 +63,12 @@ for root, _directories, files in os.walk(os.path.join(proj_dir, "torch")):
|
||||
# Execute the Hipify Script.
|
||||
args = (["--project-directory", proj_dir] +
|
||||
["--output-directory", proj_dir] +
|
||||
["--include-dirs"] + include_dirs +
|
||||
["--includes"] + includes +
|
||||
["--yaml-settings", yaml_file] +
|
||||
["--add-static-casts", "True"] +
|
||||
["--show-progress", "False"])
|
||||
|
||||
os.execv(os.path.join(amd_build_dir, "pyHIPIFY", "hipify-python.py"), ['python'] + args)
|
||||
subprocess.check_call([
|
||||
sys.executable,
|
||||
os.path.join(amd_build_dir, "pyHIPIFY", "hipify-python.py")
|
||||
] + args)
|
||||
|
@ -53,3 +53,4 @@ API_LAST = 42
|
||||
|
||||
HIP_UNSUPPORTED = 43
|
||||
API_PYTORCH = 1337
|
||||
API_CAFFE2 = 1338
|
@ -2106,5 +2106,29 @@ PYTORCH_SPECIFIC_MAPPINGS = {
|
||||
"define MAX_NUM_BLOCKS 200": ("define MAX_NUM_BLOCKS 64", API_PYTORCH),
|
||||
}
|
||||
|
||||
CAFFE2_SPECIFIC_MAPPINGS = {
|
||||
"CUDA" :("HIP", API_CAFFE2),
|
||||
"REGISTER_CUDA_OPERATOR" : ("REGISTER_HIP_OPERATOR", API_CAFFE2),
|
||||
"cuda_stream" : ("hip_stream", API_CAFFE2),
|
||||
"context_gpu" : ("hip/context_hip", API_CAFFE2),
|
||||
"common_gpu" : ("hip/common_hip", API_CAFFE2),
|
||||
"mixed_utils" : ("hip/mixed_utils_hip", API_CAFFE2),
|
||||
"operator_fallback_gpu" : ("hip/operator_fallback_hip", API_CAFFE2),
|
||||
"recurrent_network_executor_gpu" : ("hip/recurrent_network_executor_hip", API_CAFFE2),
|
||||
"max_pool_with_index_gpu": ("hip/max_pool_with_index_hip", API_CAFFE2),
|
||||
"CUDA_1D_KERNEL_LOOP" : ("HIP_1D_KERNEL_LOOP", API_CAFFE2),
|
||||
"CUDAContext" : ("HIPContext", API_CAFFE2),
|
||||
"CAFFE_CUDA_NUM_THREADS" : ("CAFFE_HIP_NUM_THREADS", API_CAFFE2),
|
||||
"HasCudaGPU" : ("HasHipGPU", API_CAFFE2),
|
||||
"__expf" : ("expf", API_CAFFE2),
|
||||
"CUBLAS_ENFORCE" : ("ROCBLAS_ENFORCE", API_CAFFE2),
|
||||
"cublas_handle" : ("rocblas_handle", API_CAFFE2),
|
||||
"CURAND_ENFORCE" :("HIPRAND_ENFORCE", API_CAFFE2),
|
||||
"curandGenerateUniform" : ("hiprandGenerateUniform", API_CAFFE2),
|
||||
"curand_generator" : ("hiprand_generator", API_CAFFE2),
|
||||
"set_cuda_gpu_id" : ("set_hip_gpu_id", API_CAFFE2),
|
||||
"CaffeCudaGetDevice" : ("CaffeHipGetDevice", API_CAFFE2),
|
||||
}
|
||||
|
||||
CUDA_TO_HIP_MAPPINGS = [CUDA_TYPE_NAME_MAP, CUDA_IDENTIFIER_MAP,
|
||||
CUDA_INCLUDE_MAP, CUDA_SPARSE_MAP, PYTORCH_SPECIFIC_MAPPINGS]
|
||||
CUDA_INCLUDE_MAP, CUDA_SPARSE_MAP, PYTORCH_SPECIFIC_MAPPINGS, CAFFE2_SPECIFIC_MAPPINGS]
|
||||
|
@ -26,11 +26,13 @@
|
||||
|
||||
import argparse
|
||||
import constants
|
||||
import fnmatch
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import os
|
||||
import yaml
|
||||
import ast
|
||||
|
||||
from functools import reduce
|
||||
from enum import Enum
|
||||
@ -40,6 +42,7 @@ from cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
|
||||
"""This dictionary provides the mapping from PyTorch kernel template types
|
||||
to their actual types."""
|
||||
PYTORCH_TEMPLATE_MAP = {"Dtype": "real", "T": "real"}
|
||||
CAFFE2_TEMPLATE_MAP = {}
|
||||
|
||||
|
||||
def openf(filename, mode):
|
||||
@ -210,72 +213,47 @@ def update_progress_bar(total, progress):
|
||||
sys.stderr.flush()
|
||||
|
||||
|
||||
def filename_ends_with_extension(filename, extensions):
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
for ext in extensions:
|
||||
if filename.endswith("." + ext):
|
||||
return True
|
||||
def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), hipify_caffe2=False):
|
||||
def _fnmatch(filepath, patterns):
|
||||
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
|
||||
|
||||
return False
|
||||
def match_extensions(filename):
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return os.path.splitext(filename)[1] in extensions
|
||||
|
||||
for (dirpath, _, filenames) in os.walk(root_path, topdown=True):
|
||||
for fn in filenames:
|
||||
filepath = os.path.join(dirpath, fn)
|
||||
rel_filepath = os.path.relpath(filepath, root_path)
|
||||
if _fnmatch(rel_filepath, includes) and (not _fnmatch(rel_filepath, ignores)) and match_extensions(fn):
|
||||
if hipify_caffe2 and not is_caffe2_gpu_file(filepath):
|
||||
continue
|
||||
|
||||
yield filepath
|
||||
|
||||
|
||||
def inside_included_directories(dirpath, rootpath, include_dirs):
|
||||
"""Helper method to see if filename within included directories"""
|
||||
for included_directory in include_dirs:
|
||||
if re.match(r'{0}\b'.format(os.path.join(rootpath, included_directory)), dirpath):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def walk_over_directory(rootpath, extensions, show_detailed=False, include_dirs=None, show_progress=True):
|
||||
def preprocess(all_files, show_detailed=False, show_progress=True, hipify_caffe2=False):
|
||||
"""
|
||||
Recursively walk over the directory and call preprocessor on selected files.
|
||||
Call preprocessor on selected files.
|
||||
|
||||
Arguments)
|
||||
extensions - A plist of file extensions ['cu', 'cuh', ..]
|
||||
|
||||
include_dirs - Directories under the rootpath that should be included in the walk.
|
||||
|
||||
show_detailed - Show a detailed summary of the transpilation process.
|
||||
"""
|
||||
|
||||
# Default argument for excluded directories.
|
||||
if include_dirs is None:
|
||||
include_dirs = []
|
||||
|
||||
# Compute the total number of files to be traversed.
|
||||
total_files = 0
|
||||
for (dirpath, _dirnames, filenames) in os.walk(rootpath):
|
||||
if inside_included_directories(dirpath, rootpath, include_dirs):
|
||||
for filename in filenames:
|
||||
total_files += filename_ends_with_extension(filename, extensions)
|
||||
|
||||
current_file = 0
|
||||
total_count = len(all_files)
|
||||
finished_count = 0
|
||||
|
||||
# Preprocessing statistics.
|
||||
stats = {"unsupported_calls": [], "kernel_launches": []}
|
||||
|
||||
# Begin traversing the files.
|
||||
for (dirpath, _dirnames, filenames) in os.walk(rootpath, topdown=True):
|
||||
# Check if file ends with a valid extensions
|
||||
if not inside_included_directories(dirpath, rootpath, include_dirs):
|
||||
continue
|
||||
|
||||
for filename in filenames:
|
||||
if filename_ends_with_extension(filename, extensions):
|
||||
# Construct the file's full path
|
||||
filepath = os.sep.join([dirpath, filename])
|
||||
|
||||
# Execute the preprocessor on the specified file.
|
||||
preprocessor(filepath, stats)
|
||||
|
||||
# Update the progress
|
||||
if show_progress:
|
||||
print(os.path.join(dirpath, filename))
|
||||
update_progress_bar(total_files, current_file)
|
||||
|
||||
current_file += 1
|
||||
for filepath in all_files:
|
||||
preprocessor(filepath, stats, hipify_caffe2)
|
||||
# Update the progress
|
||||
if show_progress:
|
||||
print(filepath)
|
||||
update_progress_bar(total_count, finished_count)
|
||||
finished_count += 1
|
||||
|
||||
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC)
|
||||
|
||||
@ -297,6 +275,41 @@ def compute_stats(stats):
|
||||
print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))
|
||||
|
||||
|
||||
def add_dim3(kernel_string, cuda_kernel):
|
||||
'''adds dim3() to the second and third arguments in the kernel launch'''
|
||||
count = 0
|
||||
closure = 0
|
||||
kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
|
||||
arg_locs = [{} for _ in range(2)]
|
||||
arg_locs[count]['start'] = 0
|
||||
for ind, c in enumerate(kernel_string):
|
||||
if count > 1:
|
||||
break
|
||||
if c == "(":
|
||||
closure += 1
|
||||
elif c == ")":
|
||||
closure -= 1
|
||||
elif (c == "," or ind == len(kernel_string) - 1) and closure == 0:
|
||||
arg_locs[count]['end'] = ind
|
||||
count += 1
|
||||
if count < 2:
|
||||
arg_locs[count]['start'] = ind + 1
|
||||
|
||||
first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
|
||||
second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
|
||||
|
||||
first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
|
||||
second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
|
||||
|
||||
first_arg_dim3 = "dim3({})".format(first_arg_clean)
|
||||
second_arg_dim3 = "dim3({})".format(second_arg_clean)
|
||||
|
||||
first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
|
||||
second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
|
||||
cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
|
||||
return cuda_kernel
|
||||
|
||||
|
||||
def processKernelLaunches(string, stats):
|
||||
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
|
||||
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
|
||||
@ -396,12 +409,12 @@ def processKernelLaunches(string, stats):
|
||||
|
||||
# Extract cuda kernel
|
||||
cuda_kernel = string[params[0]["start"]:parenthesis + 1]
|
||||
|
||||
kernel_string = string[kernel['start']:kernel['end']]
|
||||
cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
|
||||
# Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
|
||||
num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
|
||||
|
||||
# Transform cuda kernel to hip kernel
|
||||
hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel[0:-1].replace(
|
||||
hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
|
||||
">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(">>>", ", ")
|
||||
|
||||
# Replace cuda kernel with hip kernel
|
||||
@ -450,6 +463,7 @@ def disable_asserts(input_string):
|
||||
output_string = output_string.replace(input_string[start:p_end + 1], "")
|
||||
return output_string
|
||||
|
||||
|
||||
def replace_forceinline(input_string):
|
||||
"""__forceinline__'d methods can cause 'symbol multiply defined' errors in HIP.
|
||||
Adding 'static' to all such methods leads to compilation errors, so
|
||||
@ -460,6 +474,7 @@ def replace_forceinline(input_string):
|
||||
output_string = re.sub("__forceinline__", "inline", output_string)
|
||||
return output_string
|
||||
|
||||
|
||||
def replace_math_functions(input_string):
|
||||
""" FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors
|
||||
NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called.
|
||||
@ -471,6 +486,7 @@ def replace_math_functions(input_string):
|
||||
output_string = re.sub("std::pow\(", "::pow(", output_string)
|
||||
return output_string
|
||||
|
||||
|
||||
def disable_function(input_string, function, replace_style):
|
||||
""" Finds and disables a function in a particular file.
|
||||
|
||||
@ -610,11 +626,42 @@ def disable_function(input_string, function, replace_style):
|
||||
return output_string
|
||||
|
||||
|
||||
def preprocessor(filepath, stats):
|
||||
""" Executes the CUDA -> HIP conversion on the specified file. """
|
||||
with openf(filepath, "r+") as fileobj:
|
||||
output_source = fileobj.read()
|
||||
def get_hip_file_path(filepath, hipify_caffe2):
|
||||
""" Returns the new name of the hipified file """
|
||||
if not hipify_caffe2:
|
||||
return filepath
|
||||
|
||||
dirpath, filename = os.path.split(filepath)
|
||||
filename_without_ext, ext = os.path.splitext(filename)
|
||||
|
||||
if 'gpu' in filename_without_ext:
|
||||
filename_without_ext = filename_without_ext.replace('gpu', 'hip')
|
||||
else:
|
||||
filename_without_ext += '_hip'
|
||||
|
||||
if ext == '.cu':
|
||||
ext = '.cc'
|
||||
|
||||
return os.path.join(dirpath, 'hip', filename_without_ext + ext)
|
||||
|
||||
|
||||
def is_caffe2_gpu_file(filepath):
|
||||
filename = os.path.basename(filepath)
|
||||
_, ext = os.path.splitext(filename)
|
||||
return 'gpu' in filename or ext in ['.cu', '.cuh']
|
||||
|
||||
|
||||
def preprocessor(filepath, stats, hipify_caffe2):
|
||||
""" Executes the CUDA -> HIP conversion on the specified file. """
|
||||
fin_path = filepath
|
||||
with open(fin_path, 'r') as fin:
|
||||
output_source = fin.read()
|
||||
|
||||
fout_path = get_hip_file_path(filepath, hipify_caffe2)
|
||||
if not os.path.exists(os.path.dirname(fout_path)):
|
||||
os.makedirs(os.path.dirname(fout_path))
|
||||
|
||||
with open(fout_path, 'w') as fout:
|
||||
# Perform type, method, constant replacements
|
||||
for mapping in CUDA_TO_HIP_MAPPINGS:
|
||||
for cuda_type, value in mapping.items():
|
||||
@ -622,13 +669,22 @@ def preprocessor(filepath, stats):
|
||||
hip_type = value[0]
|
||||
meta_data = value[1:]
|
||||
|
||||
if constants.API_CAFFE2 in meta_data and not hipify_caffe2:
|
||||
continue
|
||||
if constants.API_RAND in meta_data and hipify_caffe2:
|
||||
continue
|
||||
|
||||
if output_source.find(cuda_type) > -1:
|
||||
# Check if supported
|
||||
if constants.HIP_UNSUPPORTED in meta_data:
|
||||
stats["unsupported_calls"].append((cuda_type, filepath))
|
||||
|
||||
if cuda_type in output_source:
|
||||
output_source = re.sub(r'\b({0})\b'.format(cuda_type), lambda x: hip_type, output_source)
|
||||
if hipify_caffe2:
|
||||
pattern = r'({0})'.format(cuda_type)
|
||||
else:
|
||||
pattern = r'(\b{0}\b)'.format(cuda_type)
|
||||
output_source = re.sub(pattern, hip_type, output_source)
|
||||
|
||||
# Perform Kernel Launch Replacements
|
||||
output_source = processKernelLaunches(output_source, stats)
|
||||
@ -643,14 +699,7 @@ def preprocessor(filepath, stats):
|
||||
# Replace __forceinline__ with inline
|
||||
output_source = replace_forceinline(output_source)
|
||||
|
||||
# Overwrite file contents
|
||||
fileobj.seek(0)
|
||||
fileobj.write(output_source)
|
||||
fileobj.truncate()
|
||||
fileobj.flush()
|
||||
|
||||
# Flush to disk
|
||||
os.fsync(fileobj)
|
||||
fout.write(output_source)
|
||||
|
||||
|
||||
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
|
||||
@ -847,7 +896,7 @@ def extract_arguments(start, string):
|
||||
closures["("] -= 1
|
||||
elif string[current_position] == "<":
|
||||
closures["<"] += 1
|
||||
elif string[current_position] == ">" and string[current_position - 1] != "-":
|
||||
elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
|
||||
closures["<"] -= 1
|
||||
|
||||
# Finished all arguments
|
||||
@ -867,7 +916,7 @@ def extract_arguments(start, string):
|
||||
|
||||
|
||||
# Add static_cast to ensure that the type of kernel arguments matches that in the corresponding kernel definition
|
||||
def add_static_casts(directory, extensions, KernelTemplateParams):
|
||||
def add_static_casts(filepath, KernelTemplateParams):
|
||||
"""Add static casts to kernel launches in order to keep launch argument types and kernel definition types matching.
|
||||
|
||||
Example:
|
||||
@ -884,73 +933,70 @@ def add_static_casts(directory, extensions, KernelTemplateParams):
|
||||
static_cast_types = ["int", "const int", "int64_t", "THCIndex_t *",
|
||||
"const int *", "ptrdiff_t", "long", "const int64_t*", "int64_t *", "double"]
|
||||
|
||||
# Add static_casts<> to all kernel launches.
|
||||
for (dirpath, _dirnames, filenames) in os.walk(directory):
|
||||
for filename in filenames:
|
||||
if filename_ends_with_extension(filename, extensions):
|
||||
filepath = os.sep.join([dirpath, filename])
|
||||
with openf(filepath, "r+") as fileobj:
|
||||
input_source = fileobj.read()
|
||||
new_output_source = input_source
|
||||
for kernel in re.finditer("hipLaunchKernelGGL\(", input_source):
|
||||
arguments = extract_arguments(kernel.end() - 1, input_source)
|
||||
with openf(filepath, "r+") as fileobj:
|
||||
input_source = fileobj.read()
|
||||
new_output_source = input_source
|
||||
for kernel in re.finditer("hipLaunchKernelGGL\(", input_source):
|
||||
arguments = extract_arguments(kernel.end() - 1, input_source)
|
||||
|
||||
# Check if we have templating + static_cast information
|
||||
argument_strings = [input_source[arg["start"]:arg["end"]] for arg in arguments]
|
||||
original_kernel_name_with_template = argument_strings[0].strip()
|
||||
kernel_name = original_kernel_name_with_template.split("<")[0].strip()
|
||||
ignore = ["upscale"]
|
||||
if kernel_name in KernelTemplateParams and kernel_name not in ignore:
|
||||
# Add template to the kernel
|
||||
# Add static_casts to relevant arguments
|
||||
kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"]
|
||||
argument_types = KernelTemplateParams[kernel_name]["arg_types"]
|
||||
|
||||
# The first 5 arguments are simply (function, number blocks, dimension blocks, shared memory, stream)
|
||||
# old_kernel_launch_parameters - will contain the actual arguments to the function itself.
|
||||
old_kernel_launch_parameters = input_source[arguments[5]["start"]:arguments[-1]["end"]]
|
||||
new_kernel_launch_parameters = old_kernel_launch_parameters
|
||||
|
||||
# full_old_kernel_launch - will contain the entire kernel launch closure.
|
||||
full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
|
||||
full_new_kernel_launch = full_old_kernel_launch
|
||||
# Check if we have templating + static_cast information
|
||||
argument_strings = [input_source[arg["start"]:arg["end"]] for arg in arguments]
|
||||
original_kernel_name_with_template = argument_strings[0].strip()
|
||||
kernel_name = original_kernel_name_with_template.split("<")[0].strip()
|
||||
ignore = ["upscale"]
|
||||
if kernel_name in KernelTemplateParams and kernel_name not in ignore:
|
||||
# Add template to the kernel
|
||||
# Add static_casts to relevant arguments
|
||||
kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"]
|
||||
argument_types = KernelTemplateParams[kernel_name]["arg_types"]
|
||||
|
||||
kernel_params = argument_strings[5:]
|
||||
for arg_idx, arg in enumerate(kernel_params):
|
||||
if arg_idx in argument_types:
|
||||
the_type = argument_types[arg_idx]
|
||||
the_arg = arg.replace("\n", "").replace("\\", "").strip()
|
||||
# Not all types have issues with the hipLaunchKernelGGL.
|
||||
if the_type in static_cast_types:
|
||||
static_argument = "static_cast<{0}>({1})".format(the_type, the_arg)
|
||||
# The first 5 arguments are simply (function, number blocks, dimension blocks, shared memory, stream)
|
||||
# old_kernel_launch_parameters - will contain the actual arguments to the function itself.
|
||||
old_kernel_launch_parameters = input_source[arguments[5]["start"]:arguments[-1]["end"]]
|
||||
new_kernel_launch_parameters = old_kernel_launch_parameters
|
||||
|
||||
def replace_arg(match):
|
||||
return match.group(1) + static_argument + match.group(3)
|
||||
# Update to static_cast, account for cases where argument is at start/end of string
|
||||
new_kernel_launch_parameters = re.sub(r'(^|\W)({0})(\W|$)'.format(
|
||||
re.escape(the_arg)), replace_arg, new_kernel_launch_parameters)
|
||||
|
||||
# replace kernel arguments in full kernel launch arguments w/ static_cast ones
|
||||
full_new_kernel_launch = full_new_kernel_launch.replace(old_kernel_launch_parameters, new_kernel_launch_parameters)
|
||||
# full_old_kernel_launch - will contain the entire kernel launch closure.
|
||||
full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
|
||||
full_new_kernel_launch = full_old_kernel_launch
|
||||
|
||||
# PyTorch Specific: Add template type
|
||||
# Here the template value will be resolved from <real> to <Dtype>.
|
||||
if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"):
|
||||
kernel_name_with_template = kernel_name_with_template.replace("<real>", "<Dtype>")
|
||||
full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
|
||||
lambda x: kernel_name_with_template, full_new_kernel_launch)
|
||||
kernel_params = argument_strings[5:]
|
||||
for arg_idx, arg in enumerate(kernel_params):
|
||||
if arg_idx in argument_types:
|
||||
the_type = argument_types[arg_idx]
|
||||
the_arg = arg.replace("\n", "").replace("\\", "").strip()
|
||||
# Not all types have issues with the hipLaunchKernelGGL.
|
||||
if the_type in static_cast_types:
|
||||
static_argument = "static_cast<{0}>({1})".format(the_type, the_arg)
|
||||
|
||||
# Replace Launch
|
||||
new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch)
|
||||
def replace_arg(match):
|
||||
return match.group(1) + static_argument + match.group(3)
|
||||
# Update to static_cast, account for cases where argument is at start/end of string
|
||||
new_kernel_launch_parameters = re.sub(r'(^|\W)({0})(\W|$)'.format(
|
||||
re.escape(the_arg)), replace_arg, new_kernel_launch_parameters)
|
||||
|
||||
# Overwrite file contents
|
||||
fileobj.seek(0)
|
||||
fileobj.write(new_output_source)
|
||||
fileobj.truncate()
|
||||
fileobj.flush()
|
||||
# replace kernel arguments in full kernel launch arguments w/ static_cast ones
|
||||
full_new_kernel_launch = full_new_kernel_launch.replace(
|
||||
old_kernel_launch_parameters, new_kernel_launch_parameters)
|
||||
|
||||
# Flush to disk
|
||||
os.fsync(fileobj)
|
||||
# PyTorch Specific: Add template type
|
||||
# Here the template value will be resolved from <real> to <Dtype>.
|
||||
if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"):
|
||||
kernel_name_with_template = kernel_name_with_template.replace("<real>", "<Dtype>")
|
||||
|
||||
full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
|
||||
lambda x: kernel_name_with_template, full_new_kernel_launch)
|
||||
|
||||
# Replace Launch
|
||||
new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch)
|
||||
|
||||
# Overwrite file contents
|
||||
fileobj.seek(0)
|
||||
fileobj.write(new_output_source)
|
||||
fileobj.truncate()
|
||||
fileobj.flush()
|
||||
|
||||
# Flush to disk
|
||||
os.fsync(fileobj)
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
@ -990,7 +1036,7 @@ def main():
|
||||
parser.add_argument(
|
||||
'--extensions',
|
||||
nargs='+',
|
||||
default=["cu", "cuh", "c", "cpp", "h", "in", "hpp"],
|
||||
default=[".cu", ".cuh", ".c", ".cpp", ".h", ".in", ".hpp"],
|
||||
help="The extensions for files to run the Hipify script over.",
|
||||
required=False)
|
||||
|
||||
@ -1002,10 +1048,10 @@ def main():
|
||||
required=False)
|
||||
|
||||
parser.add_argument(
|
||||
'--include-dirs',
|
||||
'--includes',
|
||||
nargs='+',
|
||||
default=[],
|
||||
help="The directories under the root that should be included.",
|
||||
help="The patterns of files that should be included.",
|
||||
required=False)
|
||||
|
||||
parser.add_argument(
|
||||
@ -1022,6 +1068,19 @@ def main():
|
||||
help="Whether to automatically add static_casts to kernel arguments.",
|
||||
required=False)
|
||||
|
||||
parser.add_argument(
|
||||
'--hipify_caffe2',
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to hipify caffe2 source",
|
||||
required=False)
|
||||
|
||||
parser.add_argument(
|
||||
'--ignores',
|
||||
nargs='+',
|
||||
default=[],
|
||||
help="list of patterns to ignore for hipifying")
|
||||
|
||||
parser.add_argument(
|
||||
'--show-progress',
|
||||
type=str2bool,
|
||||
@ -1037,33 +1096,14 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
# If no output directory, provide a default one.
|
||||
if args.output_directory is "":
|
||||
if not args.output_directory:
|
||||
args.project_directory.rstrip("/")
|
||||
args.output_directory = args.project_directory + "_amd"
|
||||
|
||||
# Make sure output directory does not exist.
|
||||
if not os.path.exists(args.output_directory):
|
||||
print("The output folder already exists.")
|
||||
sys.exit(2)
|
||||
|
||||
# Copy from project directory to output directory if not done already.
|
||||
if not os.path.exists(args.output_directory):
|
||||
shutil.copytree(args.project_directory, args.output_directory)
|
||||
|
||||
# Extract all of the kernel parameter and template type information.
|
||||
if args.add_static_casts:
|
||||
KernelTemplateParams = {}
|
||||
for (dirpath, _dirnames, filenames) in os.walk(args.output_directory):
|
||||
for filename in filenames:
|
||||
if filename_ends_with_extension(filename, args.extensions) and inside_included_directories(dirpath, args.output_directory, args.include_dirs):
|
||||
the_file = os.sep.join([dirpath, filename])
|
||||
|
||||
# Store param information inside KernelTemplateParams
|
||||
get_kernel_template_params(
|
||||
the_file,
|
||||
KernelTemplateParams,
|
||||
PYTORCH_TEMPLATE_MAP)
|
||||
|
||||
# Open YAML file with disable information.
|
||||
if args.yaml_settings != "":
|
||||
with openf(args.yaml_settings, "r") as f:
|
||||
@ -1152,17 +1192,28 @@ def main():
|
||||
f.write(txt)
|
||||
f.truncate()
|
||||
|
||||
# Start Preprocessor
|
||||
walk_over_directory(
|
||||
args.output_directory,
|
||||
extensions=args.extensions,
|
||||
show_detailed=args.show_detailed,
|
||||
include_dirs=args.include_dirs,
|
||||
show_progress=args.show_progress)
|
||||
all_files = list(matched_files_iter(args.output_directory, includes=args.includes,
|
||||
ignores=args.ignores, extensions=args.extensions, hipify_caffe2=args.hipify_caffe2))
|
||||
|
||||
# Start Preprocessor
|
||||
preprocess(
|
||||
all_files,
|
||||
show_detailed=args.show_detailed,
|
||||
show_progress=args.show_progress,
|
||||
hipify_caffe2=args.hipify_caffe2)
|
||||
|
||||
# Extract all of the kernel parameter and template type information.
|
||||
if args.add_static_casts:
|
||||
KernelTemplateParams = {}
|
||||
for filepath in all_files:
|
||||
get_kernel_template_params(
|
||||
filepath,
|
||||
KernelTemplateParams,
|
||||
CAFFE2_TEMPLATE_MAP if args.hipify_caffe2 else PYTORCH_TEMPLATE_MAP)
|
||||
|
||||
# Execute the Clang Tool to Automatically add static casts
|
||||
add_static_casts(args.output_directory, args.extensions, KernelTemplateParams)
|
||||
for filepath in all_files:
|
||||
add_static_casts(get_hip_file_path(filepath, hipify_caffe2=args.hipify_caffe2), KernelTemplateParams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user