diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index ac78cf196652..79a2fe58ad00 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -210,9 +210,7 @@ struct TORCH_CUDA_CPP_API ConvolutionDescriptor if(dataType == CUDNN_DATA_HALF) { AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) { -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH)); -#endif } } }; @@ -304,13 +302,9 @@ struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor< if (input_type == CUDNN_DATA_HALF) { cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); } -#endif -#if !defined(USE_CUDNN_RNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); } -#endif -#ifndef USE_CUDNN_RNN_V8_API else { // Technically, as the default it's not necessary to explicitly // set this. diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 2efe7a76169b..eb21f8fb5a73 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -78,15 +78,8 @@ cudnnBatchNormMode_t getCudnnBatchNormMode( return CUDNN_BATCHNORM_PER_ACTIVATION; } else if (training && memory_format == at::MemoryFormat::ChannelsLast) { return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) { - -#if CUDNN_VERSION >= 8100 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - return CUDNN_BATCHNORM_SPATIAL; -#endif // CUDNN_VERSION >= 8100 - } else { // TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was // introduced in CuDNN 7 for performance optimization, but it results in diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index 3fb422fbb6e0..104ae8c70803 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -735,23 +735,6 @@ Tensor cudnn_convolution_relu( output_t.options().device_opt(), output_t.options().pinned_memory_opt()); -#ifdef AT_CUDNN_CONV_BIAS_RELU_FALLBACK - raw_cudnn_convolution_add_relu_fallback_out( - output_t, - input, - weight, - output_t, // use output_t as z to satisfy CUDNN API - 0, // alpha - _bias, - stride, - padding, - dilation, - groups, - benchmark, // benchmark - false, // deterministic - allow_tf32 // allow_tf32 - ); -#else // AT_CUDNN_CONV_BIAS_RELU_FALLBACK raw_cudnn_convolution_add_relu_out( output_t, input, @@ -767,7 +750,6 @@ Tensor cudnn_convolution_relu( false, // deterministic allow_tf32 // allow_tf32 ); -#endif return output_t; } @@ -813,23 +795,6 @@ Tensor cudnn_convolution_add_relu( output_t.options().device_opt(), output_t.options().pinned_memory_opt()); -#ifdef AT_CUDNN_CONV_BIAS_RELU_FALLBACK - raw_cudnn_convolution_add_relu_fallback_out( - output_t, - input, - weight, - z, - _alpha, - _bias, - stride, - padding, - dilation, - groups, - benchmark, - false, // deterministic - allow_tf32 // allow_tf32 - ); -#else // AT_CUDNN_CONV_BIAS_RELU_FALLBACK raw_cudnn_convolution_add_relu_out( output_t, input, @@ -845,7 +810,6 @@ Tensor cudnn_convolution_add_relu( false, // deterministic allow_tf32 // allow_tf32 ); -#endif // AT_CUDNN_CONV_BIAS_RELU_FALLBACK return output_t; } diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index fa59e74bb808..ae68bfc7d20d 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -6,10 +6,6 @@ #include #include -#if CUDNN_VERSION < 8000 -#define AT_CUDNN_CONV_BIAS_RELU_FALLBACK -#endif - namespace at { namespace native { diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h index 5161d26314d2..b130103fb5cb 100644 --- a/caffe2/core/common_cudnn.h +++ b/caffe2/core/common_cudnn.h @@ -16,17 +16,12 @@ #include static_assert( - CUDNN_VERSION >= 5000, - "Caffe2 requires cudnn version 5.0 or above."); - -#if CUDNN_VERSION < 6000 -#pragma message "CUDNN version under 6.0 is supported at best effort." -#pragma message "We strongly encourage you to move to 6.0 and above." -#pragma message "This message is intended to annoy you enough to update." -#endif // CUDNN_VERSION < 6000 + CUDNN_VERSION >= 8200, + "Caffe2 requires cudnn version 8.2 or above."); #define CUDNN_VERSION_MIN(major, minor, patch) \ - (CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) + (major >= 9 ? CUDNN_VERSION >= ((major) * 10000 + (minor) * 100 + (patch)) : \ + CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) namespace caffe2 { @@ -135,7 +130,6 @@ class cudnnTypeWrapper { } }; -#if CUDNN_VERSION_MIN(6, 0, 0) template <> class cudnnTypeWrapper { public: @@ -151,7 +145,6 @@ class cudnnTypeWrapper { return &v; } }; -#endif // CUDNN_VERSION_MIN(6, 0, 0) template <> class cudnnTypeWrapper { diff --git a/caffe2/operators/conv_op_cudnn.cc b/caffe2/operators/conv_op_cudnn.cc index 36a5f4c4bcae..0ef2f35dba9e 100644 --- a/caffe2/operators/conv_op_cudnn.cc +++ b/caffe2/operators/conv_op_cudnn.cc @@ -37,29 +37,9 @@ class CudnnConvOpBase : public ConvPoolOpBase { "The current padding scheme leads to unequal padding on the left " "and right, which is not supported by cudnn."); } - // dilated convolution supported by some algorithms in cuDNN v6 -#if !(CUDNN_VERSION_MIN(6, 0, 0)) - OPERATOR_NEEDS_FEATURE( - dilation_h() == 1 && dilation_w() == 1, - "The cudnn convolution does not support dilation yet."); -#endif - // dilated grouped convolution supported in cuDNN v7.1 -#if !(CUDNN_VERSION_MIN(7, 1, 0)) - if (group_ != 1) { - for (int dim = 0; dim < kernel_.size(); ++dim) { - OPERATOR_NEEDS_FEATURE( - dilation_[dim] == 1, - "When group is used, dilation should not be set at the same time."); - } - } -#endif -#if CUDNN_VERSION_MIN(7, 0, 0) // verify TensorCore math is supported enable_tensor_core_ &= TensorCoreAvailable(); -#else - enable_tensor_core_ = false; -#endif bool individual_force_algo = OperatorBase::HasArgument("force_algo_fwd") || OperatorBase::HasArgument("force_algo_dgrad") || @@ -108,11 +88,7 @@ class CudnnConvOpBase : public ConvPoolOpBase { int H, int W, int D) { -#if CUDNN_VERSION_MIN(7, 0, 0) const int CC = C; -#else - const int CC = C / group_; -#endif switch (order_) { case StorageOrder::NHWC: if (size == 4) { @@ -182,7 +158,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { int dilation_height = 0; int dilation_width = 0; -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor( input, &pad_height, @@ -193,19 +168,7 @@ class CudnnConvOpBase : public ConvPoolOpBase { &dilation_width, &mode, &dataType)); -#else - CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor( - input, - &pad_height, - &pad_width, - &stride_height, - &stride_width, - &dilation_height, - &dilation_width, - &mode)); -#endif -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( copy, pad_height, @@ -216,17 +179,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { dilation_width, mode, dataType)); -#else - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( - copy, - pad_height, - pad_width, - stride_height, - stride_width, - dilation_height, - dilation_width, - mode)); -#endif } else { cudnnConvolutionMode_t mode; cudnnDataType_t dataType; @@ -278,7 +230,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { } void SetConvDescFromArguments() { -#if CUDNN_VERSION_MIN(6, 0, 0) if (kernel_.size() == 1 || kernel_.size() == 2) { CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc_, @@ -300,29 +251,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { CUDNN_CROSS_CORRELATION, compute_type_)); } -#else - if (kernel_.size() == 2) { - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( - conv_desc_, - pad_t(), - pad_l(), - stride_h(), - stride_w(), - 1, - 1, - CUDNN_CROSS_CORRELATION)); - } else { - vector ones(dilation_.size(), 1); - CUDNN_ENFORCE(cudnnSetConvolutionNdDescriptor( - conv_desc_, - kernel_.size(), - pads_.data(), - stride_.data(), - ones.data(), - CUDNN_CROSS_CORRELATION, - compute_type_)); - } -#endif } void SetConvDescComputeType( @@ -338,7 +266,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { int dilation_height = 0; int dilation_width = 0; -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor( conv_desc, &pad_height, @@ -349,19 +276,7 @@ class CudnnConvOpBase : public ConvPoolOpBase { &dilation_width, &mode, &dataType)); -#else - CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor( - conv_desc, - &pad_height, - &pad_width, - &stride_height, - &stride_width, - &dilation_height, - &dilation_width, - &mode)); -#endif -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc, pad_height, @@ -372,17 +287,6 @@ class CudnnConvOpBase : public ConvPoolOpBase { dilation_width, mode, math)); -#else - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( - conv_desc, - pad_height, - pad_width, - stride_height, - stride_width, - dilation_height, - dilation_width, - mode)); -#endif } else { cudnnConvolutionMode_t mode; cudnnDataType_t dataType; @@ -576,9 +480,6 @@ bool CudnnConvOp::DoRunWithType() { return true; } -#if !CUDNN_VERSION_MIN(7, 0, 0) - int group_offset_filter = filter.numel() / group_; -#endif // Set up the cudnn algorithms & workspace if necessary bool input_changed = (X.sizes() != cudnn_input_dims_); @@ -592,11 +493,7 @@ bool CudnnConvOp::DoRunWithType() { if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); if (kernel_.size() == 1 || kernel_.size() == 2) { -#if CUDNN_VERSION_MIN(7, 0, 0) const int MM = M; -#else - const int MM = M / group_; -#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, @@ -607,12 +504,6 @@ bool CudnnConvOp::DoRunWithType() { kernel_.size() == 1 ? 1 : kernel_w())); } else { vector dims(filter.sizes().begin(), filter.sizes().end()); -#if !CUDNN_VERSION_MIN(7, 0, 0) - // We only need to divide dims by group_ when CUDNN version < 7.0 - // see CUDA group convolution doc: https://fburl.com/dgj6dvpd - order_ == StorageOrder::NCHW ? dims[1] /= group_ - : dims[filter.ndim() - 1] /= group_; -#endif CUDNN_ENFORCE(cudnnSetFilterNdDescriptor( filter_desc_, cudnnTypeWrapper::type, @@ -674,7 +565,6 @@ bool CudnnConvOp::DoRunWithType() { compute_type_ = DetermineComputeTypeFromInput(X); SetConvDescFromArguments(); -#if CUDNN_VERSION_MIN(7, 0, 0) if (enable_tensor_core_) { CUDNN_ENFORCE( cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); @@ -682,7 +572,6 @@ bool CudnnConvOp::DoRunWithType() { // enable cuDNN conv groups CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); -#endif if (force_algo_[ALGO_FWD] >= 0) { algo_ = (cudnnConvolutionFwdAlgo_t)force_algo_[ALGO_FWD]; @@ -808,7 +697,6 @@ bool CudnnConvOp::DoRunWithType() { // Now, actually run the computation. // Run directly through cuDNN if possible -#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionForward( state->cudnn_handle(), @@ -825,27 +713,6 @@ bool CudnnConvOp::DoRunWithType() { top_desc_, Y->template mutable_data())); }); -#else - // otherwise manually run through groups - for (int i = 0; i < group_; ++i) { - cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { - CUDNN_ENFORCE(cudnnConvolutionForward( - state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - bottom_desc_, - X.template data() + i * group_offset_X, - filter_desc_, - filter.template data() + i * group_offset_filter, - conv_desc_, - algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - top_desc_, - Y->template mutable_data() + i * group_offset_Y)); - }); - } -#endif // Bias if (InputSize() == 3) { auto& bias = Input(BIAS); @@ -953,9 +820,6 @@ bool CudnnConvGradientOp::DoRunWithType() { "If you set group, the number of output channels should be divisible " "by group."); -#if !CUDNN_VERSION_MIN(7, 0, 0) - int group_offset_filter = filter.numel() / group_; -#endif if (kernel_.size() == 1) { ConvPoolOpBase::ComputePads({H}); } else if (kernel_.size() == 2) { @@ -1003,11 +867,7 @@ bool CudnnConvGradientOp::DoRunWithType() { if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); if (kernel_.size() == 1 || kernel_.size() == 2) { -#if CUDNN_VERSION_MIN(7, 0, 0) const int MM = M; -#else - const int MM = M / group_; -#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, @@ -1018,12 +878,6 @@ bool CudnnConvGradientOp::DoRunWithType() { kernel_.size() == 1 ? 1 : kernel_w())); } else { vector dims(filter.sizes().begin(), filter.sizes().end()); -#if !CUDNN_VERSION_MIN(7, 0, 0) - // We only need to divide dims by group_ when CUDNN version < 7.0 - // see CUDA group convolution doc: https://fburl.com/dgj6dvpd - order_ == StorageOrder::NCHW ? dims[1] /= group_ - : dims[filter.ndim() - 1] /= group_; -#endif CUDNN_ENFORCE(cudnnSetFilterNdDescriptor( filter_desc_, @@ -1091,7 +945,6 @@ bool CudnnConvGradientOp::DoRunWithType() { DuplicateConvDesc( conv_desc_, kernel_.size(), dilation_.size(), bwd_data_conv_desc_); -#if CUDNN_VERSION_MIN(7, 0, 0) if (enable_tensor_core_) { CUDNN_ENFORCE(cudnnSetConvolutionMathType( bwd_filter_conv_desc_, CUDNN_TENSOR_OP_MATH)); @@ -1102,7 +955,6 @@ bool CudnnConvGradientOp::DoRunWithType() { // set cuDNN groups if appropriate CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_filter_conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_data_conv_desc_, group_)); -#endif // Choose dW algorithm if (force_algo_[ALGO_WGRAD] >= 0) { @@ -1388,7 +1240,6 @@ bool CudnnConvGradientOp::DoRunWithType() { dbias->template mutable_data())); } -#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( state->cudnn_handle(), @@ -1427,45 +1278,6 @@ bool CudnnConvGradientOp::DoRunWithType() { dX->template mutable_data())); } }); -#else - for (int i = 0; i < group_; ++i) { - cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { - CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( - state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - bottom_desc_, - X.template data() + i * group_offset_X, - top_desc_, - dY.template data() + i * group_offset_Y, - bwd_filter_conv_desc_, - bwd_filter_algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - filter_desc_, - dfilter->template mutable_data() + i * group_offset_filter)); - if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { - // Compute the gradient w.r.t. the input. - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); - CUDNN_ENFORCE(cudnnConvolutionBackwardData( - state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - filter_desc_, - filter.template data() + i * group_offset_filter, - top_desc_, - dY.template data() + i * group_offset_Y, - bwd_data_conv_desc_, - bwd_data_algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - bottom_desc_, - dX->template mutable_data() + i * group_offset_X)); - } - }); - } -#endif return true; } diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc index d432e4d30780..b9333af99e3e 100644 --- a/caffe2/operators/conv_transpose_op_cudnn.cc +++ b/caffe2/operators/conv_transpose_op_cudnn.cc @@ -77,11 +77,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase { const int H, const int W, cudnnTensorDescriptor_t* desc) const { -#if CUDNN_VERSION_MIN(7, 0, 0) const int CC = C; -#else - const int CC = C / group_; -#endif switch (order_) { case StorageOrder::NCHW: { CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx( @@ -242,11 +238,7 @@ bool CudnnConvTransposeOp::RunOnDevice() { } if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); -#if CUDNN_VERSION_MIN(7, 0, 0) const int MM = M; -#else - const int MM = M / group_; -#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, @@ -292,7 +284,6 @@ bool CudnnConvTransposeOp::RunOnDevice() { "The current padding scheme leads to unequal padding on the left " "and right, which is not supported by cudnn."); // Set the convolution descriptor -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc_, pad_t(), @@ -303,19 +294,7 @@ bool CudnnConvTransposeOp::RunOnDevice() { 1, CUDNN_CROSS_CORRELATION, cudnnTypeWrapper::type)); -#else - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( - conv_desc_, - pad_t(), - pad_l(), - stride_h(), - stride_w(), - 1, - 1, - CUDNN_CROSS_CORRELATION)); -#endif -#if CUDNN_VERSION_MIN(7, 0, 0) // enable TensorCore math if desired enable_tensor_core_ &= TensorCoreAvailable(); if (enable_tensor_core_) { @@ -324,7 +303,6 @@ bool CudnnConvTransposeOp::RunOnDevice() { } // set cuDNN groups if appropriate CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_)); -#endif if (force_algo_[ALGO_DGRAD] >= 0) { bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD]; @@ -400,7 +378,6 @@ bool CudnnConvTransposeOp::RunOnDevice() { // Now, actually run the computation. // Filter -#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionBackwardData( state->cudnn_handle(), @@ -417,33 +394,6 @@ bool CudnnConvTransposeOp::RunOnDevice() { top_desc_, Y_data)); }); -#else - const int X_HxW = H * W; - const int Y_HxW = H_out * W_out; - const int group_offset_X = - order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_; - const int group_offset_Y = - order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_; - const int group_offset_filter = filter.numel() / group_; - for (int i = 0; i < group_; ++i) { - cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { - CUDNN_ENFORCE( - cudnnConvolutionBackwardData(state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - filter_desc_, - filter_data + i * group_offset_filter, - bottom_desc_, - X_data + i * group_offset_X; - conv_desc_, - bwd_data_algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - top_desc_, - Y_data + i * group_offset_Y)); - }); - } -#endif // Bias if (InputSize() == 3) { CUDNN_ENFORCE(cudnnAddTensor( @@ -527,11 +477,7 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { } if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); -#if CUDNN_VERSION_MIN(7, 0, 0) const int MM = M; -#else - const int MM = M / group_; -#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, @@ -576,7 +522,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { pad_r(), "The current padding scheme leads to unequal padding on the left " "and right, which is not supported by cudnn."); -#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc_, pad_t(), @@ -587,18 +532,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { 1, CUDNN_CROSS_CORRELATION, cudnnTypeWrapper::type)); -#else - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( - conv_desc_, - pad_t(), - pad_l(), - stride_h(), - stride_w(), - 1, - 1, - CUDNN_CROSS_CORRELATION)); -#endif -#if CUDNN_VERSION_MIN(7, 0, 0) // enable TensorCore math if desired enable_tensor_core_ &= TensorCoreAvailable(); if (enable_tensor_core_) { @@ -607,7 +540,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { } // set cuDNN groups if appropriate CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); -#endif if (force_algo_[ALGO_WGRAD] >= 0) { bwd_filter_algo_ = (cudnnConvolutionBwdFilterAlgo_t)force_algo_[ALGO_WGRAD]; @@ -762,7 +694,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { dbias->template mutable_data())); } -#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( state->cudnn_handle(), @@ -801,55 +732,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { dX->template mutable_data())); } }); -#else - const int X_HxW = H * W; - const int Y_HxW = H_out * W_out; - const int group_offset_X = - order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_; - const int group_offset_Y = - order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_; - const int group_offset_filter = filter.numel() / group_; - for (int i = 0; i < group_; ++i) { - cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { - CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( - state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - top_desc_, - dY.template data() + i * group_offset_Y, - bottom_desc_, - X.template data() + i * group_offset_X, - conv_desc_, - bwd_filter_algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - filter_desc_, - dfilter->template mutable_data() + i * group_offset_filter)); - if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { - // Compute the gradient w.r.t. the input. - auto* dX = Output( - no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, - X.sizes(), - at::dtype()); - cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { - CUDNN_ENFORCE(cudnnConvolutionForward( - state->cudnn_handle(), - cudnnTypeWrapper::kOne(), - top_desc_, - dY.template data() + i * group_offset_Y, - filter_desc_, - filter.template data() + i * group_offset_filter, - conv_desc_, - algo_, - state->workspace().get(cudnn_ws_nbytes_), - cudnn_ws_nbytes_, - cudnnTypeWrapper::kZero(), - bottom_desc_, - dX->template mutable_data() + i * group_offset_X)); - }); - } - } -#endif return true; } diff --git a/caffe2/operators/dropout_op_cudnn.cc b/caffe2/operators/dropout_op_cudnn.cc index 01f21544867d..2d0ca63e4ebe 100644 --- a/caffe2/operators/dropout_op_cudnn.cc +++ b/caffe2/operators/dropout_op_cudnn.cc @@ -5,10 +5,6 @@ namespace caffe2 { -// cudnnRestoreDropoutDescriptor is needed for correctness and -// doesn't exist prior to cuDNN v7 -#if CUDNN_VERSION_MIN(7,0,0) - class CuDNNDropoutOp final : public Operator { public: USE_OPERATOR_FUNCTIONS(CUDAContext); @@ -293,6 +289,4 @@ REGISTER_CUDNN_OPERATOR(Dropout, CuDNNDropoutOp); REGISTER_CUDNN_OPERATOR(DropoutGrad, CuDNNDropoutGradientOp); } -#endif - }; // namespace caffe2 diff --git a/caffe2/operators/op_utils_cudnn.h b/caffe2/operators/op_utils_cudnn.h index ca5c19e62918..cb7377a2eca3 100644 --- a/caffe2/operators/op_utils_cudnn.h +++ b/caffe2/operators/op_utils_cudnn.h @@ -14,7 +14,6 @@ static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024; // This does not have any performance implications, as we will always find the // fastest algorithm; setting them to the right number of algorithms will enable // us to best report the statistics when doing an exhaustive search, though. -#if CUDNN_VERSION_MIN(7, 0, 0) // Note: Double each of these due to potential // tensorcore + non-tensorcore versions // which are treated as separate returned algos @@ -24,11 +23,6 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 2 * CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 2 * CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; -#else -static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7; -static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; -static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; -#endif namespace { template diff --git a/caffe2/operators/pool_op_cudnn.cc b/caffe2/operators/pool_op_cudnn.cc index e65680148c71..251fc380bc83 100644 --- a/caffe2/operators/pool_op_cudnn.cc +++ b/caffe2/operators/pool_op_cudnn.cc @@ -434,11 +434,7 @@ struct CuDNNMaxPoolFunctor { deterministic(op.GetSingleArgument("deterministic", false)) {} cudnnPoolingMode_t GetPoolingMode() const { -#if CUDNN_VERSION_MIN(6, 0, 0) return deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; -#else - return CUDNN_POOLING_MAX; -#endif } template diff --git a/caffe2/operators/rnn/recurrent_op_cudnn.cc b/caffe2/operators/rnn/recurrent_op_cudnn.cc index 3679c9d2a7d3..c23662f13f21 100644 --- a/caffe2/operators/rnn/recurrent_op_cudnn.cc +++ b/caffe2/operators/rnn/recurrent_op_cudnn.cc @@ -98,7 +98,6 @@ void RecurrentBaseOp::initialize( // RNN setup { -#if CUDNN_VERSION_MIN(7, 0, 0) CUDNN_ENFORCE(cudnnSetRNNDescriptor_v6( cudnn_wrapper_.inline_cudnn_handle(), rnnDesc_, @@ -110,17 +109,6 @@ void RecurrentBaseOp::initialize( rnnMode, CUDNN_RNN_ALGO_STANDARD, // TODO: verify correctness / efficiency. cudnnTypeWrapper::type)); -#else - CUDNN_ENFORCE(cudnnSetRNNDescriptor( - rnnDesc_, - hiddenSize, - numLayers, - dropoutDesc_, - rnnInput, - rnnDirection, - rnnMode, - cudnnTypeWrapper::type)); -#endif } // X setup { diff --git a/caffe2/operators/spatial_batch_norm_op_cudnn.cu b/caffe2/operators/spatial_batch_norm_op_cudnn.cu index 4b629caa8cc0..d5ad3ac54efd 100644 --- a/caffe2/operators/spatial_batch_norm_op_cudnn.cu +++ b/caffe2/operators/spatial_batch_norm_op_cudnn.cu @@ -10,7 +10,6 @@ #include "caffe2/operators/spatial_batch_norm_op_impl.cuh" #include "caffe2/utils/math.h" -#if CUDNN_VERSION_MIN(5, 0, 0) namespace caffe2 { @@ -63,16 +62,12 @@ class CuDNNSpatialBNOp final : public SpatialBNOp { CuDNNSpatialBNOp(const OperatorDef& operator_def, Workspace* ws) : SpatialBNOp(operator_def, ws), cudnn_wrapper_(&context_), -#if CUDNN_VERSION_MIN(7, 0, 0) // TODO(T31829456): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was // introduced in CuDNN 7 for performance optimization, but it results in // accuracy losses in convolution models such as ResNeXt-101 and // video R(2+1)D. We will fall back to the normal // CUDNN_BATCHNORM_SPATIAL for now mode_(CUDNN_BATCHNORM_SPATIAL) { -#else - mode_(CUDNN_BATCHNORM_SPATIAL) { -#endif CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE(cudnnCreateTensorDescriptor(¶m_desc_)); if (epsilon_ < CUDNN_BN_MIN_EPSILON) { @@ -192,7 +187,6 @@ class CuDNNSpatialBNOp final : public SpatialBNOp { } const double alpha = static_cast(1.0f - momentum_); -#if CUDNN_VERSION_MIN(8, 0, 0) // Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION auto op = CUDNN_BATCHNORM_OPS_BN; @@ -250,26 +244,6 @@ class CuDNNSpatialBNOp final : public SpatialBNOp { state->workspace().get(reserve_size), reserve_size)); }); -#else - CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining( - cudnn_wrapper_.inline_cudnn_handle(), - mode_, - cudnnTypeWrapper::kOne(), - cudnnTypeWrapper::kZero(), - data_desc_, - X_data, - data_desc_, - Y_data, - param_desc_, - scale_data, - bias_data, - alpha, - running_mean_data, - running_var_data, - epsilon_, - saved_mean_data, - saved_inv_std_data)); -#endif // CUDNN_VERSION_MIN(8, 0, 0) } return true; } @@ -290,16 +264,12 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp { CuDNNSpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws) : SpatialBNGradientOp(operator_def, ws), cudnn_wrapper_(&context_), -#if CUDNN_VERSION_MIN(7, 0, 0) // TODO(T31829456): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was // introduced in CuDNN 7 for performance optimization, but it results in // accuracy losses in convolution models such as ResNeXt-101 and // video R(2+1)D. We will fall back to the normal // CUDNN_BATCHNORM_SPATIAL for now mode_(CUDNN_BATCHNORM_SPATIAL) { -#else - mode_(CUDNN_BATCHNORM_SPATIAL) { -#endif CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE(cudnnCreateTensorDescriptor(¶m_desc_)); if (epsilon_ < CUDNN_BN_MIN_EPSILON) { @@ -375,7 +345,6 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp { data_desc_, param_desc_); } -#if CUDNN_VERSION_MIN(8, 0, 0) // Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION auto op = CUDNN_BATCHNORM_OPS_BN; @@ -439,28 +408,6 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp { state->workspace().get(reserve_size), reserve_size)); }); -#else - CUDNN_ENFORCE(cudnnBatchNormalizationBackward( - cudnn_wrapper_.inline_cudnn_handle(), - mode_, - cudnnTypeWrapper::kOne(), - cudnnTypeWrapper::kZero(), - cudnnTypeWrapper::kOne(), - cudnnTypeWrapper::kZero(), - data_desc_, - X_data, - data_desc_, - dY_data, - data_desc_, - dX_data, - param_desc_, - scale_data, - dscale_data, - dbias_data, - epsilon_, - saved_mean_data, - saved_rstd_data)); -#endif // CUDNN_VERSION_MIN(8, 0, 0) return true; } @@ -478,5 +425,3 @@ REGISTER_CUDNN_OPERATOR(SpatialBN, CuDNNSpatialBNOp); REGISTER_CUDNN_OPERATOR(SpatialBNGradient, CuDNNSpatialBNGradientOp); } // namespace caffe2 - -#endif // CUDNN_VERSION_MIN(5, 0, 0) diff --git a/caffe2/operators/transpose_op_cudnn.cc b/caffe2/operators/transpose_op_cudnn.cc index 7e9b5b73610e..5f08c72cbe1e 100644 --- a/caffe2/operators/transpose_op_cudnn.cc +++ b/caffe2/operators/transpose_op_cudnn.cc @@ -134,33 +134,6 @@ class CuDNNTransposeOp final : public Operator { std::vector axes_; }; -#if !CUDNN_VERSION_MIN(6, 0, 0) - -// CuDNN 5.1 does not have int support yet. -template <> -bool CuDNNTransposeOp::DoRunWithType() { - const auto& X = Input(0); - const int ndim = X.dim(); - if (axes_.empty()) { - axes_.resize(ndim); - std::iota(axes_.rbegin(), axes_.rend(), 0); - } else { - CAFFE_ENFORCE_EQ(axes_.size(), ndim); - } - std::vector X_dims = X.sizes().vec(); - std::vector Y_dims(ndim); - for (int i = 0; i < ndim; ++i) { - Y_dims[i] = X_dims[axes_[i]]; - } - auto* Y = Output(0, Y_dims, at::dtype()); - const T* X_data = X.template data(); - T* Y_data = Y->template mutable_data(); - math::Transpose( - ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_); - return true; -} - -#endif // !CUDNN_VERSION_MIN(6, 0, 0) } // namespace diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index c7595774d810..8160b5e1fa88 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -251,8 +251,8 @@ if(CAFFE2_USE_CUDNN) "Cannot find cuDNN library. Turning the option off") set(CAFFE2_USE_CUDNN OFF) else() - if(CUDNN_VERSION VERSION_LESS "8.0.0") - message(FATAL_ERROR "PyTorch requires cuDNN 8 and above.") + if(CUDNN_VERSION VERSION_LESS "8.1.0") + message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.") endif() endif() diff --git a/test/test_nn.py b/test/test_nn.py index 9c421e53dd3d..37c2be3ba82a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -35,8 +35,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION, \ - PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -4147,7 +4146,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") dtype = torch.double self._test_RNN_cpu_vs_cudnn(0, dtype) - @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_RNN_cpu_vs_cudnn_with_dropout(self): # Because of dropout randomness, can only compare dropout=0 and dropout=1 self._test_RNN_cpu_vs_cudnn(1) @@ -4206,8 +4205,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") warnings.simplefilter("always") self.assertEqual(m(inp)[0].cpu(), out_expected[0]) - - @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @set_default_dtype(torch.double) def test_RNN_dropout(self): # checking the assumption that cuDNN sticks dropout in between @@ -4251,6 +4249,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") self.assertEqual(hy.data[0][0][0], 10) self.assertEqual(hy.data[1][0][0], output_val) + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @set_default_dtype(torch.double) def test_error_RNN_seq_len_zero(self): # checking error message when RNN has seq_len = 0 @@ -4279,7 +4278,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") # Check that backward does not cause a hard error outs[0].sum().backward() - @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_RNN_dropout_state(self): for p in (0, 0.1234): for train in (True, False): @@ -4319,7 +4318,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") self.assertNotEqual(hy1, hy2) self.assertNotEqual(hy1, hy3) - @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @set_default_dtype(torch.double) def test_RNN_change_dropout(self): for train, cuda in product((True, False), repeat=2):