backward support to cudnn R2 for TensorFlow benchmark references

This commit is contained in:
Yangqing Jia
2015-12-02 15:12:04 -08:00
parent acc16645d3
commit 01b45fd052
2 changed files with 33 additions and 5 deletions

View File

@ -13,8 +13,8 @@
#include "caffe2/proto/caffe2.pb.h"
#include "caffe2/core/logging.h"
static_assert(CUDNN_VERSION >= 3000,
"Caffe2 requires cudnn version 3.0 or above.");
static_assert(CUDNN_VERSION >= 2000,
"Caffe2 requires cudnn version 2.0 or above.");
namespace caffe2 {
@ -78,10 +78,12 @@ template<> class cudnnTypeWrapper<double> {
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
};
#if CUDNN_VERSION >= 3000
template<> class cudnnTypeWrapper<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
};
#endif // CUDNN_VERSION >= 3000
/**
* A wrapper function to convert the Caffe storage order to cudnn storage order

View File

@ -106,9 +106,10 @@ class CudnnConvGradientOp final : public CudnnConvOpBase {
bool RunWithCudnnWorkspace(CuDNNWorkspaceWrapper* cudnn_ws_wrapper) override;
private:
#if CUDNN_VERSION >= 3000
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
#endif // CUDNN_VERSION >= 3000
// input: X, W, dY
// output: dW, db, and optionally dX
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
@ -216,9 +217,16 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
algo_, cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_, &kZero,
top_desc_, Y->template mutable_data<T>()));
// Bias
#if CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnAddTensor_v3(
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
#else // CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnAddTensor(
cudnn_wrapper_.cudnn_handle(), CUDNN_ADD_SAME_C,
&kOne, bias_desc_, bias.template data<T>(), &kOne, top_desc_,
Y->template mutable_data<T>()));
#endif
// Done.
return true;
}
@ -296,6 +304,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
size_t bwd_filter_ws_size, bwd_data_ws_size;
#if CUDNN_VERSION >= 3000
// choose backward algorithm for filter
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
cudnn_wrapper_.cudnn_handle(),
@ -320,6 +329,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
bwd_data_algo_, &bwd_data_ws_size));
cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, bwd_data_ws_size);
CAFFE_VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
#endif // CUDNN_VERSION >= 3000
}
// Now, actually run the computation.
@ -328,30 +338,46 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
CUDNN_CHECK(cudnnConvolutionBackwardBias(
cudnn_wrapper_.cudnn_handle(), &kOne, top_desc_, dY.template data<T>(),
&kZero, bias_desc_, dbias->template mutable_data<T>()));
#if CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
top_desc_, dY.template data<T>(), conv_desc_, bwd_filter_algo_,
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
#else // CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
top_desc_, dY.template data<T>(), conv_desc_,
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
#endif // CUDNN_VERSION >= 3000
if (OutputSize() == 3) {
// Compute the gradient w.r.t. the input.
auto *dX = Output(INPUT_GRAD);
dX->ReshapeLike(X);
#if CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
filter.template data<T>(), top_desc_, dY.template data<T>(),
conv_desc_, bwd_data_algo_,
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
&kZero, bottom_desc_, dX->template mutable_data<T>()));
#else // CUDNN_VERSION >= 3000
CUDNN_CHECK(cudnnConvolutionBackwardData(
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
filter.template data<T>(), top_desc_, dY.template data<T>(),
conv_desc_, &kZero, bottom_desc_, dX->template mutable_data<T>()));
#endif // CUDNN_VERSION >= 3000
}
return true;
}
REGISTER_CUDNN_OPERATOR(Conv, CudnnConvOp<float>)
REGISTER_CUDNN_OPERATOR(ConvGradient, CudnnConvGradientOp<float>)
#if CUDNN_VERSION >= 3000
REGISTER_CUDNN_OPERATOR(ConvFp16, CudnnConvOp<float16>)
REGISTER_CUDNN_OPERATOR(ConvFp16Gradient, CudnnConvGradientOp<float16>)
#endif // CUDNN_VERSION >= 3000
} // namespace caffe2