mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
clip operator, not tested.
This commit is contained in:
@ -185,6 +185,7 @@ class Operator : public OperatorBase {
|
||||
};
|
||||
|
||||
#define USE_OPERATOR_BASE_FUNCTIONS \
|
||||
using OperatorBase::HasArgument; \
|
||||
using OperatorBase::GetSingleArgument; \
|
||||
using OperatorBase::GetRepeatedArgument; \
|
||||
using OperatorBase::def; \
|
||||
|
@ -9,6 +9,7 @@ cc_library(
|
||||
"accumulate_op.cc",
|
||||
"accuracy_op.cc",
|
||||
"averagepool_op.cc",
|
||||
"clip_op.cc",
|
||||
"conv_op.cc",
|
||||
"cross_entropy_op.cc",
|
||||
"depth_split_op.cc",
|
||||
@ -43,6 +44,7 @@ cuda_library(
|
||||
"accumulate_op.cu",
|
||||
"accuracy_op.cu",
|
||||
"averagepool_op.cu",
|
||||
"clip_op.cu",
|
||||
"conv_op.cu",
|
||||
"cross_entropy_op.cu",
|
||||
"depth_split_op.cu",
|
||||
|
40
caffe2/operators/clip_op.cc
Normal file
40
caffe2/operators/clip_op.cc
Normal file
@ -0,0 +1,40 @@
|
||||
#include "caffe2/operators/clip_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
bool ClipOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
DCHECK_GT(X.size(), 0);
|
||||
Y->ReshapeLike(X);
|
||||
const float* Xdata = X.data();
|
||||
float* Ydata = Y->mutable_data();
|
||||
for (int i = 0; i < X.size(); ++i) {
|
||||
Ydata[i] = std::min(std::max(Xdata[i], min_), max_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ClipGradientOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto* dX = Output(0);
|
||||
DCHECK_GT(X.size(), 0);
|
||||
DCHECK_EQ(dY.size(), X.size());
|
||||
dX->ReshapeLike(X);
|
||||
const float* Xdata = X.data();
|
||||
const float* dYdata = dY.data();
|
||||
float* dXdata = dX->mutable_data();
|
||||
for (int i = 0; i < X.size(); ++i) {
|
||||
dXdata[i] = dYdata[i] * (Xdata[i] > min_ && Xdata[i] < max_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_CPU_OPERATOR(Clip, ClipOp<float, CPUContext>)
|
||||
REGISTER_CPU_OPERATOR(ClipGradient, ClipGradientOp<float, CPUContext>)
|
||||
} // namespace
|
||||
} // namespace caffe2
|
70
caffe2/operators/clip_op.cu
Normal file
70
caffe2/operators/clip_op.cu
Normal file
@ -0,0 +1,70 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/clip_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
template <typename dtype>
|
||||
__device__ dtype cuda_min(dtype x, dtype y);
|
||||
template <typename dtype>
|
||||
__device__ dtype cuda_max(dtype x, dtype y);
|
||||
template <>
|
||||
__device__ float cuda_min(float x, float y) { return fminf(x, y); }
|
||||
template <>
|
||||
__device__ float cuda_max(float x, float y) { return fmaxf(x, y); }
|
||||
template <>
|
||||
__device__ double cuda_min(double x, double y) { return fmin(x, y); }
|
||||
template <>
|
||||
__device__ double cuda_max(double x, double y) { return fmax(x, y); }
|
||||
|
||||
|
||||
|
||||
template <typename dtype>
|
||||
__global__ void ClipKernel(const int N, const dtype minval, const dtype maxval,
|
||||
const dtype* X, dtype* Y) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
Y[i] = cuda_min<dtype>(cuda_max<dtype>(X[i], minval), maxval);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dtype>
|
||||
__global__ void ClipGradientKernel(const int N, const dtype minval,
|
||||
const dtype maxval, const dtype* X,
|
||||
const dtype* dY, dtype* dX) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
dX[i] = dY[i] * (X[i] > minval && X[i] < maxval);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
bool ClipOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
DCHECK_GT(X.size(), 0);
|
||||
Y->ReshapeLike(X);
|
||||
ClipKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
|
||||
0, device_context_.cuda_stream()>>>(
|
||||
X.size(), min_, max_, X.data(), Y->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ClipGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto* dX = Output(0);
|
||||
DCHECK_GT(X.size(), 0);
|
||||
DCHECK_EQ(dY.size(), X.size());
|
||||
dX->ReshapeLike(X);
|
||||
ClipGradientKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
|
||||
0, device_context_.cuda_stream()>>>(
|
||||
X.size(), min_, max_, X.data(), dY.data(), dX->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_CUDA_OPERATOR(Clip, ClipOp<float, CUDAContext>)
|
||||
REGISTER_CUDA_OPERATOR(ClipGradient, ClipGradientOp<float, CUDAContext>)
|
||||
} // namespace
|
||||
} // namespace caffe2
|
70
caffe2/operators/clip_op.h
Normal file
70
caffe2/operators/clip_op.h
Normal file
@ -0,0 +1,70 @@
|
||||
#ifndef CAFFE2_OPERATORS_CLIP_OP_H_
|
||||
#define CAFFE2_OPERATORS_CLIP_OP_H_
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename dtype, class DeviceContext>
|
||||
class ClipOp final : public Operator<dtype, DeviceContext> {
|
||||
public:
|
||||
USE_OPERATOR_BASE_FUNCTIONS;
|
||||
ClipOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<dtype, DeviceContext>(operator_def, ws),
|
||||
min_(std::numeric_limits<dtype>::min()),
|
||||
max_(std::numeric_limits<dtype>::max()) {
|
||||
if (HasArgument("min")) {
|
||||
min_ = static_cast<dtype>(
|
||||
OperatorBase::GetSingleArgument<float>("min", 0));
|
||||
}
|
||||
if (HasArgument("max")) {
|
||||
max_ = static_cast<dtype>(
|
||||
OperatorBase::GetSingleArgument<float>("max", 0));
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice();
|
||||
|
||||
protected:
|
||||
dtype min_;
|
||||
dtype max_;
|
||||
INPUT_OUTPUT_STATS(1, 1, 1, 1);
|
||||
DISABLE_COPY_AND_ASSIGN(ClipOp);
|
||||
};
|
||||
|
||||
template <typename dtype, class DeviceContext>
|
||||
class ClipGradientOp final : public Operator<dtype, DeviceContext> {
|
||||
public:
|
||||
USE_OPERATOR_BASE_FUNCTIONS;
|
||||
ClipGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<dtype, DeviceContext>(operator_def, ws),
|
||||
min_(std::numeric_limits<dtype>::min()),
|
||||
max_(std::numeric_limits<dtype>::max()) {
|
||||
if (HasArgument("min")) {
|
||||
min_ = static_cast<dtype>(
|
||||
OperatorBase::GetSingleArgument<float>("min", 0));
|
||||
}
|
||||
if (HasArgument("max")) {
|
||||
max_ = static_cast<dtype>(
|
||||
OperatorBase::GetSingleArgument<float>("max", 0));
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice();
|
||||
|
||||
protected:
|
||||
dtype min_;
|
||||
dtype max_;
|
||||
// Input: X, dY; Output: dX
|
||||
INPUT_OUTPUT_STATS(2, 2, 1, 1);
|
||||
DISABLE_COPY_AND_ASSIGN(ClipGradientOp);
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_CLIP_OP_H_
|
@ -60,6 +60,12 @@ def AddReluGradient(op):
|
||||
[op.input[0], GetGradientName(op.output[0])],
|
||||
[GetGradientName(op.input[0])])
|
||||
|
||||
@GradientRegistry.RegisterGradient("Clip")
|
||||
def AddReluGradient(op):
|
||||
return CreateOperator("ClipGradient")(
|
||||
[op.input[0], GetGradientName(op.output[0])],
|
||||
[GetGradientName(op.input[0])])
|
||||
|
||||
@GradientRegistry.RegisterGradient("MaxPool")
|
||||
def AddMaxPoolGradient(op):
|
||||
return CreateOperator("MaxPoolGradient")(
|
||||
|
Reference in New Issue
Block a user