mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[caffe2] Use extended versions of cuDNN calls for SpatialBN
Summary: Using `cudnnBatchNormalizationForwardTrainingEx` and `cudnnBatchNormalizationBackwardEx` if cuDNN version is greater than 8.0.0. Reviewed By: xw285cornell Differential Revision: D26794173 fbshipit-source-id: dc4994375350f303a3fa0aee03255e8f8be1c605
This commit is contained in:
committed by
Facebook GitHub Bot
parent
758fb94fcb
commit
369601355f
@ -191,6 +191,66 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
|
||||
return true;
|
||||
}
|
||||
const double alpha = static_cast<double>(1.0f - momentum_);
|
||||
|
||||
#if CUDNN_VERSION_MIN(8, 0, 0)
|
||||
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
|
||||
// Calculate the workspace size
|
||||
size_t workspace_size;
|
||||
CUDNN_ENFORCE(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
op,
|
||||
data_desc_,
|
||||
NULL,
|
||||
data_desc_,
|
||||
param_desc_,
|
||||
NULL,
|
||||
&workspace_size));
|
||||
|
||||
// Calculate the reserved space size - common function for forward and backward
|
||||
size_t reserve_size;
|
||||
CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
op,
|
||||
NULL,
|
||||
data_desc_,
|
||||
&reserve_size));
|
||||
|
||||
// CUDNN state is needed to access the workspace
|
||||
size_t cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0));
|
||||
cudnn_wrapper_.with_cudnn_state(
|
||||
cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationForwardTrainingEx(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
CUDNN_BATCHNORM_OPS_BN,
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
data_desc_,
|
||||
X_data,
|
||||
NULL,
|
||||
NULL,
|
||||
data_desc_,
|
||||
Y_data,
|
||||
param_desc_,
|
||||
scale_data,
|
||||
bias_data,
|
||||
alpha,
|
||||
running_mean_data,
|
||||
running_var_data,
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_inv_std_data,
|
||||
NULL,
|
||||
state->workspace().get(workspace_size),
|
||||
workspace_size,
|
||||
state->workspace().get(reserve_size),
|
||||
reserve_size));
|
||||
});
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
@ -209,6 +269,7 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_inv_std_data));
|
||||
#endif // CUDNN_VERSION_MIN(8, 0, 0)
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -314,6 +375,71 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
|
||||
data_desc_,
|
||||
param_desc_);
|
||||
}
|
||||
#if CUDNN_VERSION_MIN(8, 0, 0)
|
||||
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
|
||||
size_t workspace_size;
|
||||
CUDNN_ENFORCE(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
op,
|
||||
data_desc_,
|
||||
NULL,
|
||||
data_desc_,
|
||||
NULL,
|
||||
data_desc_,
|
||||
param_desc_,
|
||||
NULL,
|
||||
&workspace_size));
|
||||
|
||||
// Calculate the reserved space size - common function for forward and backward
|
||||
size_t reserve_size;
|
||||
CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
op,
|
||||
NULL,
|
||||
data_desc_,
|
||||
&reserve_size));
|
||||
|
||||
// CUDNN state is needed to access the workspace
|
||||
size_t cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0));
|
||||
cudnn_wrapper_.with_cudnn_state(
|
||||
cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationBackwardEx(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
op,
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
data_desc_,
|
||||
X_data,
|
||||
NULL,
|
||||
NULL,
|
||||
data_desc_,
|
||||
dY_data,
|
||||
NULL,
|
||||
NULL,
|
||||
data_desc_,
|
||||
dX_data,
|
||||
param_desc_,
|
||||
scale_data,
|
||||
NULL,
|
||||
dscale_data,
|
||||
dbias_data,
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_rstd_data,
|
||||
NULL,
|
||||
state->workspace().get(workspace_size),
|
||||
workspace_size,
|
||||
state->workspace().get(reserve_size),
|
||||
reserve_size));
|
||||
});
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationBackward(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
@ -334,7 +460,7 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_rstd_data));
|
||||
|
||||
#endif // CUDNN_VERSION_MIN(8, 0, 0)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user