mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
exhaustive search for cudnn
This commit is contained in:
@ -265,6 +265,12 @@ class CuDNNWorkspaceWrapper {
|
||||
|
||||
std::mutex& mutex() { return mutex_; }
|
||||
|
||||
void Reset() {
|
||||
if (data_) CUDAContext::Delete(data_);
|
||||
data_ = nullptr;
|
||||
nbytes_ = 0;
|
||||
}
|
||||
|
||||
void* Get(const size_t nbytes) {
|
||||
if (nbytes > nbytes_) {
|
||||
if (data_) CUDAContext::Delete(data_);
|
||||
|
@ -15,7 +15,9 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
OperatorBase::GetSingleArgument<int>(
|
||||
"ws_nbytes_limit", kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
|
||||
shared_ws_name_(
|
||||
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")) {
|
||||
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")),
|
||||
exhaustive_search_(
|
||||
OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)) {
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bottom_desc_));
|
||||
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
|
||||
@ -72,6 +74,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
string shared_ws_name_;
|
||||
size_t cudnn_ws_nbytes_;
|
||||
CuDNNWorkspaceWrapper* cudnn_ws_;
|
||||
bool exhaustive_search_;
|
||||
std::unique_ptr<CuDNNWorkspaceWrapper> local_cudnn_ws_;
|
||||
DISABLE_COPY_AND_ASSIGN(CudnnConvOpBase);
|
||||
};
|
||||
@ -192,13 +195,31 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
|
||||
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_, pad_t_, pad_l_, stride_h_, stride_w_, 1, 1,
|
||||
CUDNN_CROSS_CORRELATION));
|
||||
// Set the workspace
|
||||
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
|
||||
cudnn_wrapper_.cudnn_handle(),
|
||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
||||
cudnn_ws_nbytes_limit_,
|
||||
&algo_));
|
||||
if (exhaustive_search_) {
|
||||
// When we do an exhaustive search, we will ignore the workspace size
|
||||
// limit and simply go for the fastest algorithm. If you happen to run
|
||||
// out of memory later, you will be on your own...
|
||||
int returned_algo_count;
|
||||
cudnnConvolutionFwdAlgoPerf_t perf_stat;
|
||||
// We clean up the current workspace memory so that the forward algorithm
|
||||
// is free to allocate memory.
|
||||
cudnn_ws_wrapper->Reset();
|
||||
// Actually run the search.
|
||||
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
|
||||
cudnn_wrapper_.cudnn_handle(),
|
||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||
1, &returned_algo_count, &perf_stat));
|
||||
CAFFE_DCHECK_EQ(returned_algo_count, 1);
|
||||
algo_ = perf_stat.algo;
|
||||
} else {
|
||||
// Get the convolution algorithm based on the workspace limit.
|
||||
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
|
||||
cudnn_wrapper_.cudnn_handle(),
|
||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
||||
cudnn_ws_nbytes_limit_,
|
||||
&algo_));
|
||||
}
|
||||
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
|
||||
cudnn_wrapper_.cudnn_handle(),
|
||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||
|
Reference in New Issue
Block a user