mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
backward support to cudnn R2 for TensorFlow benchmark references
This commit is contained in:
@ -13,8 +13,8 @@
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
static_assert(CUDNN_VERSION >= 3000,
|
||||
"Caffe2 requires cudnn version 3.0 or above.");
|
||||
static_assert(CUDNN_VERSION >= 2000,
|
||||
"Caffe2 requires cudnn version 2.0 or above.");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -78,10 +78,12 @@ template<> class cudnnTypeWrapper<double> {
|
||||
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
|
||||
};
|
||||
|
||||
#if CUDNN_VERSION >= 3000
|
||||
template<> class cudnnTypeWrapper<float16> {
|
||||
public:
|
||||
static const cudnnDataType_t type = CUDNN_DATA_HALF;
|
||||
};
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
|
||||
/**
|
||||
* A wrapper function to convert the Caffe storage order to cudnn storage order
|
||||
|
@ -106,9 +106,10 @@ class CudnnConvGradientOp final : public CudnnConvOpBase {
|
||||
bool RunWithCudnnWorkspace(CuDNNWorkspaceWrapper* cudnn_ws_wrapper) override;
|
||||
|
||||
private:
|
||||
#if CUDNN_VERSION >= 3000
|
||||
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
|
||||
cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
|
||||
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
// input: X, W, dY
|
||||
// output: dW, db, and optionally dX
|
||||
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
|
||||
@ -216,9 +217,16 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
|
||||
algo_, cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_, &kZero,
|
||||
top_desc_, Y->template mutable_data<T>()));
|
||||
// Bias
|
||||
#if CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnAddTensor_v3(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
|
||||
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
|
||||
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
|
||||
#else // CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnAddTensor(
|
||||
cudnn_wrapper_.cudnn_handle(), CUDNN_ADD_SAME_C,
|
||||
&kOne, bias_desc_, bias.template data<T>(), &kOne, top_desc_,
|
||||
Y->template mutable_data<T>()));
|
||||
#endif
|
||||
// Done.
|
||||
return true;
|
||||
}
|
||||
@ -296,6 +304,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
||||
|
||||
size_t bwd_filter_ws_size, bwd_data_ws_size;
|
||||
|
||||
#if CUDNN_VERSION >= 3000
|
||||
// choose backward algorithm for filter
|
||||
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
|
||||
cudnn_wrapper_.cudnn_handle(),
|
||||
@ -320,6 +329,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
||||
bwd_data_algo_, &bwd_data_ws_size));
|
||||
cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, bwd_data_ws_size);
|
||||
CAFFE_VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
}
|
||||
|
||||
// Now, actually run the computation.
|
||||
@ -328,30 +338,46 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardBias(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, top_desc_, dY.template data<T>(),
|
||||
&kZero, bias_desc_, dbias->template mutable_data<T>()));
|
||||
#if CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
|
||||
top_desc_, dY.template data<T>(), conv_desc_, bwd_filter_algo_,
|
||||
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
||||
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
|
||||
#else // CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
|
||||
top_desc_, dY.template data<T>(), conv_desc_,
|
||||
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
|
||||
if (OutputSize() == 3) {
|
||||
// Compute the gradient w.r.t. the input.
|
||||
auto *dX = Output(INPUT_GRAD);
|
||||
dX->ReshapeLike(X);
|
||||
#if CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
|
||||
filter.template data<T>(), top_desc_, dY.template data<T>(),
|
||||
conv_desc_, bwd_data_algo_,
|
||||
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
||||
&kZero, bottom_desc_, dX->template mutable_data<T>()));
|
||||
#else // CUDNN_VERSION >= 3000
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardData(
|
||||
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
|
||||
filter.template data<T>(), top_desc_, dY.template data<T>(),
|
||||
conv_desc_, &kZero, bottom_desc_, dX->template mutable_data<T>()));
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CUDNN_OPERATOR(Conv, CudnnConvOp<float>)
|
||||
REGISTER_CUDNN_OPERATOR(ConvGradient, CudnnConvGradientOp<float>)
|
||||
#if CUDNN_VERSION >= 3000
|
||||
REGISTER_CUDNN_OPERATOR(ConvFp16, CudnnConvOp<float16>)
|
||||
REGISTER_CUDNN_OPERATOR(ConvFp16Gradient, CudnnConvGradientOp<float16>)
|
||||
#endif // CUDNN_VERSION >= 3000
|
||||
|
||||
|
||||
} // namespace caffe2
|
||||
|
Reference in New Issue
Block a user