From ba06951c661d7c4628a742e754106fa0dc1c59b9 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 3 Jan 2024 15:41:28 +0000 Subject: [PATCH] [BE] [cuDNN] Always build assuming cuDNN >= 8.1 (#95722) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### 🤖 Generated by Copilot at 27084ed This pull request simplifies and cleans up the code that uses the cuDNN library for convolution, batch normalization, CTC loss, and quantized operations. It removes the unnecessary checks and conditions for older cuDNN versions and the experimental cuDNN v8 API, and ~~replaces them with the stable `cudnn_frontend` API that requires cuDNN v8 or higher. It also adds the dependency and configuration for the `cudnn_frontend` library in the cmake and bazel files.~~ Correction: The v7 API will still be available with this PR, and can still be used, without any changes to the defaults. This change simply always _builds_ the v8 API, and removes the case where _only_ the v7 API is built. This is a re-land of https://github.com/pytorch/pytorch/pull/91527 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95722 Approved by: https://github.com/malfet, https://github.com/atalman --- BUILD.bazel | 1 + CMakeLists.txt | 3 -- WORKSPACE | 6 +++ aten/src/ATen/cudnn/Descriptors.h | 2 - aten/src/ATen/native/cudnn/BatchNorm.cpp | 35 -------------- aten/src/ATen/native/cudnn/ConvShared.h | 3 -- aten/src/ATen/native/cudnn/Conv_v7.cpp | 47 ------------------- aten/src/ATen/native/cudnn/Conv_v8.cpp | 5 -- aten/src/ATen/native/cudnn/LossCTC.cpp | 2 +- aten/src/ATen/native/cudnn/Macros.h | 12 ----- .../native/quantized/cpu/fbgemm_utils.cpp | 5 +- .../ATen/native/quantized/cudnn/BinaryOps.cpp | 3 -- aten/src/ATen/native/quantized/cudnn/Conv.cpp | 12 +++-- .../native/quantized/cudnn/ConvPrepack.cpp | 13 +++-- .../native/quantized/cudnn/ConvUnpackImpl.cpp | 5 -- .../ATen/native/quantized/cudnn/Linear.cpp | 7 ++- .../native/quantized/cudnn/LinearPrepack.cpp | 8 ++-- .../quantized/cudnn/LinearUnpackImpl.cpp | 5 -- .../ATen/native/quantized/cudnn/Pooling.cpp | 7 --- aten/src/ATen/native/quantized/cudnn/utils.h | 5 -- caffe2/CMakeLists.txt | 6 --- third_party/cudnn_frontend.BUILD | 22 +++++++++ torch/CMakeLists.txt | 4 -- torch/csrc/cuda/Module.cpp | 9 ---- 24 files changed, 55 insertions(+), 172 deletions(-) delete mode 100644 aten/src/ATen/native/cudnn/Macros.h create mode 100644 third_party/cudnn_frontend.BUILD diff --git a/BUILD.bazel b/BUILD.bazel index 59d2ea857a14..0afee2d8d71c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -413,6 +413,7 @@ cc_library( "@cuda//:cusolver", "@cuda//:nvrtc", "@cudnn", + "@cudnn_frontend", ], alwayslink = True, ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9194e520bb00..9ea0682b4964 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,9 +213,6 @@ cmake_dependent_option( cmake_dependent_option( USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF) -cmake_dependent_option( - USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" ON - "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) diff --git a/WORKSPACE b/WORKSPACE index 412ba2fbf04d..b187949d663e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -246,6 +246,12 @@ new_local_repository( path = "/usr/local/cuda", ) +new_local_repository( + name = "cudnn_frontend", + build_file = "@//third_party:cudnn_frontend.BUILD", + path = "third_party/cudnn_frontend/", +) + local_repository( name = "com_github_google_flatbuffers", path = "third_party/flatbuffers", diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index f2b8e67a0dc1..694e93216b7a 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -346,7 +346,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor void set(cudnnDataType_t datatype) { AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); } -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 7600 void setEx( cudnnDataType_t datatype, cudnnLossNormalizationMode_t normMode, @@ -354,7 +353,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor AT_CUDNN_CHECK( cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); } -#endif }; struct TORCH_CUDA_CPP_API ActivationDescriptor diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 58d6046c608a..f18318fd0dcf 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -59,11 +59,7 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(bool training, at::MemoryFormat memor return CUDNN_BATCHNORM_PER_ACTIVATION; } else if (training && memory_format == at::MemoryFormat::ChannelsLast) { -#if CUDNN_VERSION >= 7400 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - return CUDNN_BATCHNORM_SPATIAL; -#endif // CUDNN_VERSION >= 7400 } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) { @@ -152,7 +148,6 @@ std::tuple cudnn_batch_norm( save_mean = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options()); -#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( @@ -204,22 +199,6 @@ std::tuple cudnn_batch_norm( workspace_size, reserve.mutable_data_ptr(), reserve_size)); -#else - reserve = at::empty({0}, input->options().dtype(kByte)); - AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( - handle, mode, &one, &zero, - idesc.desc(), input->data_ptr(), - idesc.desc(), output->data_ptr(), - wdesc.desc(), - weight->data_ptr(), - bias->data_ptr(), - exponential_average_factor, - at::maybe_data_ptr(running_mean), - at::maybe_data_ptr(running_var), - epsilon, - save_mean.mutable_data_ptr(), - save_var.mutable_data_ptr())); -#endif // CUDNN_VERSION >= 7400 } else { reserve = at::empty({0}, input->options().dtype(kByte)); // This keeps a consistent output with native_batch_norm @@ -317,7 +296,6 @@ std::tuple cudnn_batch_norm_backward( Constant one(dataType, 1); Constant zero(dataType, 0); -#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; @@ -354,19 +332,6 @@ std::tuple cudnn_batch_norm_backward( workspace_size, reserve->data_ptr(), reserve->numel())); -#else - AT_CUDNN_CHECK(cudnnBatchNormalizationBackward( - handle, mode, &one, &zero, &one, &zero, - idesc.desc(), input->data_ptr(), - odesc.desc(), grad_output->data_ptr(), - idesc.desc(), grad_input_t.data_ptr(), - wdesc.desc(), weight->data_ptr(), - grad_weight_t.data_ptr(), - grad_bias_t.data_ptr(), - epsilon, - save_mean->data_ptr(), - save_var->data_ptr())); -#endif // CUDNN_VERSION >= 7400 return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; } diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index fa06d0940471..89986adadac1 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -109,9 +109,7 @@ void raw_cudnn_convolution_add_relu_fallback_out( #if AT_CUDNN_ENABLED() -#include -#if HAS_CUDNN_V8() // v7 functions are preserved here to allow for runtime switching to v7 // (e.g., TORCH_CUDNN_V8_API_DISABLED=1). // Note that v7 forward/backward out can have different behavior from the v8 @@ -149,5 +147,4 @@ void raw_cudnn_convolution_add_relu_out_v7( bool deterministic, bool allow_tf32); #endif -#endif }} diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 5155d4449025..ef3a70a2232f 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -3,7 +3,6 @@ #if AT_CUDNN_ENABLED() -#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -60,10 +59,6 @@ // with the best algo, under the hood, cudnn will run with the slower kernel // since it sees fastest algorithm combination with a sub optimal mathType. -// Note [blocklist fft algorithms for strided dgrad] -// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution -// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 . - constexpr size_t operator "" _TiB(unsigned long long n) { return size_t(n) * 1024 * 1024 * 1024 * 1024; } @@ -225,15 +220,6 @@ size_t getMaxWorkspaceSize( template std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) { -// See Note [blocklist fft algorithms for strided dgrad] -#if CUDNN_VERSION < 7500 - bool blocklist = std::is_same::value; - int stride_dim = args.input.dim() - 2; - blocklist &= std::any_of(std::begin(args.params.stride), - std::begin(args.params.stride) + stride_dim, - [=](int n){return n != 1;}); -#endif - std::vector result; result.reserve(n_algo); for (const auto i : c10::irange(n_algo)) { @@ -244,16 +230,6 @@ std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArg if (perf.status == CUDNN_STATUS_SUCCESS) { if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) { - // See Note [blocklist fft algorithms for strided dgrad] -#if CUDNN_VERSION < 7500 - bool skip = blocklist; - skip &= (static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || - static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT); - if (skip) { - continue; - } -#endif - result.push_back(perf); } } @@ -493,11 +469,9 @@ public: perfResults[0].mathType = CUDNN_TENSOR_OP_MATH; } else { perfResults[0].mathType = CUDNN_DEFAULT_MATH; -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) { perfResults[0].mathType = CUDNN_FMA_MATH; } -#endif } search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory)); return perfResults; @@ -610,14 +584,10 @@ static inline void split_batch_dim_to_32bit_out( } -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 #define ASSERT_CORRECT_PRECISION(math_type) \ if (args.params.dataType == CUDNN_DATA_FLOAT) { \ TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \ } -#else -#define ASSERT_CORRECT_PRECISION(math_type) -#endif // CUDNN_VERSION >= 8000 // --------------------------------------------------------------------- @@ -672,11 +642,7 @@ void raw_cudnn_convolution_forward_out_32bit( } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_forward_out( -#else void raw_cudnn_convolution_forward_out_v7( -#endif const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -734,11 +700,7 @@ void raw_cudnn_convolution_backward_input_out_32bit( ); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_backward_input_out( -#else void raw_cudnn_convolution_backward_input_out_v7( -#endif const at::Tensor& grad_input, const at::Tensor& grad_output, const at::Tensor& weight, @@ -797,11 +759,7 @@ void raw_cudnn_convolution_backward_weight_out_32bit( ); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_backward_weight_out( -#else void raw_cudnn_convolution_backward_weight_out_v7( -#endif const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -853,12 +811,7 @@ void raw_cudnn_convolution_backward_weight_out_v7( TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_add_relu_out( -#else void raw_cudnn_convolution_add_relu_out_v7( -#endif - const Tensor& output, const Tensor& input, const Tensor& weight, diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index d0f9d0a5427f..aa582fc19e14 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -4,10 +4,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include @@ -780,5 +776,4 @@ void raw_cudnn_convolution_add_relu_out( }} // at::native -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 96b31175f1ae..cb08b57c309c 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -18,7 +18,7 @@ #include #endif -#if (!AT_CUDNN_ENABLED()) || (CUDNN_VERSION < 7600) +#if (!AT_CUDNN_ENABLED()) namespace at { namespace native { diff --git a/aten/src/ATen/native/cudnn/Macros.h b/aten/src/ATen/native/cudnn/Macros.h deleted file mode 100644 index 941d77233b92..000000000000 --- a/aten/src/ATen/native/cudnn/Macros.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -// Note: The version below should not actually be 8000. Instead, it should -// be whatever version of cuDNN that v8 API work with PyTorch correctly. -// The version is set to 8000 today for convenience of debugging. -#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 -#define HAS_CUDNN_V8() true -#else -#define HAS_CUDNN_V8() false -#endif diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index c2d363045563..2d15e54c4052 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -28,7 +28,6 @@ #include #endif -int register_linear_params(); int register_embedding_params(); #ifdef USE_FBGEMM @@ -437,7 +436,9 @@ TORCH_API int register_conv_params<2>(); template TORCH_API int register_conv_params<3>(); -int register_linear_params() { +TORCH_API int register_linear_params(); + +TORCH_API int register_linear_params() { using SerializationType = std::tuple>; static auto register_linear_params = torch::selective_class_( diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index fbb46b4b0174..a225a86eeb90 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -2,8 +2,6 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() -#include -#if HAS_CUDNN_V8() #include #include @@ -259,6 +257,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 8d7d940f8986..4cb496640746 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -3,11 +3,8 @@ #if AT_CUDNN_ENABLED() -#include #include -#if HAS_CUDNN_V8() - #include #include #include @@ -25,6 +22,12 @@ #include #include +template +int register_conv_params(); + +extern template int register_conv_params<2>(); +extern template int register_conv_params<3>(); + // TODO: there is a table from input dtype and weight dtype to operator qdtype, // we can derive the operator dtype based on input dtype cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) { @@ -391,6 +394,8 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { // this is inconsistent with what has been done for conv2d where new variants use packed weights, and // old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d // and will eventually deprecate the old variants + register_conv_params<2>(); + register_conv_params<3>(); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run); @@ -401,6 +406,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at::native -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index 4269f8738640..44d37f27bf6f 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -20,6 +16,12 @@ #include #include +template +int register_conv_params(); + +extern template int register_conv_params<2>(); +extern template int register_conv_params<3>(); + template c10::intrusive_ptr> PackedConvWeightCudnn< kSpatialDim>:: @@ -203,6 +205,8 @@ class QConv1dPackWeightInt8Cudnn final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_conv_params<2>(); + register_conv_params<3>(); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8Cudnn::run_conv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv)); } @@ -211,6 +215,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp index 0b2cae06bca2..ce5ee36cad4f 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -23,6 +19,5 @@ std::tuple> PackedConvWeightCudnn< template std::tuple> PackedConvWeightCudnn< 2>::unpack(); -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 6cef8e10ed0f..37e08ba7861d 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -3,11 +3,8 @@ #if AT_CUDNN_ENABLED() -#include #include -#if HAS_CUDNN_V8() - #include #include #include @@ -25,6 +22,8 @@ #include #include +int register_linear_params(); + // TODO: there is a table from input dtype and weight dtype to operator dtype, // we can derive the operator dtype based on input dtype cudnn_frontend::MatMulDesc_v8 getLinearDescriptor(cudnnDataType_t dataType) { @@ -358,6 +357,7 @@ class QLinearInt8 final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), QLinearInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), QLinearInt8::run); } @@ -367,6 +367,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp index 212bab93d0d7..abbb5922f393 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -16,6 +12,8 @@ #include #include +int register_linear_params(); + c10::intrusive_ptr PackedLinearWeightCudnn::prepack( at::Tensor weight, c10::optional bias) { @@ -50,6 +48,7 @@ class QLinearPackWeightInt8Cudnn final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run)); } @@ -58,6 +57,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp index 2212d054b8c7..7200872480ef 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -18,6 +14,5 @@ std::tuple> PackedLinearWeightCudnn::unpac return std::tuple>{orig_weight, bias_}; } -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 711afad775df..44d4251fbc62 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -3,7 +3,6 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() -#include #include #include #include @@ -54,7 +53,6 @@ Tensor adaptive_avg_pool2d_quantized_cuda( // TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() -// #if HAS_CUDNN_V8() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt // to per channel quantized tensors TORCH_CHECK(input.qscheme() == at::kPerTensorAffine, "adaptive_avg_pool2d_quantized_cuda oonly supports per tensor quantized tensors"); @@ -91,7 +89,6 @@ Tensor quantized_max_pool2d_cudnn( bool ceil_mode) { #ifdef USE_CUDA #if AT_CUDNN_ENABLED() -#if HAS_CUDNN_V8() check_maxpool2d_params( kernel_size, stride, @@ -207,10 +204,6 @@ Tensor quantized_max_pool2d_cudnn( // recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning return (ndim == 3 ? qy.view(std::vector(output_shape.begin() + 1, output_shape.end())) : qy); -#else // HAS_CUDNN_V8() - AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN v8 support"); - return Tensor{}; // never reached, placates the compiler -#endif // HAS_CUDNN_V8() #else // AT_CUDNN_ENABLED() AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); return Tensor{}; // never reached, placates the compiler diff --git a/aten/src/ATen/native/quantized/cudnn/utils.h b/aten/src/ATen/native/quantized/cudnn/utils.h index 3f2939fa5dab..18c891fcaa1c 100644 --- a/aten/src/ATen/native/quantized/cudnn/utils.h +++ b/aten/src/ATen/native/quantized/cudnn/utils.h @@ -8,10 +8,6 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linea #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -354,6 +350,5 @@ cudnn_frontend::ExecutionPlan get_execplan_from_heuristics_else_fall_back(cudnn_ } // anonymous } // cudnn_utils -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b035e50f81ae..72d40564fa40 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1414,12 +1414,6 @@ elseif(USE_ROCM) target_compile_definitions(torch_hip PRIVATE TORCH_HIP_BUILD_MAIN_LIB) endif() -if(USE_EXPERIMENTAL_CUDNN_V8_API) - if(USE_CUDA) - target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") - endif() -endif() - set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING "Experimental option to use a single thread pool for inter- and intra-op parallelism") if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") diff --git a/third_party/cudnn_frontend.BUILD b/third_party/cudnn_frontend.BUILD new file mode 100644 index 000000000000..0af69797a3e2 --- /dev/null +++ b/third_party/cudnn_frontend.BUILD @@ -0,0 +1,22 @@ +# Adopted from: https://github.com/tensorflow/tensorflow/blob/master/third_party/cudnn_frontend.BUILD + +# Description: +# The cuDNN Frontend API is a C++ header-only library that demonstrates how +# to use the cuDNN C backend API. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE.txt"]) + +cc_library( + name = "cudnn_frontend", + hdrs = glob(["include/**"]), + includes = ["include/"], + include_prefix = "third_party/cudnn_frontend", +) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b0d7bd842d33..24903a207ecc 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -145,10 +145,6 @@ if(USE_ROCM) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) endif() -if(USE_EXPERIMENTAL_CUDNN_V8_API) - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API) -endif() - if(USE_CUDNN OR USE_ROCM) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 09f8249d8b05..8f8c54e75b52 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -12,8 +12,6 @@ #if AT_CUDNN_ENABLED() -#include - #endif #include #include @@ -1359,15 +1357,8 @@ PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) { TORCH_WARN_ONCE( "cuDNN Benchmark limit is not supported in MIOpen and will have no effect."); #endif -#if AT_CUDNN_ENABLED() -#if HAS_CUDNN_V8() auto benchmark_limit = static_cast(THPUtils_unpackLong(arg)); at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit); -#else - TORCH_WARN_ONCE( - "cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect."); -#endif -#endif Py_RETURN_NONE; }