mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add cudnn activation ops (#9379)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9379 Add cudnn activation ops Reviewed By: houseroad Differential Revision: D8818013 fbshipit-source-id: d3881c634a46578b9331da07f9fdf7e1f31d7e8a
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							b15a7d05ce
						
					
				
				
					commit
					bb9ff58c6d
				
			
							
								
								
									
										136
									
								
								caffe2/operators/activation_ops_cudnn.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								caffe2/operators/activation_ops_cudnn.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,136 @@ | ||||
| #ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ | ||||
| #define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
| #include "caffe2/core/cudnn_wrappers.h" | ||||
| #include "caffe2/core/operator.h" | ||||
| #include "caffe2/core/tensor.h" | ||||
| #include "caffe2/core/types.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| class CuDNNActivationOpBase : public Operator<CUDAContext> { | ||||
|  public: | ||||
|   USE_OPERATOR_FUNCTIONS(CUDAContext); | ||||
|  | ||||
|   CuDNNActivationOpBase(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) { | ||||
|     CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_)); | ||||
|   } | ||||
|  | ||||
|   ~CuDNNActivationOpBase() { | ||||
|     CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_)); | ||||
|   } | ||||
|  | ||||
|  protected: | ||||
|   void SetTensorDescriptor( | ||||
|       const cudnnDataType_t data_type, | ||||
|       const int data_size) { | ||||
|     if (data_size != input_size_) { | ||||
|       // Since the best performance is obtained when the tesor is HW-packed, we | ||||
|       // put X.size() to W. | ||||
|       input_size_ = data_size; | ||||
|       CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( | ||||
|           data_desc_, | ||||
|           GetCudnnTensorFormat(StorageOrder::NCHW), | ||||
|           data_type, | ||||
|           1, | ||||
|           1, | ||||
|           1, | ||||
|           input_size_)); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   CuDNNWrapper cudnn_wrapper_; | ||||
|   cudnnTensorDescriptor_t data_desc_; | ||||
|   cudnnActivationDescriptor_t act_desc_; | ||||
|  | ||||
|   int input_size_ = 0; | ||||
| }; | ||||
|  | ||||
| template <cudnnActivationMode_t kCuDNNActivationMode> | ||||
| class CuDNNActivationOp final : public CuDNNActivationOpBase { | ||||
|  public: | ||||
|   USE_OPERATOR_FUNCTIONS(CUDAContext); | ||||
|  | ||||
|   CuDNNActivationOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : CuDNNActivationOpBase(operator_def, ws) { | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0)); | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& X = Input(0); | ||||
|     auto* Y = Output(0); | ||||
|     Y->ResizeLike(X); | ||||
|     if (X.size() == 0) { | ||||
|       Y->template mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|     this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.size()); | ||||
|     CUDNN_ENFORCE(cudnnActivationForward( | ||||
|         this->cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         this->act_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         this->data_desc_, | ||||
|         X.template data<T>(), | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         this->data_desc_, | ||||
|         Y->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <cudnnActivationMode_t kCuDNNActivationMode> | ||||
| class CuDNNActivationGradientOp final : public CuDNNActivationOpBase { | ||||
|  public: | ||||
|   USE_OPERATOR_FUNCTIONS(CUDAContext); | ||||
|  | ||||
|   CuDNNActivationGradientOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : CuDNNActivationOpBase(operator_def, ws) { | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0)); | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& Y = Input(0); | ||||
|     const auto& dY = Input(1); | ||||
|     auto* dX = Output(0); | ||||
|     dX->ResizeLike(Y); | ||||
|     if (Y.size() == 0) { | ||||
|       dX->template mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|     this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.size()); | ||||
|     CUDNN_ENFORCE(cudnnActivationBackward( | ||||
|         this->cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         this->act_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         this->data_desc_, | ||||
|         Y.template data<T>(), | ||||
|         this->data_desc_, | ||||
|         dY.template data<T>(), | ||||
|         this->data_desc_, | ||||
|         Y.template data<T>(), // Use Y_data as placeholder here. | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         this->data_desc_, | ||||
|         dX->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| #endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ | ||||
| @ -1,46 +1,53 @@ | ||||
| #include "caffe2/operators/elu_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
| #include <string> | ||||
|  | ||||
| #include "caffe2/utils/eigen_utils.h" | ||||
| #include "caffe2/utils/math.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <> | ||||
| bool EluOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   // Otherwise inplace gradient and Elu dosen't make sense. | ||||
|   CAFFE_ENFORCE_GE(alpha_, 0); | ||||
|   Y->ResizeLike(X); | ||||
|   const auto* Xdata = X.template data<float>(); | ||||
|   auto* Ydata = Y->template mutable_data<float>(); | ||||
|   ConstEigenVectorArrayMap<float> Xvec(Xdata, X.size()); | ||||
|   EigenVectorArrayMap<float> Yvec(Ydata, Y->size()); | ||||
|   Yvec = Xvec.cwiseMax(0.f) + (alpha_ * (Xvec.exp() - 1.0f)).cwiseMin(0.f); | ||||
| template <typename T> | ||||
| bool EluFunctor<CPUContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const { | ||||
|   ConstEigenVectorArrayMap<T> X_arr(X, N); | ||||
|   EigenVectorMap<T>(Y, N) = | ||||
|       (X_arr < 0).select(alpha * (X_arr.exp() - T(1)), X_arr); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool EluGradientOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   DCHECK_GT(Y.size(), 0); | ||||
|   DCHECK_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|  | ||||
|   const float* Ydata = Y.data<float>(); | ||||
|   const float* dYdata = dY.data<float>(); | ||||
|   float* dXdata = dX->mutable_data<float>(); | ||||
|   ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size()); | ||||
|   ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size()); | ||||
|   EigenVectorArrayMap<float> dXvec(dXdata, dX->size()); | ||||
|   dXvec = (Yvec > 0).select(dYvec, dYvec * (Yvec + alpha_)); | ||||
| template <typename T> | ||||
| bool EluGradientFunctor<CPUContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CPUContext* /* context */) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   ConstEigenVectorArrayMap<T> Y_arr(Y, size); | ||||
|   ConstEigenVectorArrayMap<T> dY_arr(dY, size); | ||||
|   EigenVectorArrayMap<T>(dX, size) = | ||||
|       (Y_arr < 0).select(dY_arr * (Y_arr + alpha), dY_arr); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| REGISTER_CPU_OPERATOR(Elu, EluOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR(EluGradient, EluGradientOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     Elu, | ||||
|     UnaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         EluFunctor<CPUContext>>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     EluGradient, | ||||
|     BinaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         EluGradientFunctor<CPUContext>>); | ||||
|  | ||||
| // Input: X, output: Y | ||||
| OPERATOR_SCHEMA(Elu) | ||||
| @ -103,10 +110,11 @@ Y: | ||||
| )DOC") | ||||
|     .Input(0, "X", "1D input tensor of data to be operated on.") | ||||
|     .Output(0, "Y", "1D input tensor, calculated as described above.") | ||||
|     .Arg("alpha", "*(type: float; default: 1.0)* Defines alpha parameter used in calculation.") | ||||
|     .Arg( | ||||
|         "alpha", | ||||
|         "*(type: float; default: 1.0)* Defines alpha parameter used in calculation.") | ||||
|     .InheritOnnxSchema("Elu"); | ||||
|  | ||||
|  | ||||
| // Input: Y, dY, output: dX | ||||
| OPERATOR_SCHEMA(EluGradient) | ||||
|     .NumInputs(2) | ||||
| @ -117,16 +125,21 @@ EluGradient takes both Y and dY and uses this to update dX according to the | ||||
| chain rule and derivatives of the rectified linear function. | ||||
| )DOC"); | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| class GetEluGradient : public GradientMakerBase { | ||||
|   using GradientMakerBase::GradientMakerBase; | ||||
|   vector<OperatorDef> GetGradientDefs() override { | ||||
|   std::vector<OperatorDef> GetGradientDefs() override { | ||||
|     return SingleGradientDef( | ||||
|         def_.type() + "Gradient", | ||||
|         "", | ||||
|         vector<string>{O(0), GO(0)}, | ||||
|         vector<string>{GI(0)}); | ||||
|         std::vector<std::string>{O(0), GO(0)}, | ||||
|         std::vector<std::string>{GI(0)}); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| REGISTER_GRADIENT(Elu, GetEluGradient); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,74 +1,91 @@ | ||||
| #include "caffe2/operators/elu_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void EluCUDAKernel(const int N, const T alpha, const T* X, T* Y); | ||||
|  | ||||
| template <> | ||||
| __global__ void | ||||
| elu_kernel(const int N, const float alpha, const float* x, float* y) { | ||||
| EluCUDAKernel<float>(const int N, const float alpha, const float* X, float* Y) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
|     if (x[i] > 0) { | ||||
|       y[i] = x[i]; | ||||
|     } else { | ||||
|       y[i] = alpha * (__expf(x[i]) - 1); | ||||
|     } | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     Y[i] = | ||||
|         __ldg(X + i) < 0 ? alpha * (expf(__ldg(X + i)) - 1.0f) : __ldg(X + i); | ||||
| #else | ||||
|     Y[i] = X[i] < 0 ? alpha * (expf(X[i]) - 1.0f) : X[i]; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void elu_gradient_kernel( | ||||
| template <typename T> | ||||
| __global__ void EluGradientCUDAKernel( | ||||
|     const int N, | ||||
|     const float alpha, | ||||
|     const float* y, | ||||
|     const float* dy, | ||||
|     float* dx) { | ||||
|     const T alpha, | ||||
|     const T* dY, | ||||
|     const T* Y, | ||||
|     T* dX) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
|     if (y[i] > 0) { | ||||
|       dx[i] = dy[i]; | ||||
|     } else { | ||||
|       dx[i] = dy[i] * (y[i] + alpha); | ||||
|     } | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     dX[i] = __ldg(Y + i) < 0 ? __ldg(dY + i) * (__ldg(Y + i) + alpha) | ||||
|                              : __ldg(dY + i); | ||||
| #else | ||||
|     dX[i] = Y[i] < 0 ? dY[i] * (Y[i] + alpha) : dY[i]; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| template <> | ||||
| bool EluOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   // Otherwise inplace gradient and Elu dosen't make sense. | ||||
|   CAFFE_ENFORCE_GE(alpha_, 0); | ||||
|   Y->ResizeLike(X); | ||||
|   const auto* Xdata = X.data<float>(); | ||||
|   auto* Ydata = Y->mutable_data<float>(); | ||||
|   elu_kernel<<< | ||||
|       CAFFE_GET_BLOCKS(X.size()), | ||||
|       CAFFE_CUDA_NUM_THREADS, | ||||
|       0, | ||||
|       context_.cuda_stream()>>>(X.size(), alpha_, Xdata, Ydata); | ||||
| template <typename T> | ||||
| bool EluFunctor<CUDAContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CUDAContext* context) const { | ||||
|   EluCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(N), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context->cuda_stream()>>>(N, alpha, X, Y); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool EluGradientOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   DCHECK_GT(Y.size(), 0); | ||||
|   DCHECK_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|  | ||||
|   const float* Ydata = Y.data<float>(); | ||||
|   const float* dYdata = dY.data<float>(); | ||||
|   float* dXdata = dX->mutable_data<float>(); | ||||
|   elu_gradient_kernel<<< | ||||
|       CAFFE_GET_BLOCKS(Y.size()), | ||||
|       CAFFE_CUDA_NUM_THREADS, | ||||
|       0, | ||||
|       context_.cuda_stream()>>>(Y.size(), alpha_, Ydata, dYdata, dXdata); | ||||
| template <typename T> | ||||
| bool EluGradientFunctor<CUDAContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CUDAContext* context) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   EluGradientCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(size), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context->cuda_stream()>>>(size, alpha, dY, Y, dX); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(Elu, EluOp<float, CUDAContext>); | ||||
| REGISTER_CUDA_OPERATOR(EluGradient, EluGradientOp<float, CUDAContext>); | ||||
| } | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     Elu, | ||||
|     UnaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CUDAContext, | ||||
|         EluFunctor<CUDAContext>>); | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     EluGradient, | ||||
|     BinaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CUDAContext, | ||||
|         EluGradientFunctor<CUDAContext>>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,37 +1,40 @@ | ||||
| #pragma once | ||||
| #ifndef CAFFE2_OPERATORS_ELU_OP_H_ | ||||
| #define CAFFE2_OPERATORS_ELU_OP_H_ | ||||
|  | ||||
| #include "caffe2/core/context.h" | ||||
| #include "caffe2/core/logging.h" | ||||
| #include "caffe2/core/operator.h" | ||||
| #include <vector> | ||||
|  | ||||
| #include "caffe2/operators/elementwise_ops.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class EluOp final : public Operator<Context> { | ||||
|  public: | ||||
|   EluOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<Context>(operator_def, ws), | ||||
|         alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)) {} | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
| template <class Context> | ||||
| struct EluFunctor { | ||||
|   explicit EluFunctor(OperatorBase& op) | ||||
|       : alpha(op.GetSingleArgument<float>("alpha", 1.0f)) {} | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|   template <typename T> | ||||
|   bool operator()(const int N, const T* X, T* Y, Context* context) const; | ||||
|  | ||||
|  protected: | ||||
|   T alpha_; | ||||
|   const float alpha; | ||||
| }; | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class EluGradientOp final : public Operator<Context> { | ||||
|  public: | ||||
|   EluGradientOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<Context>(operator_def, ws), | ||||
|         alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)) {} | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
| template <class Context> | ||||
| struct EluGradientFunctor { | ||||
|   explicit EluGradientFunctor(OperatorBase& op) | ||||
|       : alpha(op.GetSingleArgument<float>("alpha", 1.0f)) {} | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|   template <typename T> | ||||
|   bool Forward( | ||||
|       const std::vector<int>& Y_dims, | ||||
|       const std::vector<int>& dY_dims, | ||||
|       const T* Y, | ||||
|       const T* dY, | ||||
|       T* dX, | ||||
|       Context* context) const; | ||||
|  | ||||
|  protected: | ||||
|   T alpha_; | ||||
|   const float alpha; | ||||
| }; | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| #endif // CAFFE2_OPERATORS_ELU_OP_H_ | ||||
|  | ||||
							
								
								
									
										109
									
								
								caffe2/operators/elu_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								caffe2/operators/elu_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,109 @@ | ||||
| #include "caffe2/operators/elu_op.h" | ||||
|  | ||||
| #include "caffe2/operators/activation_ops_cudnn.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <> | ||||
| class CuDNNActivationOp<CUDNN_ACTIVATION_ELU> final | ||||
|     : public CuDNNActivationOpBase { | ||||
|  public: | ||||
|   USE_OPERATOR_FUNCTIONS(CUDAContext); | ||||
|  | ||||
|   CuDNNActivationOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : CuDNNActivationOpBase(operator_def, ws), | ||||
|         OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) { | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         act_desc_, | ||||
|         CUDNN_ACTIVATION_ELU, | ||||
|         CUDNN_PROPAGATE_NAN, | ||||
|         static_cast<double>(alpha_))); | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& X = Input(0); | ||||
|     auto* Y = Output(0); | ||||
|     Y->ResizeLike(X); | ||||
|     if (X.size() == 0) { | ||||
|       Y->template mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|     this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.size()); | ||||
|     CUDNN_ENFORCE(cudnnActivationForward( | ||||
|         this->cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         this->act_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         this->data_desc_, | ||||
|         X.template data<T>(), | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         this->data_desc_, | ||||
|         Y->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   const float alpha_; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| class CuDNNActivationGradientOp<CUDNN_ACTIVATION_ELU> final | ||||
|     : public CuDNNActivationOpBase { | ||||
|  public: | ||||
|   USE_OPERATOR_FUNCTIONS(CUDAContext); | ||||
|  | ||||
|   CuDNNActivationGradientOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : CuDNNActivationOpBase(operator_def, ws), | ||||
|         OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) { | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         act_desc_, | ||||
|         CUDNN_ACTIVATION_ELU, | ||||
|         CUDNN_PROPAGATE_NAN, | ||||
|         static_cast<double>(alpha_))); | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& Y = Input(0); | ||||
|     const auto& dY = Input(1); | ||||
|     auto* dX = Output(0); | ||||
|     dX->ResizeLike(Y); | ||||
|     if (Y.size() == 0) { | ||||
|       dX->template mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|     this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.size()); | ||||
|     CUDNN_ENFORCE(cudnnActivationBackward( | ||||
|         this->cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         this->act_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         this->data_desc_, | ||||
|         Y.template data<T>(), | ||||
|         this->data_desc_, | ||||
|         dY.template data<T>(), | ||||
|         this->data_desc_, | ||||
|         Y.template data<T>(), // Use Y_data as placeholder here. | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         this->data_desc_, | ||||
|         dX->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   const float alpha_; | ||||
| }; | ||||
|  | ||||
| REGISTER_CUDNN_OPERATOR(Elu, CuDNNActivationOp<CUDNN_ACTIVATION_ELU>); | ||||
| REGISTER_CUDNN_OPERATOR( | ||||
|     EluGradient, | ||||
|     CuDNNActivationGradientOp<CUDNN_ACTIVATION_ELU>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
| @ -1,69 +1,42 @@ | ||||
| /** | ||||
|  * Copyright (c) 2016-present, Facebook, Inc. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| #include "caffe2/operators/relu_n_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
| #include <string> | ||||
|  | ||||
| #include "caffe2/utils/eigen_utils.h" | ||||
| #include "caffe2/utils/math.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <> | ||||
| bool ReluNOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   Y->ResizeLike(X); | ||||
|  | ||||
|   EigenVectorMap<float>(Y->mutable_data<float>(), X.size()) = | ||||
|       ConstEigenVectorMap<float>(X.data<float>(), X.size()) | ||||
|           .cwiseMax(0.f) | ||||
|           .cwiseMin(n); | ||||
| template <typename T> | ||||
| bool ReluNFunctor<CPUContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const { | ||||
|   EigenVectorMap<T>(Y, N) = | ||||
|       ConstEigenVectorMap<T>(X, N).cwiseMax(T(0)).cwiseMin(T(n)); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| // define a custom template unary functor | ||||
| template <typename Scalar> | ||||
| struct CwiseClampSignOp { | ||||
|   CwiseClampSignOp(const Scalar& sup) : m_sup(sup) {} | ||||
|   const Scalar operator()(const Scalar& x) const { | ||||
|     return x < 0 ? 0 : (x >= m_sup ? 0 : 1); | ||||
|   } | ||||
|   Scalar m_sup; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| bool ReluNGradientOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   CAFFE_ENFORCE_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|  | ||||
|   const float* Ydata = Y.data<float>(); | ||||
|   const float* dYdata = dY.data<float>(); | ||||
|   float* dXdata = dX->mutable_data<float>(); | ||||
|   // TODO: proper vectorization with Eigen | ||||
|   EigenVectorArrayMap<float> dXvec(dXdata, dX->size()); | ||||
|   ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size()); | ||||
|   ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size()); | ||||
|   dXvec = dYvec * Yvec.unaryExpr(CwiseClampSignOp<float>(n)); | ||||
| template <typename T> | ||||
| bool ReluNGradientFunctor<CPUContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CPUContext* /* context */) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   ConstEigenVectorArrayMap<T> Y_arr(Y, size); | ||||
|   EigenVectorArrayMap<T>(dX, size) = | ||||
|       (Y_arr > T(0) && Y_arr < T(n)) | ||||
|           .select(ConstEigenVectorArrayMap<T>(dY, size), T(0)); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| OpSchema::Cost CostInferenceForReluN( | ||||
|     const OperatorDef& def, | ||||
|     const vector<TensorShape>& in) { | ||||
| @ -71,10 +44,21 @@ OpSchema::Cost CostInferenceForReluN( | ||||
|   cost.params_bytes = 0; | ||||
|   return cost; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| REGISTER_CPU_OPERATOR(ReluN, ReluNOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR(ReluNGradient, ReluNGradientOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     ReluN, | ||||
|     UnaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         ReluNFunctor<CPUContext>>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     ReluNGradient, | ||||
|     BinaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         ReluNGradientFunctor<CPUContext>>); | ||||
|  | ||||
| // Input: X, output: Y | ||||
| OPERATOR_SCHEMA(ReluN) | ||||
| @ -103,16 +87,21 @@ ReluGradient takes both Y and dY and uses this to update dX according to the | ||||
| chain rule and derivatives of the rectified linear function. | ||||
| )DOC"); | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| class GetReluNGradient : public GradientMakerBase { | ||||
|   using GradientMakerBase::GradientMakerBase; | ||||
|   vector<OperatorDef> GetGradientDefs() override { | ||||
|   std::vector<OperatorDef> GetGradientDefs() override { | ||||
|     return SingleGradientDef( | ||||
|         def_.type() + "Gradient", | ||||
|         "", | ||||
|         vector<string>{O(0), GO(0)}, | ||||
|         vector<string>{GI(0)}); | ||||
|         std::vector<std::string>{O(0), GO(0)}, | ||||
|         std::vector<std::string>{GI(0)}); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| REGISTER_GRADIENT(ReluN, GetReluNGradient); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,82 +1,88 @@ | ||||
| /** | ||||
|  * Copyright (c) 2016-present, Facebook, Inc. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
| #include "caffe2/operators/relu_n_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void ReluNKernel(const int N, const T* X, T* Y, const T thres) { | ||||
| __global__ void | ||||
| ReluNCUDAKernel(const int N, const T threshold, const T* X, T* Y) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
|     auto data = X[i]; | ||||
|     Y[i] = data > 0 ? (data > thres ? thres : data) : 0; | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     Y[i] = __ldg(X + i) > 0 | ||||
|         ? (__ldg(X + i) < threshold ? __ldg(X + i) : threshold) | ||||
|         : T(0); | ||||
| #else | ||||
|     Y[i] = X[i] > 0 ? (X[i] < threshold ? X[i] : threshold) : T(0); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void ReluNGradientKernel( | ||||
| __global__ void ReluNGradientCUDAKernel( | ||||
|     const int N, | ||||
|     const T* Y, | ||||
|     const T threshold, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     const T thres) { | ||||
|     const T* Y, | ||||
|     T* dX) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
|     auto data = Y[i]; | ||||
|     dX[i] = data > 0 ? (data >= thres ? 0 : dY[i]) : 0; | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     dX[i] = (__ldg(Y + i) > 0 && __ldg(Y + i) < threshold) ? dY[i] : T(0); | ||||
| #else | ||||
|     dX[i] = (Y[i] > 0 && Y[i] < threshold) ? dY[i] : T(0); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| template <> | ||||
| bool ReluNOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   CAFFE_ENFORCE_GT(X.size(), 0); | ||||
|   Y->ResizeLike(X); | ||||
|   ReluNKernel<<< | ||||
|       CAFFE_GET_BLOCKS(X.size()), | ||||
|       CAFFE_CUDA_NUM_THREADS, | ||||
|       0, | ||||
|       context_.cuda_stream()>>>( | ||||
|       X.size(), X.data<float>(), Y->mutable_data<float>(), n); | ||||
| template <typename T> | ||||
| bool ReluNFunctor<CUDAContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CUDAContext* context) const { | ||||
|   ReluNCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(N), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context->cuda_stream()>>>(N, n, X, Y); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool ReluNGradientOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   CAFFE_ENFORCE_GT(Y.size(), 0); | ||||
|   CAFFE_ENFORCE_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|   ReluNGradientKernel<float> | ||||
|       <<<CAFFE_GET_BLOCKS(Y.size()), | ||||
| template <typename T> | ||||
| bool ReluNGradientFunctor<CUDAContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CUDAContext* context) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   ReluNGradientCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(size), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context_.cuda_stream()>>>( | ||||
|           Y.size(), | ||||
|           Y.data<float>(), | ||||
|           dY.data<float>(), | ||||
|           dX->mutable_data<float>(), | ||||
|           n); | ||||
|          context->cuda_stream()>>>(size, n, dY, Y, dX); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(ReluN, ReluNOp<float, CUDAContext>); | ||||
| REGISTER_CUDA_OPERATOR(ReluNGradient, ReluNGradientOp<float, CUDAContext>); | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     ReluN, | ||||
|     UnaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CUDAContext, | ||||
|         ReluNFunctor<CUDAContext>>); | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     ReluNGradient, | ||||
|     BinaryElementwiseWithArgsOp< | ||||
|         TensorTypes<float>, | ||||
|         CUDAContext, | ||||
|         ReluNGradientFunctor<CUDAContext>>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,60 +1,42 @@ | ||||
| /** | ||||
|  * Copyright (c) 2016-present, Facebook, Inc. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| #ifndef CAFFE2_OPERATORS_RELU_N_OP_H_ | ||||
| #define CAFFE2_OPERATORS_RELU_N_OP_H_ | ||||
|  | ||||
| #include "caffe2/core/common_omp.h" | ||||
| #include "caffe2/core/context.h" | ||||
| #include "caffe2/core/logging.h" | ||||
| #include "caffe2/core/operator.h" | ||||
| #include <vector> | ||||
|  | ||||
| #include "caffe2/operators/elementwise_ops.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class ReluNOp final : public Operator<Context> { | ||||
|  public: | ||||
|   ReluNOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<Context>(operator_def, ws), | ||||
|         n(OperatorBase::GetSingleArgument<float>("n", 6.0)) { | ||||
| template <class Context> | ||||
| struct ReluNFunctor { | ||||
|   explicit ReluNFunctor(OperatorBase& op) | ||||
|       : n(op.GetSingleArgument<float>("n", 6.0f)) { | ||||
|     CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); | ||||
|   } | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|   template <typename T> | ||||
|   bool operator()(const int N, const T* X, T* Y, Context* context) const; | ||||
|  | ||||
|  protected: | ||||
|   float n; | ||||
|   const float n; | ||||
| }; | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class ReluNGradientOp final : public Operator<Context> { | ||||
|  public: | ||||
|   ReluNGradientOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<Context>(operator_def, ws), | ||||
|         n(OperatorBase::GetSingleArgument<float>("n", 6.0)) { | ||||
| template <class Context> | ||||
| struct ReluNGradientFunctor { | ||||
|   explicit ReluNGradientFunctor(OperatorBase& op) | ||||
|       : n(op.GetSingleArgument<float>("n", 6.0f)) { | ||||
|     CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); | ||||
|   } | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|   template <typename T> | ||||
|   bool Forward( | ||||
|       const std::vector<int>& Y_dims, | ||||
|       const std::vector<int>& dY_dims, | ||||
|       const T* Y, | ||||
|       const T* dY, | ||||
|       T* dX, | ||||
|       Context* context) const; | ||||
|  | ||||
|  protected: | ||||
|   // Input: Y, dY; Output: dX | ||||
|   float n; | ||||
|   const float n; | ||||
| }; | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,58 +1,56 @@ | ||||
| #include "caffe2/operators/relu_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
| #include <string> | ||||
|  | ||||
| #include "caffe2/utils/eigen_utils.h" | ||||
| #include "caffe2/utils/math.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <> | ||||
| bool ReluOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   Y->ResizeLike(X); | ||||
|  | ||||
| #ifdef CAFFE2_USE_ACCELERATE | ||||
|   const float zero = 0.0f; | ||||
|   vDSP_vthres(X.data<float>(), 1, &zero, Y->mutable_data<float>(), 1, X.size()); | ||||
| #else | ||||
|   EigenVectorMap<float>(Y->mutable_data<float>(), X.size()) = | ||||
|       ConstEigenVectorMap<float>(X.data<float>(), X.size()).cwiseMax(0.f); | ||||
| #endif | ||||
|   /* Naive implementation | ||||
|   const float* Xdata = X.data<float>(); | ||||
|   float* Ydata = Y->mutable_data<float>(); | ||||
|   for (int i = 0; i < X.size(); ++i) { | ||||
|     Ydata[i] = std::max(Xdata[i], 0.f); | ||||
|   } | ||||
|   */ | ||||
| template <typename T> | ||||
| bool ReluFunctor<CPUContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const { | ||||
|   EigenVectorMap<T>(Y, N) = ConstEigenVectorMap<float>(X, N).cwiseMax(T(0)); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool ReluGradientOp<float, CPUContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   CAFFE_ENFORCE_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
| #ifdef CAFFE2_USE_ACCELERATE | ||||
|  | ||||
|   const float* Ydata = Y.data<float>(); | ||||
|   const float* dYdata = dY.data<float>(); | ||||
|   float* dXdata = dX->mutable_data<float>(); | ||||
|   // TODO: proper vectorization with Eigen | ||||
|   EigenVectorArrayMap<float> dXvec(dXdata, dX->size()); | ||||
|   ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size()); | ||||
|   ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size()); | ||||
|   dXvec = dYvec * Yvec.cwiseSign(); | ||||
|   /* Previous implementation | ||||
|   for (int i = 0; i < Y.size(); ++i) { | ||||
|     dXdata[i] = Ydata[i] > 0 ? dYdata[i] : 0; | ||||
|   } | ||||
|   */ | ||||
| template <> | ||||
| template <> | ||||
| bool ReluFunctor<CPUContext>::operator()<float>( | ||||
|     const int N, | ||||
|     const float* X, | ||||
|     float* Y, | ||||
|     CPUContext* /* context */) const { | ||||
|   const float zero = 0.0f; | ||||
|   vDSP_vthres(X, 1, &zero, Y, 1, N); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| #endif // CAFFE2_USE_ACCELERATE | ||||
|  | ||||
| template <> | ||||
| template <typename T> | ||||
| bool ReluGradientFunctor<CPUContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CPUContext* /* context */) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   EigenVectorArrayMap<T>(dX, size) = | ||||
|       (ConstEigenVectorArrayMap<T>(Y, size) > T(0)) | ||||
|           .select(ConstEigenVectorArrayMap<T>(dY, size), T(0)); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| OpSchema::Cost CostInferenceForRelu( | ||||
|     const OperatorDef& def, | ||||
|     const vector<TensorShape>& in) { | ||||
| @ -60,10 +58,21 @@ OpSchema::Cost CostInferenceForRelu( | ||||
|   cost.params_bytes = 0; | ||||
|   return cost; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| REGISTER_CPU_OPERATOR(Relu, ReluOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR(ReluGradient, ReluGradientOp<float, CPUContext>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     Relu, | ||||
|     UnaryElementwiseOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         ReluFunctor<CPUContext>>); | ||||
| REGISTER_CPU_OPERATOR( | ||||
|     ReluGradient, | ||||
|     BinaryElementwiseOp< | ||||
|         TensorTypes<float>, | ||||
|         CPUContext, | ||||
|         ReluGradientFunctor<CPUContext>>); | ||||
|  | ||||
| // Input: X, output: Y | ||||
| OPERATOR_SCHEMA(Relu) | ||||
| @ -140,17 +149,21 @@ ReluGradient takes both Y and dY and uses this to update dX according to the | ||||
| chain rule and derivatives of the rectified linear function. | ||||
| )DOC"); | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| class GetReluGradient : public GradientMakerBase { | ||||
|   using GradientMakerBase::GradientMakerBase; | ||||
|   vector<OperatorDef> GetGradientDefs() override { | ||||
|   std::vector<OperatorDef> GetGradientDefs() override { | ||||
|     return SingleGradientDef( | ||||
|         def_.type() + "Gradient", | ||||
|         "", | ||||
|         vector<string>{O(0), GO(0)}, | ||||
|         vector<string>{GI(0)}); | ||||
|         std::vector<std::string>{O(0), GO(0)}, | ||||
|         std::vector<std::string>{GI(0)}); | ||||
|   } | ||||
| }; | ||||
| REGISTER_GRADIENT(Relu, GetReluGradient); | ||||
| REGISTER_GRADIENT(ReluFp16, GetReluGradient); | ||||
|  | ||||
| }  // namespace caffe2 | ||||
| } // namespace | ||||
|  | ||||
| REGISTER_GRADIENT(Relu, GetReluGradient); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,57 +1,198 @@ | ||||
| #include "caffe2/core/context_gpu.h" | ||||
| #include "caffe2/operators/relu_op.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <functional> | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void ReluKernel(const int N, const T* X, T* Y) { | ||||
| __global__ void ReluCUDAKernel(const int N, const T* X, T* Y) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
|     Y[i] = X[i] > 0 ? X[i] : 0; | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     Y[i] = __ldg(X + i) > 0 ? __ldg(X + i) : T(0); | ||||
| #else | ||||
|     Y[i] = X[i] > 0 ? X[i] : T(0); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluHalfCUDAKernel(const int N, const half* X, half* Y) { | ||||
|   const half kZero = __float2half(0.0f); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     Y[i] = __hgt(__ldg(X + i), kZero) ? __ldg(X + i) : kZero; | ||||
| #else | ||||
|     Y[i] = (__half2float(X[i]) > 0) ? X[i] : kZero; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluHalf2CUDAKernel(const int N, const half2* X, half2* Y) { | ||||
|   const half2 kZero = __float2half2_rn(0.0f); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     Y[i] = __hmul2(__hgt2(__ldg(X + i), kZero), __ldg(X + i)); | ||||
| #else | ||||
|     const float2 xx = __half22float2(X[i]); | ||||
|     Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f, xx.y > 0 ? xx.y : 0.f); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void | ||||
| ReluGradientKernel(const int N, const T* Y, const T* dY, T* dX) { | ||||
| ReluGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     dX[i] = __ldg(Y + i) > 0 ? __ldg(dY + i) : 0; | ||||
| #else | ||||
|     dX[i] = Y[i] > 0 ? dY[i] : 0; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluGradientHalfCUDAKernel( | ||||
|     const int N, | ||||
|     const half* dY, | ||||
|     const half* Y, | ||||
|     half* dX) { | ||||
|   const half kZero = __float2half(0.0f); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     dX[i] = __hgt(__ldg(Y + i), kZero) ? __ldg(dY + i) : kZero; | ||||
| #else | ||||
|     dX[i] = (__half2float(Y[i]) > 0) ? dY[i] : kZero; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluGradientHalf2CUDAKernel( | ||||
|     const int N, | ||||
|     const half2* dY, | ||||
|     const half2* Y, | ||||
|     half2* dX) { | ||||
|   const half2 kZero = __float2half2_rn(0.0f); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     dX[i] = __hmul2(__hgt2(__ldg(Y + i), kZero), __ldg(dY + i)); | ||||
| #else | ||||
|     const float2 dy = __half22float2(dY[i]); | ||||
|     const float2 yy = __half22float2(Y[i]); | ||||
|     dX[i] = __floats2half2_rn(yy.x > 0 ? dy.x : 0.f, yy.y > 0 ? dy.y : 0.f); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| template <> | ||||
| bool ReluOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   CAFFE_ENFORCE_GE(X.size(), 0); | ||||
|  | ||||
|   Y->ResizeLike(X); | ||||
|   ReluKernel<<< | ||||
|       CAFFE_GET_BLOCKS(X.size()), | ||||
|       CAFFE_CUDA_NUM_THREADS, | ||||
|       0, | ||||
|       context_.cuda_stream()>>>( | ||||
|       X.size(), X.data<float>(), Y->mutable_data<float>()); | ||||
| template <typename T> | ||||
| bool ReluFunctor<CUDAContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CUDAContext* context) const { | ||||
|   ReluCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(N), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context->cuda_stream()>>>(N, X, Y); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool ReluGradientOp<float, CUDAContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   CAFFE_ENFORCE_GE(Y.size(), 0); | ||||
|   CAFFE_ENFORCE_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|   ReluGradientKernel<<< | ||||
|       CAFFE_GET_BLOCKS(Y.size()), | ||||
|       CAFFE_CUDA_NUM_THREADS, | ||||
|       0, | ||||
|       context_.cuda_stream()>>>( | ||||
|       Y.size(), Y.data<float>(), dY.data<float>(), dX->mutable_data<float>()); | ||||
| template <> | ||||
| bool ReluFunctor<CUDAContext>::operator()<float16>( | ||||
|     const int N, | ||||
|     const float16* X, | ||||
|     float16* Y, | ||||
|     CUDAContext* context) const { | ||||
|   if ((N & 1) == 0) { | ||||
|     ReluHalf2CUDAKernel<<< | ||||
|         CAFFE_GET_BLOCKS((N >> 1)), | ||||
|         CAFFE_CUDA_NUM_THREADS, | ||||
|         0, | ||||
|         context->cuda_stream()>>>( | ||||
|         (N >> 1), | ||||
|         reinterpret_cast<const half2*>(X), | ||||
|         reinterpret_cast<half2*>(Y)); | ||||
|   } else { | ||||
|     ReluHalfCUDAKernel<<< | ||||
|         CAFFE_GET_BLOCKS(N), | ||||
|         CAFFE_CUDA_NUM_THREADS, | ||||
|         0, | ||||
|         context->cuda_stream()>>>( | ||||
|         N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y)); | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(Relu, ReluOp<float, CUDAContext>); | ||||
| REGISTER_CUDA_OPERATOR(ReluGradient, ReluGradientOp<float, CUDAContext>); | ||||
| template <> | ||||
| template <typename T> | ||||
| bool ReluGradientFunctor<CUDAContext>::Forward( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const T* Y, | ||||
|     const T* dY, | ||||
|     T* dX, | ||||
|     CUDAContext* context) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   ReluGradientCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(size), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|          context->cuda_stream()>>>(size, dY, Y, dX); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| template <> | ||||
| bool ReluGradientFunctor<CUDAContext>::Forward<float16>( | ||||
|     const std::vector<int>& Y_dims, | ||||
|     const std::vector<int>& /* dY_dims */, | ||||
|     const float16* Y, | ||||
|     const float16* dY, | ||||
|     float16* dX, | ||||
|     CUDAContext* context) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   if ((size & 1) == 0) { | ||||
|     ReluGradientHalf2CUDAKernel<<< | ||||
|         CAFFE_GET_BLOCKS((size >> 1)), | ||||
|         CAFFE_CUDA_NUM_THREADS, | ||||
|         0, | ||||
|         context->cuda_stream()>>>( | ||||
|         (size >> 1), | ||||
|         reinterpret_cast<const half2*>(dY), | ||||
|         reinterpret_cast<const half2*>(Y), | ||||
|         reinterpret_cast<half2*>(dX)); | ||||
|   } else { | ||||
|     ReluGradientHalfCUDAKernel<<< | ||||
|         CAFFE_GET_BLOCKS(size), | ||||
|         CAFFE_CUDA_NUM_THREADS, | ||||
|         0, | ||||
|         context->cuda_stream()>>>( | ||||
|         size, | ||||
|         reinterpret_cast<const half*>(dY), | ||||
|         reinterpret_cast<const half*>(Y), | ||||
|         reinterpret_cast<half*>(dX)); | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     Relu, | ||||
|     UnaryElementwiseOp< | ||||
|         TensorTypes<float, float16>, | ||||
|         CUDAContext, | ||||
|         ReluFunctor<CUDAContext>>); | ||||
| REGISTER_CUDA_OPERATOR( | ||||
|     ReluGradient, | ||||
|     BinaryElementwiseOp< | ||||
|         TensorTypes<float, float16>, | ||||
|         CUDAContext, | ||||
|         ReluGradientFunctor<CUDAContext>>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,34 +1,28 @@ | ||||
| #ifndef CAFFE2_OPERATORS_RELU_OP_H_ | ||||
| #define CAFFE2_OPERATORS_RELU_OP_H_ | ||||
|  | ||||
| #include "caffe2/core/common_omp.h" | ||||
| #include "caffe2/core/context.h" | ||||
| #include "caffe2/core/logging.h" | ||||
| #include "caffe2/core/operator.h" | ||||
| #include <vector> | ||||
|  | ||||
| #include "caffe2/operators/elementwise_ops.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class ReluOp final : public Operator<Context> { | ||||
|  public: | ||||
|   USE_SIMPLE_CTOR_DTOR(ReluOp); | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|  | ||||
|  protected: | ||||
| template <class Context> | ||||
| struct ReluFunctor { | ||||
|   template <typename T> | ||||
|   bool operator()(const int N, const T* X, T* Y, Context* context) const; | ||||
| }; | ||||
|  | ||||
| template <typename T, class Context> | ||||
| class ReluGradientOp final : public Operator<Context> { | ||||
|  public: | ||||
|   USE_SIMPLE_CTOR_DTOR(ReluGradientOp); | ||||
|   USE_OPERATOR_CONTEXT_FUNCTIONS; | ||||
|  | ||||
|   bool RunOnDevice() override; | ||||
|  | ||||
|  protected: | ||||
|   // Input: Y, dY; Output: dX | ||||
| template <class Context> | ||||
| struct ReluGradientFunctor { | ||||
|   template <typename T> | ||||
|   bool Forward( | ||||
|       const std::vector<int>& Y_dims, | ||||
|       const std::vector<int>& dY_dims, | ||||
|       const T* Y, | ||||
|       const T* dY, | ||||
|       T* dX, | ||||
|       Context* context) const; | ||||
| }; | ||||
|  | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,208 +1,12 @@ | ||||
| #include "caffe2/core/context_gpu.h" | ||||
| #include "caffe2/core/cudnn_wrappers.h" | ||||
| #include "caffe2/core/operator.h" | ||||
| #include "caffe2/core/types.h" | ||||
| #include "caffe2/operators/relu_op.h" | ||||
|  | ||||
| #include "caffe2/operators/activation_ops_cudnn.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| class CuDNNReluOp final : public Operator<CUDAContext> { | ||||
|  public: | ||||
|   CuDNNReluOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<CUDAContext>(operator_def, ws), | ||||
|         cudnn_wrapper_(&context_), | ||||
|         order_(StringToStorageOrder( | ||||
|             OperatorBase::GetSingleArgument<string>("order", "NCHW"))) { | ||||
|     CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0)); | ||||
|   } | ||||
| REGISTER_CUDNN_OPERATOR(Relu, CuDNNActivationOp<CUDNN_ACTIVATION_RELU>); | ||||
| REGISTER_CUDNN_OPERATOR( | ||||
|     ReluGradient, | ||||
|     CuDNNActivationGradientOp<CUDNN_ACTIVATION_RELU>); | ||||
|  | ||||
|   ~CuDNNReluOp() { | ||||
|     CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& X = Input(0); | ||||
|     auto* Y = Output(0); | ||||
|  | ||||
|     // Return if X is empty | ||||
|     if (X.size() == 0) { | ||||
|       Y->mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|  | ||||
|     // See if we need to reshape. | ||||
|     if (X.dims() != cudnn_input_dims_) { | ||||
|       VLOG(1) << "Setting descriptors."; | ||||
|       cudnn_input_dims_ = X.dims(); | ||||
|       int C = 1, H = 1, W = 1; | ||||
|       if (X.ndim() == 4) { | ||||
|         // Normal 4-dimensional tensors for images. | ||||
|         C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(3)); | ||||
|         H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1)); | ||||
|         W = (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2)); | ||||
|       } else { | ||||
|         // If X is not 4-dimensional, we will simply use H = 1 and W = 1 | ||||
|         // and wrap everything into C. | ||||
|         C = X.size() / X.dim32(0); | ||||
|       } | ||||
|       CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( | ||||
|           data_desc_, | ||||
|           GetCudnnTensorFormat(order_), | ||||
|           cudnnTypeWrapper<T>::type, | ||||
|           X.dim32(0), | ||||
|           C, | ||||
|           H, | ||||
|           W)); | ||||
|     } | ||||
|     CUDNN_ENFORCE(cudnnActivationForward( | ||||
|         cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         activ_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         data_desc_, | ||||
|         X.template data<T>(), | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         data_desc_, | ||||
|         Y->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     // dispatch based on contents of tensor(s) | ||||
|     const auto& X = Input(0); | ||||
|     auto* Y = Output(0); | ||||
|     Y->ResizeLike(X); | ||||
|  | ||||
|     if (X.IsType<float>()) { | ||||
|       return DoRunWithType<float>(); | ||||
|     } else if (X.IsType<float16>()) { | ||||
|       return DoRunWithType<float16>(); | ||||
|     } else { | ||||
|       LOG(FATAL) << "Unsupported input types"; | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|  protected: | ||||
|   CuDNNWrapper cudnn_wrapper_; | ||||
|   cudnnTensorDescriptor_t data_desc_; | ||||
|   cudnnActivationDescriptor_t activ_desc_; | ||||
|   vector<TIndex> cudnn_input_dims_; | ||||
|   StorageOrder order_; | ||||
| }; | ||||
|  | ||||
|  | ||||
| // Note: You can see that in CuDNNReluGradientOp, we abused the cudnn interface | ||||
| // by passing in the output tensor for both bottom and top. This is dependent on | ||||
| // the assumption that the Relu gradient actually does not rely on the bottom | ||||
| // data, or it treats input=0 the same way as input<0. This is of course not | ||||
| // very safe, but we have been running in this way in Caffe for a while so it | ||||
| // *might* be safe to assume so. | ||||
| class CuDNNReluGradientOp final : public Operator<CUDAContext> { | ||||
|  public: | ||||
|   CuDNNReluGradientOp(const OperatorDef& operator_def, Workspace* ws) | ||||
|       : Operator<CUDAContext>(operator_def, ws), | ||||
|         cudnn_wrapper_(&context_), | ||||
|         order_(StringToStorageOrder( | ||||
|             OperatorBase::GetSingleArgument<string>("order", "NCHW"))) { | ||||
|     CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnSetActivationDescriptor( | ||||
|         activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0)); | ||||
|   } | ||||
|  | ||||
|   ~CuDNNReluGradientOp() { | ||||
|     CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_)); | ||||
|     CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_)); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   bool DoRunWithType() { | ||||
|     const auto& Y = Input(0); | ||||
|     const auto& dY = Input(1); | ||||
|     auto* dX = Output(0); | ||||
|  | ||||
|     // Return if Y is empty | ||||
|     if (Y.size() == 0) { | ||||
|       dX->mutable_data<T>(); | ||||
|       return true; | ||||
|     } | ||||
|  | ||||
|     // See if we need to reshape. | ||||
|     if (Y.dims() != cudnn_input_dims_) { | ||||
|       VLOG(1) << "Setting descriptors."; | ||||
|       cudnn_input_dims_ = Y.dims(); | ||||
|       int C = 1, H = 1, W = 1; | ||||
|       if (Y.ndim() == 4) { | ||||
|         // Normal 4-dimensional tensors for images. | ||||
|         C = (order_ == StorageOrder::NCHW ? Y.dim32(1) : Y.dim32(3)); | ||||
|         H = (order_ == StorageOrder::NCHW ? Y.dim32(2) : Y.dim32(1)); | ||||
|         W = (order_ == StorageOrder::NCHW ? Y.dim32(3) : Y.dim32(2)); | ||||
|       } else { | ||||
|         // If Y is not 4-dimensional, we will simply use H = 1 and W = 1 | ||||
|         // and wrap everything into C. | ||||
|         C = Y.size() / Y.dim32(0); | ||||
|       } | ||||
|       CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( | ||||
|           data_desc_, | ||||
|           GetCudnnTensorFormat(order_), | ||||
|           cudnnTypeWrapper<T>::type, | ||||
|           Y.dim32(0), | ||||
|           C, | ||||
|           H, | ||||
|           W)); | ||||
|     } | ||||
|     CUDNN_ENFORCE(cudnnActivationBackward( | ||||
|         cudnn_wrapper_.inline_cudnn_handle(), | ||||
|         activ_desc_, | ||||
|         cudnnTypeWrapper<T>::kOne(), | ||||
|         data_desc_, | ||||
|         Y.template data<T>(), | ||||
|         data_desc_, | ||||
|         dY.template data<T>(), | ||||
|         data_desc_, | ||||
|         // Note: strictly speaking, we should be using the input data in this | ||||
|         // case, but for the ReLU case we rely on the underlying implementation | ||||
|         // that only the output is needed to calculate the Relu gradient. This | ||||
|         // will enable us to do memory optimization for in-place relu. To | ||||
|         // ensure this is correct, a unit test is provided at | ||||
|         // caffe2/python/operator_test/relu_op_test.py | ||||
|         Y.template data<T>(), | ||||
|         cudnnTypeWrapper<T>::kZero(), | ||||
|         data_desc_, | ||||
|         dX->template mutable_data<T>())); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   bool RunOnDevice() override { | ||||
|     const auto& Y = Input(0); | ||||
|     auto* dX = Output(0); | ||||
|     dX->ResizeLike(Y); | ||||
|  | ||||
|     if (Y.IsType<float>()) { | ||||
|       return DoRunWithType<float>(); | ||||
|     } else if (Y.IsType<float16>()) { | ||||
|       return DoRunWithType<float16>(); | ||||
|     } else { | ||||
|       LOG(FATAL) << "Unsupported input types"; | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|  protected: | ||||
|   CuDNNWrapper cudnn_wrapper_; | ||||
|   cudnnTensorDescriptor_t data_desc_; | ||||
|   cudnnActivationDescriptor_t activ_desc_; | ||||
|   vector<TIndex> cudnn_input_dims_; | ||||
|   StorageOrder order_; | ||||
|   // Input: Y, dY; Output: dX | ||||
| }; | ||||
|  | ||||
| namespace { | ||||
| REGISTER_CUDNN_OPERATOR(Relu, CuDNNReluOp); | ||||
| REGISTER_CUDNN_OPERATOR(ReluGradient, CuDNNReluGradientOp); | ||||
| }  // namespace | ||||
| }  // namespace caffe2 | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -1,90 +0,0 @@ | ||||
| #include "caffe2/core/common_gpu.h" | ||||
| #ifdef CAFFE_HAS_CUDA_FP16 | ||||
|  | ||||
| #include "caffe2/core/context_gpu.h" | ||||
| #include "caffe2/operators/relu_op.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
| namespace { | ||||
| __global__ void ReluKernelHalf(const int N, const half* X, half* Y) { | ||||
|   const half kZero = __float2half(0.0); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     Y[i] = __hgt(X[i], kZero) ? X[i] : kZero; | ||||
| #else | ||||
|     Y[i] = (__half2float(X[i]) > 0) ? X[i] : kZero; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluKernelHalf2(const int N, const half2* X, half2* Y) { | ||||
|   const half2 kZero = __float2half2_rn(0.0); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     Y[i] = __hmul2(__hgt2(X[i], kZero), X[i]); | ||||
| #else | ||||
|     float2 xx = __half22float2(X[i]); | ||||
|     Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f, | ||||
|                              xx.y > 0 ? xx.y : 0.f); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void ReluGradientKernelHalf( | ||||
|     const int N, const half* Y, const half* dY, half* dX) { | ||||
|   const half kZero = __float2half(0.0); | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 530 | ||||
|     dX[i] = __hgt(Y[i], kZero) ? dY[i] : kZero; | ||||
| #else | ||||
|     dX[i] = (__half2float(Y[i]) > 0) ? dY[i] : kZero; | ||||
| #endif | ||||
|   } | ||||
| } | ||||
| }  // namespace | ||||
|  | ||||
| template <> | ||||
| bool ReluOp<float16, CUDAContext>::RunOnDevice() { | ||||
|   auto& X = Input(0); | ||||
|   auto* Y = Output(0); | ||||
|   CAFFE_ENFORCE_GT(X.size(), 0); | ||||
|   Y->ResizeLike(X); | ||||
|   if (X.size() % 2 == 0) { | ||||
|     ReluKernelHalf2<<<CAFFE_GET_BLOCKS(X.size() / 2), CAFFE_CUDA_NUM_THREADS, | ||||
|                       0, context_.cuda_stream()>>>( | ||||
|         X.size() / 2, reinterpret_cast<const half2*>(X.data<float16>()), | ||||
|         reinterpret_cast<half2*>(Y->mutable_data<float16>())); | ||||
|     return true; | ||||
|   } else { | ||||
|     ReluKernelHalf<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS, | ||||
|                      0, context_.cuda_stream()>>>( | ||||
|         X.size(), reinterpret_cast<const half*>(X.data<float16>()), | ||||
|         reinterpret_cast<half*>(Y->mutable_data<float16>())); | ||||
|     return true; | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| bool ReluGradientOp<float16, CUDAContext>::RunOnDevice() { | ||||
|   auto& Y = Input(0); | ||||
|   auto& dY = Input(1); | ||||
|   auto* dX = Output(0); | ||||
|   CAFFE_ENFORCE_GT(Y.size(), 0); | ||||
|   CAFFE_ENFORCE_EQ(dY.size(), Y.size()); | ||||
|   dX->ResizeLike(Y); | ||||
|   ReluGradientKernelHalf<<<CAFFE_GET_BLOCKS(Y.size()), CAFFE_CUDA_NUM_THREADS, | ||||
|                            0, context_.cuda_stream()>>>( | ||||
|       Y.size(), reinterpret_cast<const half*>(Y.data<float16>()), | ||||
|       reinterpret_cast<const half*>(dY.data<float16>()), | ||||
|       reinterpret_cast<half*>(dX->mutable_data<float16>())); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| OPERATOR_SCHEMA(ReluFp16); | ||||
| OPERATOR_SCHEMA(ReluFp16Gradient); | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(ReluFp16, ReluOp<float16, CUDAContext>); | ||||
| REGISTER_CUDA_OPERATOR(ReluFp16Gradient, ReluGradientOp<float16, CUDAContext>); | ||||
| }  // namespace caffe2 | ||||
|  | ||||
| #endif  // CAFFE_HAS_CUDA_FP16 | ||||
| @ -22,7 +22,7 @@ bool SigmoidGradientFunctor<CPUContext>::Forward( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   ConstEigenVectorArrayMap<T> dY_arr(dY, size); | ||||
|   ConstEigenVectorArrayMap<T> Y_arr(Y, size); | ||||
|   EigenVectorArrayMap<T>(dX, size) = dY_arr * Y_arr * (1. - Y_arr); | ||||
|   EigenVectorArrayMap<T>(dX, size) = dY_arr * Y_arr * (T(1) - Y_arr); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -8,8 +8,8 @@ template <> | ||||
| template <typename T> | ||||
| bool SigmoidFunctor<CPUContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const { | ||||
|   ConstEigenVectorArrayMap<T> X_arr(X, N); | ||||
|   EigenVectorArrayMap<T>(Y, N) = 1. / (1. + (-X_arr).exp()); | ||||
|   EigenVectorArrayMap<T>(Y, N) = | ||||
|       T(1) / (T(1) + (-ConstEigenVectorArrayMap<T>(X, N)).exp()); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -10,19 +10,23 @@ namespace caffe2 { | ||||
| namespace { | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void SigmoidKernel(const int N, const T* X, T* Y) { | ||||
| __global__ void SigmoidCUDAKernel(const int N, const T* X, T* Y); | ||||
|  | ||||
| template <> | ||||
| __global__ void | ||||
| SigmoidCUDAKernel<float>(const int N, const float* X, float* Y) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     Y[i] = T(1) / (T(1) + exp(-__ldg(X + i))); | ||||
|     Y[i] = 1.0f / (1.0f + expf(-__ldg(X + i))); | ||||
| #else | ||||
|     Y[i] = T(1) / (T(1) + exp(-X[i])); | ||||
|     Y[i] = 1.0f / (1.0f + expf(-X[i])); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void | ||||
| SigmoidGradientKernel(const int N, const T* dY, const T* Y, T* dX) { | ||||
| SigmoidGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) { | ||||
|   CUDA_1D_KERNEL_LOOP(i, N) { | ||||
| #if __CUDA_ARCH__ >= 350 | ||||
|     dX[i] = __ldg(dY + i) * __ldg(Y + i) * (T(1) - __ldg(Y + i)); | ||||
| @ -38,7 +42,7 @@ template <> | ||||
| template <typename T> | ||||
| bool SigmoidFunctor<CUDAContext>:: | ||||
| operator()(const int N, const T* X, T* Y, CUDAContext* context) const { | ||||
|   SigmoidKernel<T> | ||||
|   SigmoidCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(N), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
| @ -57,7 +61,7 @@ bool SigmoidGradientFunctor<CUDAContext>::Forward( | ||||
|     CUDAContext* context) const { | ||||
|   const int size = std::accumulate( | ||||
|       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>()); | ||||
|   SigmoidGradientKernel<T> | ||||
|   SigmoidGradientCUDAKernel<T> | ||||
|       <<<CAFFE_GET_BLOCKS(size), | ||||
|          CAFFE_CUDA_NUM_THREADS, | ||||
|          0, | ||||
|  | ||||
							
								
								
									
										12
									
								
								caffe2/operators/sigmoid_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								caffe2/operators/sigmoid_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| #include "caffe2/operators/sigmoid_op.h" | ||||
|  | ||||
| #include "caffe2/operators/activation_ops_cudnn.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| REGISTER_CUDNN_OPERATOR(Sigmoid, CuDNNActivationOp<CUDNN_ACTIVATION_SIGMOID>); | ||||
| REGISTER_CUDNN_OPERATOR( | ||||
|     SigmoidGradient, | ||||
|     CuDNNActivationGradientOp<CUDNN_ACTIVATION_SIGMOID>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
							
								
								
									
										12
									
								
								caffe2/operators/tanh_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								caffe2/operators/tanh_op_cudnn.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| #include "caffe2/operators/tanh_op.h" | ||||
|  | ||||
| #include "caffe2/operators/activation_ops_cudnn.h" | ||||
|  | ||||
| namespace caffe2 { | ||||
|  | ||||
| REGISTER_CUDNN_OPERATOR(Tanh, CuDNNActivationOp<CUDNN_ACTIVATION_TANH>); | ||||
| REGISTER_CUDNN_OPERATOR( | ||||
|     TanhGradient, | ||||
|     CuDNNActivationGradientOp<CUDNN_ACTIVATION_TANH>); | ||||
|  | ||||
| } // namespace caffe2 | ||||
| @ -8,35 +8,123 @@ import numpy as np | ||||
| from hypothesis import given | ||||
| import hypothesis.strategies as st | ||||
|  | ||||
| from caffe2.python import core | ||||
| from caffe2.python import core, workspace | ||||
| import caffe2.python.hypothesis_test_util as hu | ||||
| import caffe2.python.mkl_test_util as mu | ||||
|  | ||||
| import unittest | ||||
|  | ||||
|  | ||||
| class TestActivations(hu.HypothesisTestCase): | ||||
|     @given(X=hu.tensor(), in_place=st.booleans(), | ||||
|            engine=st.sampled_from(["", "CUDNN"]), **mu.gcs) | ||||
|     def test_relu(self, X, in_place, engine, gc, dc): | ||||
|         if gc == mu.mkl_do: | ||||
|             in_place = False | ||||
|  | ||||
|         op = core.CreateOperator( | ||||
|             "Relu", | ||||
|             ["X"], | ||||
|             ["X"] if in_place else ["Y"], | ||||
|             engine=engine, | ||||
|         ) | ||||
|  | ||||
|         def relu_ref(X): | ||||
|             return [np.maximum(X, 0.0)] | ||||
|  | ||||
|         # go away from the origin point to avoid kink problems | ||||
|         X += 0.02 * np.sign(X) | ||||
|         X[X == 0.0] += 0.02 | ||||
|  | ||||
|         self.assertReferenceChecks(gc, op, [X], relu_ref) | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0]) | ||||
|  | ||||
|     @unittest.skipIf(not workspace.has_gpu_support, | ||||
|                      "Relu for float16 can only run on GPU now.") | ||||
|     @given(X=hu.tensor(dtype=np.float16), in_place=st.booleans(), | ||||
|            engine=st.sampled_from(["", "CUDNN"]), **hu.gcs_gpu_only) | ||||
|     def test_relu_fp16(self, X, in_place, engine, gc, dc): | ||||
|         op = core.CreateOperator( | ||||
|             "Relu", | ||||
|             ["X"], | ||||
|             ["X"] if in_place else ["Y"], | ||||
|             engine=engine, | ||||
|         ) | ||||
|  | ||||
|         def relu_ref(X): | ||||
|             return [np.maximum(X, 0.0)] | ||||
|  | ||||
|         def relu_grad_ref(g_out, outputs, fwd_inputs): | ||||
|             dY = g_out | ||||
|             [Y] = outputs | ||||
|             dX = dY | ||||
|             dX[Y == 0] = 0 | ||||
|             return [dX] | ||||
|  | ||||
|         # go away from the origin point to avoid kink problems | ||||
|         X += 0.02 * np.sign(X) | ||||
|         X[X == 0.0] += 0.02 | ||||
|  | ||||
|         self.assertReferenceChecks( | ||||
|             hu.gpu_do, | ||||
|             op, | ||||
|             [X], | ||||
|             relu_ref, | ||||
|             output_to_grad="X" if in_place else "Y", | ||||
|             grad_reference=relu_grad_ref) | ||||
|  | ||||
|     @given(X=hu.tensor(elements=st.floats(-3.0, 3.0)), | ||||
|            n=st.floats(min_value=0.5, max_value=2.0), | ||||
|            in_place=st.booleans(), **hu.gcs) | ||||
|     def test_relu_n(self, X, n, in_place, gc, dc): | ||||
|         op = core.CreateOperator( | ||||
|             "ReluN", | ||||
|             ["X"], | ||||
|             ["X"] if in_place else ["Y"], | ||||
|             n=n, | ||||
|         ) | ||||
|  | ||||
|         def relu_n_ref(X): | ||||
|             return [np.minimum(np.maximum(X, 0.0), n)] | ||||
|  | ||||
|         # go away from 0 and n to avoid kink problems | ||||
|         X += 0.04 * np.sign(X) | ||||
|         X[X == 0.0] += 0.04 | ||||
|         X -= n | ||||
|         X += 0.02 * np.sign(X) | ||||
|         X[X == 0.0] -= 0.02 | ||||
|         X += n | ||||
|  | ||||
|         self.assertReferenceChecks(gc, op, [X], relu_n_ref) | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.005) | ||||
|  | ||||
|     @given(X=hu.tensor(), | ||||
|            alpha=st.floats(min_value=0.1, max_value=2.0), | ||||
|            inplace=st.booleans(), | ||||
|            in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]), | ||||
|            **hu.gcs) | ||||
|     def test_elu(self, X, alpha, inplace, gc, dc): | ||||
|     def test_elu(self, X, alpha, in_place, engine, gc, dc): | ||||
|         op = core.CreateOperator( | ||||
|             "Elu", | ||||
|             ["X"], | ||||
|             ["X"] if in_place else ["Y"], | ||||
|             alpha=alpha, | ||||
|             engine=engine, | ||||
|         ) | ||||
|  | ||||
|         def elu_ref(X): | ||||
|             Y = X | ||||
|             Y[X < 0] = alpha * (np.exp(X[X < 0]) - 1.0) | ||||
|             return [Y] | ||||
|  | ||||
|         # go away from the origin point to avoid kink problems | ||||
|         X += 0.04 * np.sign(X) | ||||
|         X[X == 0.0] += 0.04 | ||||
|  | ||||
|         def elu_ref(X): | ||||
|             Y = X.copy() | ||||
|             neg_indices = X <= 0 | ||||
|             Y[neg_indices] = alpha * (np.exp(Y[neg_indices]) - 1) | ||||
|             return (Y,) | ||||
|  | ||||
|         op = core.CreateOperator( | ||||
|             "Elu", | ||||
|             ["X"], ["Y" if not inplace else "X"], | ||||
|             alpha=alpha) | ||||
|         self.assertReferenceChecks(gc, op, [X], elu_ref) | ||||
|         # Check over multiple devices | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         # Gradient check wrt X | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=1e-2) | ||||
|  | ||||
|     @given(X=hu.tensor(min_dim=4, max_dim=4), | ||||
|            alpha=st.floats(min_value=0.1, max_value=2.0), | ||||
| @ -124,8 +212,3 @@ class TestActivations(hu.HypothesisTestCase): | ||||
|         self.assertReferenceChecks(gc, op, [X], leaky_relu_ref) | ||||
|         # Check over multiple devices | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     import unittest | ||||
|     unittest.main() | ||||
|  | ||||
| @ -311,12 +311,14 @@ class TestElementwiseOps(hu.HypothesisTestCase): | ||||
|             reference=swish_gradient, | ||||
|         ) | ||||
|  | ||||
|     @given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(), **hu.gcs) | ||||
|     def test_sigmoid(self, X, inplace, gc, dc): | ||||
|     @given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(), | ||||
|            engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) | ||||
|     def test_sigmoid(self, X, inplace, engine, gc, dc): | ||||
|         op = core.CreateOperator( | ||||
|             "Sigmoid", | ||||
|             ["X"], | ||||
|             ["X"] if inplace else ["Y"], | ||||
|             engine=engine, | ||||
|         ) | ||||
|  | ||||
|         def sigmoid_ref(X): | ||||
|  | ||||
| @ -11,8 +11,12 @@ import numpy as np | ||||
|  | ||||
|  | ||||
| class TestHyperbolicOps(hu.HypothesisTestCase): | ||||
|     def _test_hyperbolic_op(self, op_name, np_ref, X, in_place, gc, dc): | ||||
|         op = core.CreateOperator(op_name, ["X"], ["X"] if in_place else ["Y"]) | ||||
|     def _test_hyperbolic_op(self, op_name, np_ref, X, in_place, engine, gc, dc): | ||||
|         op = core.CreateOperator( | ||||
|             op_name, | ||||
|             ["X"], | ||||
|             ["X"] if in_place else ["Y"], | ||||
|             engine=engine,) | ||||
|  | ||||
|         def ref(X): | ||||
|             return [np_ref(X)] | ||||
| @ -26,14 +30,15 @@ class TestHyperbolicOps(hu.HypothesisTestCase): | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0]) | ||||
|  | ||||
|     @given(X=hu.tensor(dtype=np.float32), in_place=st.booleans(), **hu.gcs) | ||||
|     def test_tanh(self, X, in_place, gc, dc): | ||||
|         self._test_hyperbolic_op("Tanh", np.tanh, X, in_place, gc, dc) | ||||
|  | ||||
|     @given(X=hu.tensor(dtype=np.float32), **hu.gcs) | ||||
|     def test_sinh(self, X, gc, dc): | ||||
|         self._test_hyperbolic_op("Sinh", np.sinh, X, False, gc, dc) | ||||
|         self._test_hyperbolic_op("Sinh", np.sinh, X, False, "", gc, dc) | ||||
|  | ||||
|     @given(X=hu.tensor(dtype=np.float32), **hu.gcs) | ||||
|     def test_cosh(self, X, gc, dc): | ||||
|         self._test_hyperbolic_op("Cosh", np.cosh, X, False, gc, dc) | ||||
|         self._test_hyperbolic_op("Cosh", np.cosh, X, False, "", gc, dc) | ||||
|  | ||||
|     @given(X=hu.tensor(dtype=np.float32), in_place=st.booleans(), | ||||
|            engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) | ||||
|     def test_tanh(self, X, in_place, engine, gc, dc): | ||||
|         self._test_hyperbolic_op("Tanh", np.tanh, X, in_place, engine, gc, dc) | ||||
|  | ||||
| @ -1,51 +0,0 @@ | ||||
| # Copyright (c) 2016-present, Facebook, Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| ############################################################################## | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| from __future__ import unicode_literals | ||||
|  | ||||
| from caffe2.python import core | ||||
| from hypothesis import given | ||||
| import caffe2.python.hypothesis_test_util as hu | ||||
| import numpy as np | ||||
|  | ||||
| import unittest | ||||
|  | ||||
|  | ||||
| class TestRelu(hu.HypothesisTestCase): | ||||
|  | ||||
|     @given(X=hu.tensor(), | ||||
|            **hu.gcs) | ||||
|     def test_relu_n(self, X, gc, dc): | ||||
|         X = 0.8 * np.sign(X) | ||||
|         X = X - 0.5 | ||||
|         X[X == 0.0] = 0.01 | ||||
|         n = max(np.max(X), 1.0) / 2 | ||||
|         X[X == 0.2] = 0.01 | ||||
|  | ||||
|         def relu_n_ref(X): | ||||
|             Y = np.minimum(np.maximum(X, 0), n) | ||||
|             return [Y] | ||||
|  | ||||
|         op = core.CreateOperator("ReluN", ["X"], ["Y"], n=n) | ||||
|         self.assertReferenceChecks(gc, op, [X], relu_n_ref) | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.001, threshold=0.001) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
| @ -1,31 +0,0 @@ | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| from __future__ import unicode_literals | ||||
|  | ||||
| from caffe2.python import core | ||||
| from hypothesis import given | ||||
| import hypothesis.strategies as st | ||||
| import caffe2.python.hypothesis_test_util as hu | ||||
| import caffe2.python.mkl_test_util as mu | ||||
| import numpy as np | ||||
|  | ||||
| import unittest | ||||
|  | ||||
|  | ||||
| class TestRelu(hu.HypothesisTestCase): | ||||
|  | ||||
|     @given(X=hu.tensor(), | ||||
|            engine=st.sampled_from(["", "CUDNN"]), | ||||
|            **mu.gcs) | ||||
|     def test_relu(self, X, gc, dc, engine): | ||||
|         op = core.CreateOperator("Relu", ["X"], ["Y"], engine=engine) | ||||
|         # go away from the origin point to avoid kink problems | ||||
|         X += 0.02 * np.sign(X) | ||||
|         X[X == 0.0] += 0.02 | ||||
|         self.assertDeviceChecks(dc, op, [X], [0]) | ||||
|         self.assertGradientChecks(gc, op, [X], 0, [0]) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user