clip operator, not tested.

This commit is contained in:
Yangqing Jia
2015-08-26 15:24:29 -07:00
parent a57de4ece7
commit 4f4aa1f205
6 changed files with 189 additions and 0 deletions

View File

@ -185,6 +185,7 @@ class Operator : public OperatorBase {
};
#define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::HasArgument; \
using OperatorBase::GetSingleArgument; \
using OperatorBase::GetRepeatedArgument; \
using OperatorBase::def; \

View File

@ -9,6 +9,7 @@ cc_library(
"accumulate_op.cc",
"accuracy_op.cc",
"averagepool_op.cc",
"clip_op.cc",
"conv_op.cc",
"cross_entropy_op.cc",
"depth_split_op.cc",
@ -43,6 +44,7 @@ cuda_library(
"accumulate_op.cu",
"accuracy_op.cu",
"averagepool_op.cu",
"clip_op.cu",
"conv_op.cu",
"cross_entropy_op.cu",
"depth_split_op.cu",

View File

@ -0,0 +1,40 @@
#include "caffe2/operators/clip_op.h"
namespace caffe2 {
template <>
bool ClipOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_GT(X.size(), 0);
Y->ReshapeLike(X);
const float* Xdata = X.data();
float* Ydata = Y->mutable_data();
for (int i = 0; i < X.size(); ++i) {
Ydata[i] = std::min(std::max(Xdata[i], min_), max_);
}
return true;
}
template <>
bool ClipGradientOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(X.size(), 0);
DCHECK_EQ(dY.size(), X.size());
dX->ReshapeLike(X);
const float* Xdata = X.data();
const float* dYdata = dY.data();
float* dXdata = dX->mutable_data();
for (int i = 0; i < X.size(); ++i) {
dXdata[i] = dYdata[i] * (Xdata[i] > min_ && Xdata[i] < max_);
}
return true;
}
namespace {
REGISTER_CPU_OPERATOR(Clip, ClipOp<float, CPUContext>)
REGISTER_CPU_OPERATOR(ClipGradient, ClipGradientOp<float, CPUContext>)
} // namespace
} // namespace caffe2

View File

@ -0,0 +1,70 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/clip_op.h"
namespace caffe2 {
namespace {
template <typename dtype>
__device__ dtype cuda_min(dtype x, dtype y);
template <typename dtype>
__device__ dtype cuda_max(dtype x, dtype y);
template <>
__device__ float cuda_min(float x, float y) { return fminf(x, y); }
template <>
__device__ float cuda_max(float x, float y) { return fmaxf(x, y); }
template <>
__device__ double cuda_min(double x, double y) { return fmin(x, y); }
template <>
__device__ double cuda_max(double x, double y) { return fmax(x, y); }
template <typename dtype>
__global__ void ClipKernel(const int N, const dtype minval, const dtype maxval,
const dtype* X, dtype* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
Y[i] = cuda_min<dtype>(cuda_max<dtype>(X[i], minval), maxval);
}
}
template <typename dtype>
__global__ void ClipGradientKernel(const int N, const dtype minval,
const dtype maxval, const dtype* X,
const dtype* dY, dtype* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
dX[i] = dY[i] * (X[i] > minval && X[i] < maxval);
}
}
} // namespace
template <>
bool ClipOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_GT(X.size(), 0);
Y->ReshapeLike(X);
ClipKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
0, device_context_.cuda_stream()>>>(
X.size(), min_, max_, X.data(), Y->mutable_data());
return true;
}
template <>
bool ClipGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(X.size(), 0);
DCHECK_EQ(dY.size(), X.size());
dX->ReshapeLike(X);
ClipGradientKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
0, device_context_.cuda_stream()>>>(
X.size(), min_, max_, X.data(), dY.data(), dX->mutable_data());
return true;
}
namespace {
REGISTER_CUDA_OPERATOR(Clip, ClipOp<float, CUDAContext>)
REGISTER_CUDA_OPERATOR(ClipGradient, ClipGradientOp<float, CUDAContext>)
} // namespace
} // namespace caffe2

View File

@ -0,0 +1,70 @@
#ifndef CAFFE2_OPERATORS_CLIP_OP_H_
#define CAFFE2_OPERATORS_CLIP_OP_H_
#include <limits>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#include "glog/logging.h"
namespace caffe2 {
template <typename dtype, class DeviceContext>
class ClipOp final : public Operator<dtype, DeviceContext> {
public:
USE_OPERATOR_BASE_FUNCTIONS;
ClipOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<dtype, DeviceContext>(operator_def, ws),
min_(std::numeric_limits<dtype>::min()),
max_(std::numeric_limits<dtype>::max()) {
if (HasArgument("min")) {
min_ = static_cast<dtype>(
OperatorBase::GetSingleArgument<float>("min", 0));
}
if (HasArgument("max")) {
max_ = static_cast<dtype>(
OperatorBase::GetSingleArgument<float>("max", 0));
}
}
bool RunOnDevice();
protected:
dtype min_;
dtype max_;
INPUT_OUTPUT_STATS(1, 1, 1, 1);
DISABLE_COPY_AND_ASSIGN(ClipOp);
};
template <typename dtype, class DeviceContext>
class ClipGradientOp final : public Operator<dtype, DeviceContext> {
public:
USE_OPERATOR_BASE_FUNCTIONS;
ClipGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<dtype, DeviceContext>(operator_def, ws),
min_(std::numeric_limits<dtype>::min()),
max_(std::numeric_limits<dtype>::max()) {
if (HasArgument("min")) {
min_ = static_cast<dtype>(
OperatorBase::GetSingleArgument<float>("min", 0));
}
if (HasArgument("max")) {
max_ = static_cast<dtype>(
OperatorBase::GetSingleArgument<float>("max", 0));
}
}
bool RunOnDevice();
protected:
dtype min_;
dtype max_;
// Input: X, dY; Output: dX
INPUT_OUTPUT_STATS(2, 2, 1, 1);
DISABLE_COPY_AND_ASSIGN(ClipGradientOp);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_CLIP_OP_H_

View File

@ -60,6 +60,12 @@ def AddReluGradient(op):
[op.input[0], GetGradientName(op.output[0])],
[GetGradientName(op.input[0])])
@GradientRegistry.RegisterGradient("Clip")
def AddReluGradient(op):
return CreateOperator("ClipGradient")(
[op.input[0], GetGradientName(op.output[0])],
[GetGradientName(op.input[0])])
@GradientRegistry.RegisterGradient("MaxPool")
def AddMaxPoolGradient(op):
return CreateOperator("MaxPoolGradient")(