mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
exhaustive search for cudnn
This commit is contained in:
@ -265,6 +265,12 @@ class CuDNNWorkspaceWrapper {
|
|||||||
|
|
||||||
std::mutex& mutex() { return mutex_; }
|
std::mutex& mutex() { return mutex_; }
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
if (data_) CUDAContext::Delete(data_);
|
||||||
|
data_ = nullptr;
|
||||||
|
nbytes_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
void* Get(const size_t nbytes) {
|
void* Get(const size_t nbytes) {
|
||||||
if (nbytes > nbytes_) {
|
if (nbytes > nbytes_) {
|
||||||
if (data_) CUDAContext::Delete(data_);
|
if (data_) CUDAContext::Delete(data_);
|
||||||
|
@ -15,7 +15,9 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
|||||||
OperatorBase::GetSingleArgument<int>(
|
OperatorBase::GetSingleArgument<int>(
|
||||||
"ws_nbytes_limit", kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
|
"ws_nbytes_limit", kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
|
||||||
shared_ws_name_(
|
shared_ws_name_(
|
||||||
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")) {
|
OperatorBase::GetSingleArgument<string>("shared_ws_name", "")),
|
||||||
|
exhaustive_search_(
|
||||||
|
OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)) {
|
||||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bottom_desc_));
|
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bottom_desc_));
|
||||||
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
|
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
|
||||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
|
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
|
||||||
@ -72,6 +74,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
|||||||
string shared_ws_name_;
|
string shared_ws_name_;
|
||||||
size_t cudnn_ws_nbytes_;
|
size_t cudnn_ws_nbytes_;
|
||||||
CuDNNWorkspaceWrapper* cudnn_ws_;
|
CuDNNWorkspaceWrapper* cudnn_ws_;
|
||||||
|
bool exhaustive_search_;
|
||||||
std::unique_ptr<CuDNNWorkspaceWrapper> local_cudnn_ws_;
|
std::unique_ptr<CuDNNWorkspaceWrapper> local_cudnn_ws_;
|
||||||
DISABLE_COPY_AND_ASSIGN(CudnnConvOpBase);
|
DISABLE_COPY_AND_ASSIGN(CudnnConvOpBase);
|
||||||
};
|
};
|
||||||
@ -192,13 +195,31 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
|
|||||||
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(
|
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(
|
||||||
conv_desc_, pad_t_, pad_l_, stride_h_, stride_w_, 1, 1,
|
conv_desc_, pad_t_, pad_l_, stride_h_, stride_w_, 1, 1,
|
||||||
CUDNN_CROSS_CORRELATION));
|
CUDNN_CROSS_CORRELATION));
|
||||||
// Set the workspace
|
if (exhaustive_search_) {
|
||||||
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
|
// When we do an exhaustive search, we will ignore the workspace size
|
||||||
cudnn_wrapper_.cudnn_handle(),
|
// limit and simply go for the fastest algorithm. If you happen to run
|
||||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
// out of memory later, you will be on your own...
|
||||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
int returned_algo_count;
|
||||||
cudnn_ws_nbytes_limit_,
|
cudnnConvolutionFwdAlgoPerf_t perf_stat;
|
||||||
&algo_));
|
// We clean up the current workspace memory so that the forward algorithm
|
||||||
|
// is free to allocate memory.
|
||||||
|
cudnn_ws_wrapper->Reset();
|
||||||
|
// Actually run the search.
|
||||||
|
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
|
||||||
|
cudnn_wrapper_.cudnn_handle(),
|
||||||
|
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||||
|
1, &returned_algo_count, &perf_stat));
|
||||||
|
CAFFE_DCHECK_EQ(returned_algo_count, 1);
|
||||||
|
algo_ = perf_stat.algo;
|
||||||
|
} else {
|
||||||
|
// Get the convolution algorithm based on the workspace limit.
|
||||||
|
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
|
||||||
|
cudnn_wrapper_.cudnn_handle(),
|
||||||
|
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||||
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
||||||
|
cudnn_ws_nbytes_limit_,
|
||||||
|
&algo_));
|
||||||
|
}
|
||||||
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
|
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
cudnn_wrapper_.cudnn_handle(),
|
cudnn_wrapper_.cudnn_handle(),
|
||||||
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
bottom_desc_, filter_desc_, conv_desc_, top_desc_,
|
||||||
|
@ -205,7 +205,8 @@ def Benchmark(model_gen, order, batch_size, cudnn_limit, forward_only,
|
|||||||
for op in model.net._net.op:
|
for op in model.net._net.op:
|
||||||
if op.type == 'Conv':
|
if op.type == 'Conv':
|
||||||
op.engine = 'CUDNN'
|
op.engine = 'CUDNN'
|
||||||
op.arg.add().CopyFrom(utils.MakeArgument('ws_nbytes_limit', cudnn_limit))
|
#op.arg.add().CopyFrom(utils.MakeArgument('ws_nbytes_limit', cudnn_limit))
|
||||||
|
op.arg.add().CopyFrom(utils.MakeArgument('exhaustive_search', 1))
|
||||||
op.arg.add().CopyFrom(utils.MakeArgument('shared_ws_name', 'cudnn_workspace'))
|
op.arg.add().CopyFrom(utils.MakeArgument('shared_ws_name', 'cudnn_workspace'))
|
||||||
elif op.type in ['Relu', 'Softmax']:
|
elif op.type in ['Relu', 'Softmax']:
|
||||||
op.engine = 'CUDNN'
|
op.engine = 'CUDNN'
|
||||||
|
Reference in New Issue
Block a user