[caffe2] Use extended versions of cuDNN calls for SpatialBN

Summary: Using `cudnnBatchNormalizationForwardTrainingEx` and `cudnnBatchNormalizationBackwardEx` if cuDNN version is greater than 8.0.0.

Reviewed By: xw285cornell

Differential Revision: D26794173

fbshipit-source-id: dc4994375350f303a3fa0aee03255e8f8be1c605
This commit is contained in:
Valentin Andrei
2021-03-05 18:15:56 -08:00
committed by Facebook GitHub Bot
parent 758fb94fcb
commit 369601355f

View File

@ -191,6 +191,66 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
return true;
}
const double alpha = static_cast<double>(1.0f - momentum_);
#if CUDNN_VERSION_MIN(8, 0, 0)
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
auto op = CUDNN_BATCHNORM_OPS_BN;
// Calculate the workspace size
size_t workspace_size;
CUDNN_ENFORCE(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
op,
data_desc_,
NULL,
data_desc_,
param_desc_,
NULL,
&workspace_size));
// Calculate the reserved space size - common function for forward and backward
size_t reserve_size;
CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
op,
NULL,
data_desc_,
&reserve_size));
// CUDNN state is needed to access the workspace
size_t cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0));
cudnn_wrapper_.with_cudnn_state(
cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnBatchNormalizationForwardTrainingEx(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
CUDNN_BATCHNORM_OPS_BN,
cudnnTypeWrapper<T>::kOne(),
cudnnTypeWrapper<T>::kZero(),
data_desc_,
X_data,
NULL,
NULL,
data_desc_,
Y_data,
param_desc_,
scale_data,
bias_data,
alpha,
running_mean_data,
running_var_data,
epsilon_,
saved_mean_data,
saved_inv_std_data,
NULL,
state->workspace().get(workspace_size),
workspace_size,
state->workspace().get(reserve_size),
reserve_size));
});
#else
CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
@ -209,6 +269,7 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
epsilon_,
saved_mean_data,
saved_inv_std_data));
#endif // CUDNN_VERSION_MIN(8, 0, 0)
}
return true;
}
@ -314,6 +375,71 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
data_desc_,
param_desc_);
}
#if CUDNN_VERSION_MIN(8, 0, 0)
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size;
CUDNN_ENFORCE(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
op,
data_desc_,
NULL,
data_desc_,
NULL,
data_desc_,
param_desc_,
NULL,
&workspace_size));
// Calculate the reserved space size - common function for forward and backward
size_t reserve_size;
CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
op,
NULL,
data_desc_,
&reserve_size));
// CUDNN state is needed to access the workspace
size_t cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0));
cudnn_wrapper_.with_cudnn_state(
cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnBatchNormalizationBackwardEx(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
op,
cudnnTypeWrapper<T>::kOne(),
cudnnTypeWrapper<T>::kZero(),
cudnnTypeWrapper<T>::kOne(),
cudnnTypeWrapper<T>::kZero(),
data_desc_,
X_data,
NULL,
NULL,
data_desc_,
dY_data,
NULL,
NULL,
data_desc_,
dX_data,
param_desc_,
scale_data,
NULL,
dscale_data,
dbias_data,
epsilon_,
saved_mean_data,
saved_rstd_data,
NULL,
state->workspace().get(workspace_size),
workspace_size,
state->workspace().get(reserve_size),
reserve_size));
});
#else
CUDNN_ENFORCE(cudnnBatchNormalizationBackward(
cudnn_wrapper_.inline_cudnn_handle(),
mode_,
@ -334,7 +460,7 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
epsilon_,
saved_mean_data,
saved_rstd_data));
#endif // CUDNN_VERSION_MIN(8, 0, 0)
return true;
}