Add cudnn activation ops (#9379)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9379

Add cudnn activation ops

Reviewed By: houseroad

Differential Revision: D8818013

fbshipit-source-id: d3881c634a46578b9331da07f9fdf7e1f31d7e8a
This commit is contained in:
Xiaomeng Yang
2018-07-12 23:02:31 -07:00
committed by Facebook Github Bot
parent b15a7d05ce
commit bb9ff58c6d
23 changed files with 928 additions and 775 deletions

View File

@ -0,0 +1,136 @@
#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
#define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/types.h"
namespace caffe2 {
class CuDNNActivationOpBase : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNActivationOpBase(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
}
~CuDNNActivationOpBase() {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
}
protected:
void SetTensorDescriptor(
const cudnnDataType_t data_type,
const int data_size) {
if (data_size != input_size_) {
// Since the best performance is obtained when the tesor is HW-packed, we
// put X.size() to W.
input_size_ = data_size;
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),
data_type,
1,
1,
1,
input_size_));
}
}
CuDNNWrapper cudnn_wrapper_;
cudnnTensorDescriptor_t data_desc_;
cudnnActivationDescriptor_t act_desc_;
int input_size_ = 0;
};
template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationOp final : public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNActivationOp(const OperatorDef& operator_def, Workspace* ws)
: CuDNNActivationOpBase(operator_def, ws) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
if (X.size() == 0) {
Y->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.size());
CUDNN_ENFORCE(cudnnActivationForward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
X.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
Y->template mutable_data<T>()));
return true;
}
};
template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationGradientOp final : public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNActivationGradientOp(const OperatorDef& operator_def, Workspace* ws)
: CuDNNActivationOpBase(operator_def, ws) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& Y = Input(0);
const auto& dY = Input(1);
auto* dX = Output(0);
dX->ResizeLike(Y);
if (Y.size() == 0) {
dX->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.size());
CUDNN_ENFORCE(cudnnActivationBackward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
Y.template data<T>(),
this->data_desc_,
dY.template data<T>(),
this->data_desc_,
Y.template data<T>(), // Use Y_data as placeholder here.
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
dX->template mutable_data<T>()));
return true;
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_

View File

@ -1,46 +1,53 @@
#include "caffe2/operators/elu_op.h"
#include <algorithm>
#include <functional>
#include <string>
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <>
bool EluOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
// Otherwise inplace gradient and Elu dosen't make sense.
CAFFE_ENFORCE_GE(alpha_, 0);
Y->ResizeLike(X);
const auto* Xdata = X.template data<float>();
auto* Ydata = Y->template mutable_data<float>();
ConstEigenVectorArrayMap<float> Xvec(Xdata, X.size());
EigenVectorArrayMap<float> Yvec(Ydata, Y->size());
Yvec = Xvec.cwiseMax(0.f) + (alpha_ * (Xvec.exp() - 1.0f)).cwiseMin(0.f);
template <typename T>
bool EluFunctor<CPUContext>::
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
ConstEigenVectorArrayMap<T> X_arr(X, N);
EigenVectorMap<T>(Y, N) =
(X_arr < 0).select(alpha * (X_arr.exp() - T(1)), X_arr);
return true;
}
template <>
bool EluGradientOp<float, CPUContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(Y.size(), 0);
DCHECK_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
const float* Ydata = Y.data<float>();
const float* dYdata = dY.data<float>();
float* dXdata = dX->mutable_data<float>();
ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size());
ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size());
EigenVectorArrayMap<float> dXvec(dXdata, dX->size());
dXvec = (Yvec > 0).select(dYvec, dYvec * (Yvec + alpha_));
template <typename T>
bool EluGradientFunctor<CPUContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CPUContext* /* context */) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ConstEigenVectorArrayMap<T> Y_arr(Y, size);
ConstEigenVectorArrayMap<T> dY_arr(dY, size);
EigenVectorArrayMap<T>(dX, size) =
(Y_arr < 0).select(dY_arr * (Y_arr + alpha), dY_arr);
return true;
}
REGISTER_CPU_OPERATOR(Elu, EluOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(EluGradient, EluGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
Elu,
UnaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
EluFunctor<CPUContext>>);
REGISTER_CPU_OPERATOR(
EluGradient,
BinaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
EluGradientFunctor<CPUContext>>);
// Input: X, output: Y
OPERATOR_SCHEMA(Elu)
@ -103,10 +110,11 @@ Y:
)DOC")
.Input(0, "X", "1D input tensor of data to be operated on.")
.Output(0, "Y", "1D input tensor, calculated as described above.")
.Arg("alpha", "*(type: float; default: 1.0)* Defines alpha parameter used in calculation.")
.Arg(
"alpha",
"*(type: float; default: 1.0)* Defines alpha parameter used in calculation.")
.InheritOnnxSchema("Elu");
// Input: Y, dY, output: dX
OPERATOR_SCHEMA(EluGradient)
.NumInputs(2)
@ -117,16 +125,21 @@ EluGradient takes both Y and dY and uses this to update dX according to the
chain rule and derivatives of the rectified linear function.
)DOC");
namespace {
class GetEluGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
def_.type() + "Gradient",
"",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
std::vector<std::string>{O(0), GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(Elu, GetEluGradient);
} // namespace caffe2

View File

@ -1,74 +1,91 @@
#include "caffe2/operators/elu_op.h"
#include <algorithm>
#include <functional>
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
__global__ void EluCUDAKernel(const int N, const T alpha, const T* X, T* Y);
template <>
__global__ void
elu_kernel(const int N, const float alpha, const float* x, float* y) {
EluCUDAKernel<float>(const int N, const float alpha, const float* X, float* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
if (x[i] > 0) {
y[i] = x[i];
} else {
y[i] = alpha * (__expf(x[i]) - 1);
}
#if __CUDA_ARCH__ >= 350
Y[i] =
__ldg(X + i) < 0 ? alpha * (expf(__ldg(X + i)) - 1.0f) : __ldg(X + i);
#else
Y[i] = X[i] < 0 ? alpha * (expf(X[i]) - 1.0f) : X[i];
#endif
}
}
__global__ void elu_gradient_kernel(
template <typename T>
__global__ void EluGradientCUDAKernel(
const int N,
const float alpha,
const float* y,
const float* dy,
float* dx) {
const T alpha,
const T* dY,
const T* Y,
T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
if (y[i] > 0) {
dx[i] = dy[i];
} else {
dx[i] = dy[i] * (y[i] + alpha);
}
#if __CUDA_ARCH__ >= 350
dX[i] = __ldg(Y + i) < 0 ? __ldg(dY + i) * (__ldg(Y + i) + alpha)
: __ldg(dY + i);
#else
dX[i] = Y[i] < 0 ? dY[i] * (Y[i] + alpha) : dY[i];
#endif
}
}
} // namespace
template <>
bool EluOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
// Otherwise inplace gradient and Elu dosen't make sense.
CAFFE_ENFORCE_GE(alpha_, 0);
Y->ResizeLike(X);
const auto* Xdata = X.data<float>();
auto* Ydata = Y->mutable_data<float>();
elu_kernel<<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(X.size(), alpha_, Xdata, Ydata);
template <typename T>
bool EluFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
EluCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(N, alpha, X, Y);
return true;
}
template <>
bool EluGradientOp<float, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(Y.size(), 0);
DCHECK_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
const float* Ydata = Y.data<float>();
const float* dYdata = dY.data<float>();
float* dXdata = dX->mutable_data<float>();
elu_gradient_kernel<<<
CAFFE_GET_BLOCKS(Y.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(Y.size(), alpha_, Ydata, dYdata, dXdata);
template <typename T>
bool EluGradientFunctor<CUDAContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CUDAContext* context) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
EluGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(size, alpha, dY, Y, dX);
return true;
}
REGISTER_CUDA_OPERATOR(Elu, EluOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(EluGradient, EluGradientOp<float, CUDAContext>);
}
REGISTER_CUDA_OPERATOR(
Elu,
UnaryElementwiseWithArgsOp<
TensorTypes<float>,
CUDAContext,
EluFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
EluGradient,
BinaryElementwiseWithArgsOp<
TensorTypes<float>,
CUDAContext,
EluGradientFunctor<CUDAContext>>);
} // namespace caffe2

View File

@ -1,37 +1,40 @@
#pragma once
#ifndef CAFFE2_OPERATORS_ELU_OP_H_
#define CAFFE2_OPERATORS_ELU_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include <vector>
#include "caffe2/operators/elementwise_ops.h"
namespace caffe2 {
template <typename T, class Context>
class EluOp final : public Operator<Context> {
public:
EluOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class Context>
struct EluFunctor {
explicit EluFunctor(OperatorBase& op)
: alpha(op.GetSingleArgument<float>("alpha", 1.0f)) {}
bool RunOnDevice() override;
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const;
protected:
T alpha_;
const float alpha;
};
template <typename T, class Context>
class EluGradientOp final : public Operator<Context> {
public:
EluGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class Context>
struct EluGradientFunctor {
explicit EluGradientFunctor(OperatorBase& op)
: alpha(op.GetSingleArgument<float>("alpha", 1.0f)) {}
bool RunOnDevice() override;
template <typename T>
bool Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& dY_dims,
const T* Y,
const T* dY,
T* dX,
Context* context) const;
protected:
T alpha_;
const float alpha;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ELU_OP_H_

View File

@ -0,0 +1,109 @@
#include "caffe2/operators/elu_op.h"
#include "caffe2/operators/activation_ops_cudnn.h"
namespace caffe2 {
template <>
class CuDNNActivationOp<CUDNN_ACTIVATION_ELU> final
: public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNActivationOp(const OperatorDef& operator_def, Workspace* ws)
: CuDNNActivationOpBase(operator_def, ws),
OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_,
CUDNN_ACTIVATION_ELU,
CUDNN_PROPAGATE_NAN,
static_cast<double>(alpha_)));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
if (X.size() == 0) {
Y->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.size());
CUDNN_ENFORCE(cudnnActivationForward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
X.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
Y->template mutable_data<T>()));
return true;
}
private:
const float alpha_;
};
template <>
class CuDNNActivationGradientOp<CUDNN_ACTIVATION_ELU> final
: public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
CuDNNActivationGradientOp(const OperatorDef& operator_def, Workspace* ws)
: CuDNNActivationOpBase(operator_def, ws),
OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_,
CUDNN_ACTIVATION_ELU,
CUDNN_PROPAGATE_NAN,
static_cast<double>(alpha_)));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& Y = Input(0);
const auto& dY = Input(1);
auto* dX = Output(0);
dX->ResizeLike(Y);
if (Y.size() == 0) {
dX->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.size());
CUDNN_ENFORCE(cudnnActivationBackward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
Y.template data<T>(),
this->data_desc_,
dY.template data<T>(),
this->data_desc_,
Y.template data<T>(), // Use Y_data as placeholder here.
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
dX->template mutable_data<T>()));
return true;
}
private:
const float alpha_;
};
REGISTER_CUDNN_OPERATOR(Elu, CuDNNActivationOp<CUDNN_ACTIVATION_ELU>);
REGISTER_CUDNN_OPERATOR(
EluGradient,
CuDNNActivationGradientOp<CUDNN_ACTIVATION_ELU>);
} // namespace caffe2

View File

@ -1,69 +1,42 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/relu_n_op.h"
#include <algorithm>
#include <functional>
#include <string>
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <>
bool ReluNOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
EigenVectorMap<float>(Y->mutable_data<float>(), X.size()) =
ConstEigenVectorMap<float>(X.data<float>(), X.size())
.cwiseMax(0.f)
.cwiseMin(n);
template <typename T>
bool ReluNFunctor<CPUContext>::
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
EigenVectorMap<T>(Y, N) =
ConstEigenVectorMap<T>(X, N).cwiseMax(T(0)).cwiseMin(T(n));
return true;
}
// define a custom template unary functor
template <typename Scalar>
struct CwiseClampSignOp {
CwiseClampSignOp(const Scalar& sup) : m_sup(sup) {}
const Scalar operator()(const Scalar& x) const {
return x < 0 ? 0 : (x >= m_sup ? 0 : 1);
}
Scalar m_sup;
};
template <>
bool ReluNGradientOp<float, CPUContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
const float* Ydata = Y.data<float>();
const float* dYdata = dY.data<float>();
float* dXdata = dX->mutable_data<float>();
// TODO: proper vectorization with Eigen
EigenVectorArrayMap<float> dXvec(dXdata, dX->size());
ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size());
ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size());
dXvec = dYvec * Yvec.unaryExpr(CwiseClampSignOp<float>(n));
template <typename T>
bool ReluNGradientFunctor<CPUContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CPUContext* /* context */) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ConstEigenVectorArrayMap<T> Y_arr(Y, size);
EigenVectorArrayMap<T>(dX, size) =
(Y_arr > T(0) && Y_arr < T(n))
.select(ConstEigenVectorArrayMap<T>(dY, size), T(0));
return true;
}
namespace {
OpSchema::Cost CostInferenceForReluN(
const OperatorDef& def,
const vector<TensorShape>& in) {
@ -71,10 +44,21 @@ OpSchema::Cost CostInferenceForReluN(
cost.params_bytes = 0;
return cost;
}
} // namespace
REGISTER_CPU_OPERATOR(ReluN, ReluNOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(ReluNGradient, ReluNGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
ReluN,
UnaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
ReluNFunctor<CPUContext>>);
REGISTER_CPU_OPERATOR(
ReluNGradient,
BinaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
ReluNGradientFunctor<CPUContext>>);
// Input: X, output: Y
OPERATOR_SCHEMA(ReluN)
@ -103,16 +87,21 @@ ReluGradient takes both Y and dY and uses this to update dX according to the
chain rule and derivatives of the rectified linear function.
)DOC");
namespace {
class GetReluNGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
def_.type() + "Gradient",
"",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
std::vector<std::string>{O(0), GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(ReluN, GetReluNGradient);
} // namespace caffe2

View File

@ -1,82 +1,88 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/relu_n_op.h"
#include <algorithm>
#include <functional>
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
__global__ void ReluNKernel(const int N, const T* X, T* Y, const T thres) {
__global__ void
ReluNCUDAKernel(const int N, const T threshold, const T* X, T* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
auto data = X[i];
Y[i] = data > 0 ? (data > thres ? thres : data) : 0;
#if __CUDA_ARCH__ >= 350
Y[i] = __ldg(X + i) > 0
? (__ldg(X + i) < threshold ? __ldg(X + i) : threshold)
: T(0);
#else
Y[i] = X[i] > 0 ? (X[i] < threshold ? X[i] : threshold) : T(0);
#endif
}
}
template <typename T>
__global__ void ReluNGradientKernel(
__global__ void ReluNGradientCUDAKernel(
const int N,
const T* Y,
const T threshold,
const T* dY,
T* dX,
const T thres) {
const T* Y,
T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
auto data = Y[i];
dX[i] = data > 0 ? (data >= thres ? 0 : dY[i]) : 0;
#if __CUDA_ARCH__ >= 350
dX[i] = (__ldg(Y + i) > 0 && __ldg(Y + i) < threshold) ? dY[i] : T(0);
#else
dX[i] = (Y[i] > 0 && Y[i] < threshold) ? dY[i] : T(0);
#endif
}
}
} // namespace
template <>
bool ReluNOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
CAFFE_ENFORCE_GT(X.size(), 0);
Y->ResizeLike(X);
ReluNKernel<<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
X.size(), X.data<float>(), Y->mutable_data<float>(), n);
template <typename T>
bool ReluNFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
ReluNCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(N, n, X, Y);
return true;
}
template <>
bool ReluNGradientOp<float, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
CAFFE_ENFORCE_GT(Y.size(), 0);
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
ReluNGradientKernel<float>
<<<CAFFE_GET_BLOCKS(Y.size()),
template <typename T>
bool ReluNGradientFunctor<CUDAContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CUDAContext* context) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ReluNGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Y.size(),
Y.data<float>(),
dY.data<float>(),
dX->mutable_data<float>(),
n);
context->cuda_stream()>>>(size, n, dY, Y, dX);
return true;
}
REGISTER_CUDA_OPERATOR(ReluN, ReluNOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(ReluNGradient, ReluNGradientOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(
ReluN,
UnaryElementwiseWithArgsOp<
TensorTypes<float>,
CUDAContext,
ReluNFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
ReluNGradient,
BinaryElementwiseWithArgsOp<
TensorTypes<float>,
CUDAContext,
ReluNGradientFunctor<CUDAContext>>);
} // namespace caffe2

View File

@ -1,60 +1,42 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_OPERATORS_RELU_N_OP_H_
#define CAFFE2_OPERATORS_RELU_N_OP_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include <vector>
#include "caffe2/operators/elementwise_ops.h"
namespace caffe2 {
template <typename T, class Context>
class ReluNOp final : public Operator<Context> {
public:
ReluNOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
n(OperatorBase::GetSingleArgument<float>("n", 6.0)) {
template <class Context>
struct ReluNFunctor {
explicit ReluNFunctor(OperatorBase& op)
: n(op.GetSingleArgument<float>("n", 6.0f)) {
CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0");
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const;
protected:
float n;
const float n;
};
template <typename T, class Context>
class ReluNGradientOp final : public Operator<Context> {
public:
ReluNGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
n(OperatorBase::GetSingleArgument<float>("n", 6.0)) {
template <class Context>
struct ReluNGradientFunctor {
explicit ReluNGradientFunctor(OperatorBase& op)
: n(op.GetSingleArgument<float>("n", 6.0f)) {
CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0");
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
template <typename T>
bool Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& dY_dims,
const T* Y,
const T* dY,
T* dX,
Context* context) const;
protected:
// Input: Y, dY; Output: dX
float n;
const float n;
};
} // namespace caffe2

View File

@ -1,58 +1,56 @@
#include "caffe2/operators/relu_op.h"
#include <algorithm>
#include <functional>
#include <string>
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <>
bool ReluOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
#ifdef CAFFE2_USE_ACCELERATE
const float zero = 0.0f;
vDSP_vthres(X.data<float>(), 1, &zero, Y->mutable_data<float>(), 1, X.size());
#else
EigenVectorMap<float>(Y->mutable_data<float>(), X.size()) =
ConstEigenVectorMap<float>(X.data<float>(), X.size()).cwiseMax(0.f);
#endif
/* Naive implementation
const float* Xdata = X.data<float>();
float* Ydata = Y->mutable_data<float>();
for (int i = 0; i < X.size(); ++i) {
Ydata[i] = std::max(Xdata[i], 0.f);
}
*/
template <typename T>
bool ReluFunctor<CPUContext>::
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
EigenVectorMap<T>(Y, N) = ConstEigenVectorMap<float>(X, N).cwiseMax(T(0));
return true;
}
template <>
bool ReluGradientOp<float, CPUContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
#ifdef CAFFE2_USE_ACCELERATE
const float* Ydata = Y.data<float>();
const float* dYdata = dY.data<float>();
float* dXdata = dX->mutable_data<float>();
// TODO: proper vectorization with Eigen
EigenVectorArrayMap<float> dXvec(dXdata, dX->size());
ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size());
ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size());
dXvec = dYvec * Yvec.cwiseSign();
/* Previous implementation
for (int i = 0; i < Y.size(); ++i) {
dXdata[i] = Ydata[i] > 0 ? dYdata[i] : 0;
}
*/
template <>
template <>
bool ReluFunctor<CPUContext>::operator()<float>(
const int N,
const float* X,
float* Y,
CPUContext* /* context */) const {
const float zero = 0.0f;
vDSP_vthres(X, 1, &zero, Y, 1, N);
return true;
}
#endif // CAFFE2_USE_ACCELERATE
template <>
template <typename T>
bool ReluGradientFunctor<CPUContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CPUContext* /* context */) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
EigenVectorArrayMap<T>(dX, size) =
(ConstEigenVectorArrayMap<T>(Y, size) > T(0))
.select(ConstEigenVectorArrayMap<T>(dY, size), T(0));
return true;
}
namespace {
OpSchema::Cost CostInferenceForRelu(
const OperatorDef& def,
const vector<TensorShape>& in) {
@ -60,10 +58,21 @@ OpSchema::Cost CostInferenceForRelu(
cost.params_bytes = 0;
return cost;
}
} // namespace
REGISTER_CPU_OPERATOR(Relu, ReluOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(ReluGradient, ReluGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
Relu,
UnaryElementwiseOp<
TensorTypes<float>,
CPUContext,
ReluFunctor<CPUContext>>);
REGISTER_CPU_OPERATOR(
ReluGradient,
BinaryElementwiseOp<
TensorTypes<float>,
CPUContext,
ReluGradientFunctor<CPUContext>>);
// Input: X, output: Y
OPERATOR_SCHEMA(Relu)
@ -140,17 +149,21 @@ ReluGradient takes both Y and dY and uses this to update dX according to the
chain rule and derivatives of the rectified linear function.
)DOC");
namespace {
class GetReluGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
def_.type() + "Gradient",
"",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
std::vector<std::string>{O(0), GO(0)},
std::vector<std::string>{GI(0)});
}
};
REGISTER_GRADIENT(Relu, GetReluGradient);
REGISTER_GRADIENT(ReluFp16, GetReluGradient);
} // namespace caffe2
} // namespace
REGISTER_GRADIENT(Relu, GetReluGradient);
} // namespace caffe2

View File

@ -1,57 +1,198 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/relu_op.h"
#include <algorithm>
#include <functional>
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
__global__ void ReluKernel(const int N, const T* X, T* Y) {
__global__ void ReluCUDAKernel(const int N, const T* X, T* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
Y[i] = X[i] > 0 ? X[i] : 0;
#if __CUDA_ARCH__ >= 350
Y[i] = __ldg(X + i) > 0 ? __ldg(X + i) : T(0);
#else
Y[i] = X[i] > 0 ? X[i] : T(0);
#endif
}
}
__global__ void ReluHalfCUDAKernel(const int N, const half* X, half* Y) {
const half kZero = __float2half(0.0f);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
Y[i] = __hgt(__ldg(X + i), kZero) ? __ldg(X + i) : kZero;
#else
Y[i] = (__half2float(X[i]) > 0) ? X[i] : kZero;
#endif
}
}
__global__ void ReluHalf2CUDAKernel(const int N, const half2* X, half2* Y) {
const half2 kZero = __float2half2_rn(0.0f);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
Y[i] = __hmul2(__hgt2(__ldg(X + i), kZero), __ldg(X + i));
#else
const float2 xx = __half22float2(X[i]);
Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f, xx.y > 0 ? xx.y : 0.f);
#endif
}
}
template <typename T>
__global__ void
ReluGradientKernel(const int N, const T* Y, const T* dY, T* dX) {
ReluGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
dX[i] = __ldg(Y + i) > 0 ? __ldg(dY + i) : 0;
#else
dX[i] = Y[i] > 0 ? dY[i] : 0;
#endif
}
}
__global__ void ReluGradientHalfCUDAKernel(
const int N,
const half* dY,
const half* Y,
half* dX) {
const half kZero = __float2half(0.0f);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
dX[i] = __hgt(__ldg(Y + i), kZero) ? __ldg(dY + i) : kZero;
#else
dX[i] = (__half2float(Y[i]) > 0) ? dY[i] : kZero;
#endif
}
}
__global__ void ReluGradientHalf2CUDAKernel(
const int N,
const half2* dY,
const half2* Y,
half2* dX) {
const half2 kZero = __float2half2_rn(0.0f);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
dX[i] = __hmul2(__hgt2(__ldg(Y + i), kZero), __ldg(dY + i));
#else
const float2 dy = __half22float2(dY[i]);
const float2 yy = __half22float2(Y[i]);
dX[i] = __floats2half2_rn(yy.x > 0 ? dy.x : 0.f, yy.y > 0 ? dy.y : 0.f);
#endif
}
}
} // namespace
template <>
bool ReluOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
CAFFE_ENFORCE_GE(X.size(), 0);
Y->ResizeLike(X);
ReluKernel<<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
X.size(), X.data<float>(), Y->mutable_data<float>());
template <typename T>
bool ReluFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
ReluCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(N, X, Y);
return true;
}
template <>
bool ReluGradientOp<float, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
CAFFE_ENFORCE_GE(Y.size(), 0);
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
ReluGradientKernel<<<
CAFFE_GET_BLOCKS(Y.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Y.size(), Y.data<float>(), dY.data<float>(), dX->mutable_data<float>());
template <>
bool ReluFunctor<CUDAContext>::operator()<float16>(
const int N,
const float16* X,
float16* Y,
CUDAContext* context) const {
if ((N & 1) == 0) {
ReluHalf2CUDAKernel<<<
CAFFE_GET_BLOCKS((N >> 1)),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
(N >> 1),
reinterpret_cast<const half2*>(X),
reinterpret_cast<half2*>(Y));
} else {
ReluHalfCUDAKernel<<<
CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y));
}
return true;
}
REGISTER_CUDA_OPERATOR(Relu, ReluOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(ReluGradient, ReluGradientOp<float, CUDAContext>);
template <>
template <typename T>
bool ReluGradientFunctor<CUDAContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CUDAContext* context) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ReluGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(size, dY, Y, dX);
return true;
}
template <>
template <>
bool ReluGradientFunctor<CUDAContext>::Forward<float16>(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const float16* Y,
const float16* dY,
float16* dX,
CUDAContext* context) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
if ((size & 1) == 0) {
ReluGradientHalf2CUDAKernel<<<
CAFFE_GET_BLOCKS((size >> 1)),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
(size >> 1),
reinterpret_cast<const half2*>(dY),
reinterpret_cast<const half2*>(Y),
reinterpret_cast<half2*>(dX));
} else {
ReluGradientHalfCUDAKernel<<<
CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
size,
reinterpret_cast<const half*>(dY),
reinterpret_cast<const half*>(Y),
reinterpret_cast<half*>(dX));
}
return true;
}
REGISTER_CUDA_OPERATOR(
Relu,
UnaryElementwiseOp<
TensorTypes<float, float16>,
CUDAContext,
ReluFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
ReluGradient,
BinaryElementwiseOp<
TensorTypes<float, float16>,
CUDAContext,
ReluGradientFunctor<CUDAContext>>);
} // namespace caffe2

View File

@ -1,34 +1,28 @@
#ifndef CAFFE2_OPERATORS_RELU_OP_H_
#define CAFFE2_OPERATORS_RELU_OP_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include <vector>
#include "caffe2/operators/elementwise_ops.h"
namespace caffe2 {
template <typename T, class Context>
class ReluOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
template <class Context>
struct ReluFunctor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const;
};
template <typename T, class Context>
class ReluGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: Y, dY; Output: dX
template <class Context>
struct ReluGradientFunctor {
template <typename T>
bool Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& dY_dims,
const T* Y,
const T* dY,
T* dX,
Context* context) const;
};
} // namespace caffe2

View File

@ -1,208 +1,12 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/operators/relu_op.h"
#include "caffe2/operators/activation_ops_cudnn.h"
namespace caffe2 {
class CuDNNReluOp final : public Operator<CUDAContext> {
public:
CuDNNReluOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
cudnn_wrapper_(&context_),
order_(StringToStorageOrder(
OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_));
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0));
}
REGISTER_CUDNN_OPERATOR(Relu, CuDNNActivationOp<CUDNN_ACTIVATION_RELU>);
REGISTER_CUDNN_OPERATOR(
ReluGradient,
CuDNNActivationGradientOp<CUDNN_ACTIVATION_RELU>);
~CuDNNReluOp() {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_));
}
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(0);
// Return if X is empty
if (X.size() == 0) {
Y->mutable_data<T>();
return true;
}
// See if we need to reshape.
if (X.dims() != cudnn_input_dims_) {
VLOG(1) << "Setting descriptors.";
cudnn_input_dims_ = X.dims();
int C = 1, H = 1, W = 1;
if (X.ndim() == 4) {
// Normal 4-dimensional tensors for images.
C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(3));
H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1));
W = (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2));
} else {
// If X is not 4-dimensional, we will simply use H = 1 and W = 1
// and wrap everything into C.
C = X.size() / X.dim32(0);
}
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
X.dim32(0),
C,
H,
W));
}
CUDNN_ENFORCE(cudnnActivationForward(
cudnn_wrapper_.inline_cudnn_handle(),
activ_desc_,
cudnnTypeWrapper<T>::kOne(),
data_desc_,
X.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
data_desc_,
Y->template mutable_data<T>()));
return true;
}
bool RunOnDevice() override {
// dispatch based on contents of tensor(s)
const auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
if (X.IsType<float>()) {
return DoRunWithType<float>();
} else if (X.IsType<float16>()) {
return DoRunWithType<float16>();
} else {
LOG(FATAL) << "Unsupported input types";
}
return true;
}
protected:
CuDNNWrapper cudnn_wrapper_;
cudnnTensorDescriptor_t data_desc_;
cudnnActivationDescriptor_t activ_desc_;
vector<TIndex> cudnn_input_dims_;
StorageOrder order_;
};
// Note: You can see that in CuDNNReluGradientOp, we abused the cudnn interface
// by passing in the output tensor for both bottom and top. This is dependent on
// the assumption that the Relu gradient actually does not rely on the bottom
// data, or it treats input=0 the same way as input<0. This is of course not
// very safe, but we have been running in this way in Caffe for a while so it
// *might* be safe to assume so.
class CuDNNReluGradientOp final : public Operator<CUDAContext> {
public:
CuDNNReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
cudnn_wrapper_(&context_),
order_(StringToStorageOrder(
OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_));
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0));
}
~CuDNNReluGradientOp() {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_));
}
template <typename T>
bool DoRunWithType() {
const auto& Y = Input(0);
const auto& dY = Input(1);
auto* dX = Output(0);
// Return if Y is empty
if (Y.size() == 0) {
dX->mutable_data<T>();
return true;
}
// See if we need to reshape.
if (Y.dims() != cudnn_input_dims_) {
VLOG(1) << "Setting descriptors.";
cudnn_input_dims_ = Y.dims();
int C = 1, H = 1, W = 1;
if (Y.ndim() == 4) {
// Normal 4-dimensional tensors for images.
C = (order_ == StorageOrder::NCHW ? Y.dim32(1) : Y.dim32(3));
H = (order_ == StorageOrder::NCHW ? Y.dim32(2) : Y.dim32(1));
W = (order_ == StorageOrder::NCHW ? Y.dim32(3) : Y.dim32(2));
} else {
// If Y is not 4-dimensional, we will simply use H = 1 and W = 1
// and wrap everything into C.
C = Y.size() / Y.dim32(0);
}
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
Y.dim32(0),
C,
H,
W));
}
CUDNN_ENFORCE(cudnnActivationBackward(
cudnn_wrapper_.inline_cudnn_handle(),
activ_desc_,
cudnnTypeWrapper<T>::kOne(),
data_desc_,
Y.template data<T>(),
data_desc_,
dY.template data<T>(),
data_desc_,
// Note: strictly speaking, we should be using the input data in this
// case, but for the ReLU case we rely on the underlying implementation
// that only the output is needed to calculate the Relu gradient. This
// will enable us to do memory optimization for in-place relu. To
// ensure this is correct, a unit test is provided at
// caffe2/python/operator_test/relu_op_test.py
Y.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
data_desc_,
dX->template mutable_data<T>()));
return true;
}
bool RunOnDevice() override {
const auto& Y = Input(0);
auto* dX = Output(0);
dX->ResizeLike(Y);
if (Y.IsType<float>()) {
return DoRunWithType<float>();
} else if (Y.IsType<float16>()) {
return DoRunWithType<float16>();
} else {
LOG(FATAL) << "Unsupported input types";
}
return true;
}
protected:
CuDNNWrapper cudnn_wrapper_;
cudnnTensorDescriptor_t data_desc_;
cudnnActivationDescriptor_t activ_desc_;
vector<TIndex> cudnn_input_dims_;
StorageOrder order_;
// Input: Y, dY; Output: dX
};
namespace {
REGISTER_CUDNN_OPERATOR(Relu, CuDNNReluOp);
REGISTER_CUDNN_OPERATOR(ReluGradient, CuDNNReluGradientOp);
} // namespace
} // namespace caffe2
} // namespace caffe2

View File

@ -1,90 +0,0 @@
#include "caffe2/core/common_gpu.h"
#ifdef CAFFE_HAS_CUDA_FP16
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/relu_op.h"
namespace caffe2 {
namespace {
__global__ void ReluKernelHalf(const int N, const half* X, half* Y) {
const half kZero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
Y[i] = __hgt(X[i], kZero) ? X[i] : kZero;
#else
Y[i] = (__half2float(X[i]) > 0) ? X[i] : kZero;
#endif
}
}
__global__ void ReluKernelHalf2(const int N, const half2* X, half2* Y) {
const half2 kZero = __float2half2_rn(0.0);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
Y[i] = __hmul2(__hgt2(X[i], kZero), X[i]);
#else
float2 xx = __half22float2(X[i]);
Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f,
xx.y > 0 ? xx.y : 0.f);
#endif
}
}
__global__ void ReluGradientKernelHalf(
const int N, const half* Y, const half* dY, half* dX) {
const half kZero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 530
dX[i] = __hgt(Y[i], kZero) ? dY[i] : kZero;
#else
dX[i] = (__half2float(Y[i]) > 0) ? dY[i] : kZero;
#endif
}
}
} // namespace
template <>
bool ReluOp<float16, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
CAFFE_ENFORCE_GT(X.size(), 0);
Y->ResizeLike(X);
if (X.size() % 2 == 0) {
ReluKernelHalf2<<<CAFFE_GET_BLOCKS(X.size() / 2), CAFFE_CUDA_NUM_THREADS,
0, context_.cuda_stream()>>>(
X.size() / 2, reinterpret_cast<const half2*>(X.data<float16>()),
reinterpret_cast<half2*>(Y->mutable_data<float16>()));
return true;
} else {
ReluKernelHalf<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
0, context_.cuda_stream()>>>(
X.size(), reinterpret_cast<const half*>(X.data<float16>()),
reinterpret_cast<half*>(Y->mutable_data<float16>()));
return true;
}
}
template <>
bool ReluGradientOp<float16, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
CAFFE_ENFORCE_GT(Y.size(), 0);
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
dX->ResizeLike(Y);
ReluGradientKernelHalf<<<CAFFE_GET_BLOCKS(Y.size()), CAFFE_CUDA_NUM_THREADS,
0, context_.cuda_stream()>>>(
Y.size(), reinterpret_cast<const half*>(Y.data<float16>()),
reinterpret_cast<const half*>(dY.data<float16>()),
reinterpret_cast<half*>(dX->mutable_data<float16>()));
return true;
}
OPERATOR_SCHEMA(ReluFp16);
OPERATOR_SCHEMA(ReluFp16Gradient);
REGISTER_CUDA_OPERATOR(ReluFp16, ReluOp<float16, CUDAContext>);
REGISTER_CUDA_OPERATOR(ReluFp16Gradient, ReluGradientOp<float16, CUDAContext>);
} // namespace caffe2
#endif // CAFFE_HAS_CUDA_FP16

View File

@ -22,7 +22,7 @@ bool SigmoidGradientFunctor<CPUContext>::Forward(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ConstEigenVectorArrayMap<T> dY_arr(dY, size);
ConstEigenVectorArrayMap<T> Y_arr(Y, size);
EigenVectorArrayMap<T>(dX, size) = dY_arr * Y_arr * (1. - Y_arr);
EigenVectorArrayMap<T>(dX, size) = dY_arr * Y_arr * (T(1) - Y_arr);
return true;
}

View File

@ -8,8 +8,8 @@ template <>
template <typename T>
bool SigmoidFunctor<CPUContext>::
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
ConstEigenVectorArrayMap<T> X_arr(X, N);
EigenVectorArrayMap<T>(Y, N) = 1. / (1. + (-X_arr).exp());
EigenVectorArrayMap<T>(Y, N) =
T(1) / (T(1) + (-ConstEigenVectorArrayMap<T>(X, N)).exp());
return true;
}

View File

@ -10,19 +10,23 @@ namespace caffe2 {
namespace {
template <typename T>
__global__ void SigmoidKernel(const int N, const T* X, T* Y) {
__global__ void SigmoidCUDAKernel(const int N, const T* X, T* Y);
template <>
__global__ void
SigmoidCUDAKernel<float>(const int N, const float* X, float* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
Y[i] = T(1) / (T(1) + exp(-__ldg(X + i)));
Y[i] = 1.0f / (1.0f + expf(-__ldg(X + i)));
#else
Y[i] = T(1) / (T(1) + exp(-X[i]));
Y[i] = 1.0f / (1.0f + expf(-X[i]));
#endif
}
}
template <typename T>
__global__ void
SigmoidGradientKernel(const int N, const T* dY, const T* Y, T* dX) {
SigmoidGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
dX[i] = __ldg(dY + i) * __ldg(Y + i) * (T(1) - __ldg(Y + i));
@ -38,7 +42,7 @@ template <>
template <typename T>
bool SigmoidFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
SigmoidKernel<T>
SigmoidCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
@ -57,7 +61,7 @@ bool SigmoidGradientFunctor<CUDAContext>::Forward(
CUDAContext* context) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
SigmoidGradientKernel<T>
SigmoidGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,

View File

@ -0,0 +1,12 @@
#include "caffe2/operators/sigmoid_op.h"
#include "caffe2/operators/activation_ops_cudnn.h"
namespace caffe2 {
REGISTER_CUDNN_OPERATOR(Sigmoid, CuDNNActivationOp<CUDNN_ACTIVATION_SIGMOID>);
REGISTER_CUDNN_OPERATOR(
SigmoidGradient,
CuDNNActivationGradientOp<CUDNN_ACTIVATION_SIGMOID>);
} // namespace caffe2

View File

@ -0,0 +1,12 @@
#include "caffe2/operators/tanh_op.h"
#include "caffe2/operators/activation_ops_cudnn.h"
namespace caffe2 {
REGISTER_CUDNN_OPERATOR(Tanh, CuDNNActivationOp<CUDNN_ACTIVATION_TANH>);
REGISTER_CUDNN_OPERATOR(
TanhGradient,
CuDNNActivationGradientOp<CUDNN_ACTIVATION_TANH>);
} // namespace caffe2

View File

@ -8,35 +8,123 @@ import numpy as np
from hypothesis import given
import hypothesis.strategies as st
from caffe2.python import core
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.mkl_test_util as mu
import unittest
class TestActivations(hu.HypothesisTestCase):
@given(X=hu.tensor(), in_place=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **mu.gcs)
def test_relu(self, X, in_place, engine, gc, dc):
if gc == mu.mkl_do:
in_place = False
op = core.CreateOperator(
"Relu",
["X"],
["X"] if in_place else ["Y"],
engine=engine,
)
def relu_ref(X):
return [np.maximum(X, 0.0)]
# go away from the origin point to avoid kink problems
X += 0.02 * np.sign(X)
X[X == 0.0] += 0.02
self.assertReferenceChecks(gc, op, [X], relu_ref)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
@unittest.skipIf(not workspace.has_gpu_support,
"Relu for float16 can only run on GPU now.")
@given(X=hu.tensor(dtype=np.float16), in_place=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **hu.gcs_gpu_only)
def test_relu_fp16(self, X, in_place, engine, gc, dc):
op = core.CreateOperator(
"Relu",
["X"],
["X"] if in_place else ["Y"],
engine=engine,
)
def relu_ref(X):
return [np.maximum(X, 0.0)]
def relu_grad_ref(g_out, outputs, fwd_inputs):
dY = g_out
[Y] = outputs
dX = dY
dX[Y == 0] = 0
return [dX]
# go away from the origin point to avoid kink problems
X += 0.02 * np.sign(X)
X[X == 0.0] += 0.02
self.assertReferenceChecks(
hu.gpu_do,
op,
[X],
relu_ref,
output_to_grad="X" if in_place else "Y",
grad_reference=relu_grad_ref)
@given(X=hu.tensor(elements=st.floats(-3.0, 3.0)),
n=st.floats(min_value=0.5, max_value=2.0),
in_place=st.booleans(), **hu.gcs)
def test_relu_n(self, X, n, in_place, gc, dc):
op = core.CreateOperator(
"ReluN",
["X"],
["X"] if in_place else ["Y"],
n=n,
)
def relu_n_ref(X):
return [np.minimum(np.maximum(X, 0.0), n)]
# go away from 0 and n to avoid kink problems
X += 0.04 * np.sign(X)
X[X == 0.0] += 0.04
X -= n
X += 0.02 * np.sign(X)
X[X == 0.0] -= 0.02
X += n
self.assertReferenceChecks(gc, op, [X], relu_n_ref)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.005)
@given(X=hu.tensor(),
alpha=st.floats(min_value=0.1, max_value=2.0),
inplace=st.booleans(),
in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs)
def test_elu(self, X, alpha, inplace, gc, dc):
def test_elu(self, X, alpha, in_place, engine, gc, dc):
op = core.CreateOperator(
"Elu",
["X"],
["X"] if in_place else ["Y"],
alpha=alpha,
engine=engine,
)
def elu_ref(X):
Y = X
Y[X < 0] = alpha * (np.exp(X[X < 0]) - 1.0)
return [Y]
# go away from the origin point to avoid kink problems
X += 0.04 * np.sign(X)
X[X == 0.0] += 0.04
def elu_ref(X):
Y = X.copy()
neg_indices = X <= 0
Y[neg_indices] = alpha * (np.exp(Y[neg_indices]) - 1)
return (Y,)
op = core.CreateOperator(
"Elu",
["X"], ["Y" if not inplace else "X"],
alpha=alpha)
self.assertReferenceChecks(gc, op, [X], elu_ref)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [X], [0])
# Gradient check wrt X
self.assertGradientChecks(gc, op, [X], 0, [0])
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=1e-2)
@given(X=hu.tensor(min_dim=4, max_dim=4),
alpha=st.floats(min_value=0.1, max_value=2.0),
@ -124,8 +212,3 @@ class TestActivations(hu.HypothesisTestCase):
self.assertReferenceChecks(gc, op, [X], leaky_relu_ref)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [X], [0])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -311,12 +311,14 @@ class TestElementwiseOps(hu.HypothesisTestCase):
reference=swish_gradient,
)
@given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(), **hu.gcs)
def test_sigmoid(self, X, inplace, gc, dc):
@given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
def test_sigmoid(self, X, inplace, engine, gc, dc):
op = core.CreateOperator(
"Sigmoid",
["X"],
["X"] if inplace else ["Y"],
engine=engine,
)
def sigmoid_ref(X):

View File

@ -11,8 +11,12 @@ import numpy as np
class TestHyperbolicOps(hu.HypothesisTestCase):
def _test_hyperbolic_op(self, op_name, np_ref, X, in_place, gc, dc):
op = core.CreateOperator(op_name, ["X"], ["X"] if in_place else ["Y"])
def _test_hyperbolic_op(self, op_name, np_ref, X, in_place, engine, gc, dc):
op = core.CreateOperator(
op_name,
["X"],
["X"] if in_place else ["Y"],
engine=engine,)
def ref(X):
return [np_ref(X)]
@ -26,14 +30,15 @@ class TestHyperbolicOps(hu.HypothesisTestCase):
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
@given(X=hu.tensor(dtype=np.float32), in_place=st.booleans(), **hu.gcs)
def test_tanh(self, X, in_place, gc, dc):
self._test_hyperbolic_op("Tanh", np.tanh, X, in_place, gc, dc)
@given(X=hu.tensor(dtype=np.float32), **hu.gcs)
def test_sinh(self, X, gc, dc):
self._test_hyperbolic_op("Sinh", np.sinh, X, False, gc, dc)
self._test_hyperbolic_op("Sinh", np.sinh, X, False, "", gc, dc)
@given(X=hu.tensor(dtype=np.float32), **hu.gcs)
def test_cosh(self, X, gc, dc):
self._test_hyperbolic_op("Cosh", np.cosh, X, False, gc, dc)
self._test_hyperbolic_op("Cosh", np.cosh, X, False, "", gc, dc)
@given(X=hu.tensor(dtype=np.float32), in_place=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
def test_tanh(self, X, in_place, engine, gc, dc):
self._test_hyperbolic_op("Tanh", np.tanh, X, in_place, engine, gc, dc)

View File

@ -1,51 +0,0 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import numpy as np
import unittest
class TestRelu(hu.HypothesisTestCase):
@given(X=hu.tensor(),
**hu.gcs)
def test_relu_n(self, X, gc, dc):
X = 0.8 * np.sign(X)
X = X - 0.5
X[X == 0.0] = 0.01
n = max(np.max(X), 1.0) / 2
X[X == 0.2] = 0.01
def relu_n_ref(X):
Y = np.minimum(np.maximum(X, 0), n)
return [Y]
op = core.CreateOperator("ReluN", ["X"], ["Y"], n=n)
self.assertReferenceChecks(gc, op, [X], relu_n_ref)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.001, threshold=0.001)
if __name__ == "__main__":
unittest.main()

View File

@ -1,31 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from hypothesis import given
import hypothesis.strategies as st
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.mkl_test_util as mu
import numpy as np
import unittest
class TestRelu(hu.HypothesisTestCase):
@given(X=hu.tensor(),
engine=st.sampled_from(["", "CUDNN"]),
**mu.gcs)
def test_relu(self, X, gc, dc, engine):
op = core.CreateOperator("Relu", ["X"], ["Y"], engine=engine)
# go away from the origin point to avoid kink problems
X += 0.02 * np.sign(X)
X[X == 0.0] += 0.02
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
if __name__ == "__main__":
unittest.main()