mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
backward support to cudnn R2 for TensorFlow benchmark references
This commit is contained in:
@ -13,8 +13,8 @@
|
|||||||
#include "caffe2/proto/caffe2.pb.h"
|
#include "caffe2/proto/caffe2.pb.h"
|
||||||
#include "caffe2/core/logging.h"
|
#include "caffe2/core/logging.h"
|
||||||
|
|
||||||
static_assert(CUDNN_VERSION >= 3000,
|
static_assert(CUDNN_VERSION >= 2000,
|
||||||
"Caffe2 requires cudnn version 3.0 or above.");
|
"Caffe2 requires cudnn version 2.0 or above.");
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
|
|
||||||
@ -78,10 +78,12 @@ template<> class cudnnTypeWrapper<double> {
|
|||||||
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
|
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
template<> class cudnnTypeWrapper<float16> {
|
template<> class cudnnTypeWrapper<float16> {
|
||||||
public:
|
public:
|
||||||
static const cudnnDataType_t type = CUDNN_DATA_HALF;
|
static const cudnnDataType_t type = CUDNN_DATA_HALF;
|
||||||
};
|
};
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A wrapper function to convert the Caffe storage order to cudnn storage order
|
* A wrapper function to convert the Caffe storage order to cudnn storage order
|
||||||
|
|||||||
@ -106,9 +106,10 @@ class CudnnConvGradientOp final : public CudnnConvOpBase {
|
|||||||
bool RunWithCudnnWorkspace(CuDNNWorkspaceWrapper* cudnn_ws_wrapper) override;
|
bool RunWithCudnnWorkspace(CuDNNWorkspaceWrapper* cudnn_ws_wrapper) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
|
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
|
||||||
cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
|
cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
// input: X, W, dY
|
// input: X, W, dY
|
||||||
// output: dW, db, and optionally dX
|
// output: dW, db, and optionally dX
|
||||||
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
|
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
|
||||||
@ -216,9 +217,16 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
|
|||||||
algo_, cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_, &kZero,
|
algo_, cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_, &kZero,
|
||||||
top_desc_, Y->template mutable_data<T>()));
|
top_desc_, Y->template mutable_data<T>()));
|
||||||
// Bias
|
// Bias
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
CUDNN_CHECK(cudnnAddTensor_v3(
|
CUDNN_CHECK(cudnnAddTensor_v3(
|
||||||
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
|
cudnn_wrapper_.cudnn_handle(), &kOne, bias_desc_,
|
||||||
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
|
bias.template data<T>(), &kOne, top_desc_, Y->template mutable_data<T>()));
|
||||||
|
#else // CUDNN_VERSION >= 3000
|
||||||
|
CUDNN_CHECK(cudnnAddTensor(
|
||||||
|
cudnn_wrapper_.cudnn_handle(), CUDNN_ADD_SAME_C,
|
||||||
|
&kOne, bias_desc_, bias.template data<T>(), &kOne, top_desc_,
|
||||||
|
Y->template mutable_data<T>()));
|
||||||
|
#endif
|
||||||
// Done.
|
// Done.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -296,6 +304,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
|||||||
|
|
||||||
size_t bwd_filter_ws_size, bwd_data_ws_size;
|
size_t bwd_filter_ws_size, bwd_data_ws_size;
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
// choose backward algorithm for filter
|
// choose backward algorithm for filter
|
||||||
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
|
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
|
||||||
cudnn_wrapper_.cudnn_handle(),
|
cudnn_wrapper_.cudnn_handle(),
|
||||||
@ -320,6 +329,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
|||||||
bwd_data_algo_, &bwd_data_ws_size));
|
bwd_data_algo_, &bwd_data_ws_size));
|
||||||
cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, bwd_data_ws_size);
|
cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, bwd_data_ws_size);
|
||||||
CAFFE_VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
|
CAFFE_VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now, actually run the computation.
|
// Now, actually run the computation.
|
||||||
@ -328,30 +338,46 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
|||||||
CUDNN_CHECK(cudnnConvolutionBackwardBias(
|
CUDNN_CHECK(cudnnConvolutionBackwardBias(
|
||||||
cudnn_wrapper_.cudnn_handle(), &kOne, top_desc_, dY.template data<T>(),
|
cudnn_wrapper_.cudnn_handle(), &kOne, top_desc_, dY.template data<T>(),
|
||||||
&kZero, bias_desc_, dbias->template mutable_data<T>()));
|
&kZero, bias_desc_, dbias->template mutable_data<T>()));
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
|
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
|
||||||
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
|
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
|
||||||
top_desc_, dY.template data<T>(), conv_desc_, bwd_filter_algo_,
|
top_desc_, dY.template data<T>(), conv_desc_, bwd_filter_algo_,
|
||||||
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
||||||
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
|
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
|
||||||
|
#else // CUDNN_VERSION >= 3000
|
||||||
|
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
|
||||||
|
cudnn_wrapper_.cudnn_handle(), &kOne, bottom_desc_, X.template data<T>(),
|
||||||
|
top_desc_, dY.template data<T>(), conv_desc_,
|
||||||
|
&kZero, filter_desc_, dfilter->template mutable_data<T>()));
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
|
|
||||||
if (OutputSize() == 3) {
|
if (OutputSize() == 3) {
|
||||||
// Compute the gradient w.r.t. the input.
|
// Compute the gradient w.r.t. the input.
|
||||||
auto *dX = Output(INPUT_GRAD);
|
auto *dX = Output(INPUT_GRAD);
|
||||||
dX->ReshapeLike(X);
|
dX->ReshapeLike(X);
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
|
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
|
||||||
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
|
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
|
||||||
filter.template data<T>(), top_desc_, dY.template data<T>(),
|
filter.template data<T>(), top_desc_, dY.template data<T>(),
|
||||||
conv_desc_, bwd_data_algo_,
|
conv_desc_, bwd_data_algo_,
|
||||||
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
cudnn_ws_wrapper->Get(cudnn_ws_nbytes_), cudnn_ws_nbytes_,
|
||||||
&kZero, bottom_desc_, dX->template mutable_data<T>()));
|
&kZero, bottom_desc_, dX->template mutable_data<T>()));
|
||||||
|
#else // CUDNN_VERSION >= 3000
|
||||||
|
CUDNN_CHECK(cudnnConvolutionBackwardData(
|
||||||
|
cudnn_wrapper_.cudnn_handle(), &kOne, filter_desc_,
|
||||||
|
filter.template data<T>(), top_desc_, dY.template data<T>(),
|
||||||
|
conv_desc_, &kZero, bottom_desc_, dX->template mutable_data<T>()));
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_CUDNN_OPERATOR(Conv, CudnnConvOp<float>)
|
REGISTER_CUDNN_OPERATOR(Conv, CudnnConvOp<float>)
|
||||||
REGISTER_CUDNN_OPERATOR(ConvGradient, CudnnConvGradientOp<float>)
|
REGISTER_CUDNN_OPERATOR(ConvGradient, CudnnConvGradientOp<float>)
|
||||||
|
#if CUDNN_VERSION >= 3000
|
||||||
REGISTER_CUDNN_OPERATOR(ConvFp16, CudnnConvOp<float16>)
|
REGISTER_CUDNN_OPERATOR(ConvFp16, CudnnConvOp<float16>)
|
||||||
REGISTER_CUDNN_OPERATOR(ConvFp16Gradient, CudnnConvGradientOp<float16>)
|
REGISTER_CUDNN_OPERATOR(ConvFp16Gradient, CudnnConvGradientOp<float16>)
|
||||||
|
#endif // CUDNN_VERSION >= 3000
|
||||||
|
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
|||||||
Reference in New Issue
Block a user