exhaustive search for cudnn

This commit is contained in:
Yangqing Jia
2015-12-15 22:20:21 -08:00
parent 61c114971b
commit d79cfb4ae7
3 changed files with 37 additions and 9 deletions

View File

@ -265,6 +265,12 @@ class CuDNNWorkspaceWrapper {
std::mutex& mutex() { return mutex_; }
void Reset() {
if (data_) CUDAContext::Delete(data_);
data_ = nullptr;
nbytes_ = 0;
}
void* Get(const size_t nbytes) {
if (nbytes > nbytes_) {
if (data_) CUDAContext::Delete(data_);

View File

@ -15,7 +15,9 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
OperatorBase::GetSingleArgument<int>(
"ws_nbytes_limit", kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
shared_ws_name_(
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")) {
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")),
exhaustive_search_(
OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bottom_desc_));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
@ -72,6 +74,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
string shared_ws_name_;
size_t cudnn_ws_nbytes_;
CuDNNWorkspaceWrapper* cudnn_ws_;
bool exhaustive_search_;
std::unique_ptr<CuDNNWorkspaceWrapper> local_cudnn_ws_;
DISABLE_COPY_AND_ASSIGN(CudnnConvOpBase);
};
@ -192,13 +195,31 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(
conv_desc_, pad_t_, pad_l_, stride_h_, stride_w_, 1, 1,
CUDNN_CROSS_CORRELATION));
// Set the workspace
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
cudnn_wrapper_.cudnn_handle(),
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
cudnn_ws_nbytes_limit_,
&algo_));
if (exhaustive_search_) {
// When we do an exhaustive search, we will ignore the workspace size
// limit and simply go for the fastest algorithm. If you happen to run
// out of memory later, you will be on your own...
int returned_algo_count;
cudnnConvolutionFwdAlgoPerf_t perf_stat;
// We clean up the current workspace memory so that the forward algorithm
// is free to allocate memory.
cudnn_ws_wrapper->Reset();
// Actually run the search.
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
cudnn_wrapper_.cudnn_handle(),
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
1, &returned_algo_count, &perf_stat));
CAFFE_DCHECK_EQ(returned_algo_count, 1);
algo_ = perf_stat.algo;
} else {
// Get the convolution algorithm based on the workspace limit.
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
cudnn_wrapper_.cudnn_handle(),
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
cudnn_ws_nbytes_limit_,
&algo_));
}
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
cudnn_wrapper_.cudnn_handle(),
bottom_desc_, filter_desc_, conv_desc_, top_desc_,