diff --git a/BUILD.bazel b/BUILD.bazel index 59d2ea857a14..0afee2d8d71c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -413,6 +413,7 @@ cc_library( "@cuda//:cusolver", "@cuda//:nvrtc", "@cudnn", + "@cudnn_frontend", ], alwayslink = True, ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9194e520bb00..9ea0682b4964 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,9 +213,6 @@ cmake_dependent_option( cmake_dependent_option( USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF) -cmake_dependent_option( - USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" ON - "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) diff --git a/WORKSPACE b/WORKSPACE index 412ba2fbf04d..b187949d663e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -246,6 +246,12 @@ new_local_repository( path = "/usr/local/cuda", ) +new_local_repository( + name = "cudnn_frontend", + build_file = "@//third_party:cudnn_frontend.BUILD", + path = "third_party/cudnn_frontend/", +) + local_repository( name = "com_github_google_flatbuffers", path = "third_party/flatbuffers", diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index f2b8e67a0dc1..694e93216b7a 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -346,7 +346,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor void set(cudnnDataType_t datatype) { AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); } -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 7600 void setEx( cudnnDataType_t datatype, cudnnLossNormalizationMode_t normMode, @@ -354,7 +353,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor AT_CUDNN_CHECK( cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); } -#endif }; struct TORCH_CUDA_CPP_API ActivationDescriptor diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 58d6046c608a..f18318fd0dcf 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -59,11 +59,7 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(bool training, at::MemoryFormat memor return CUDNN_BATCHNORM_PER_ACTIVATION; } else if (training && memory_format == at::MemoryFormat::ChannelsLast) { -#if CUDNN_VERSION >= 7400 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - return CUDNN_BATCHNORM_SPATIAL; -#endif // CUDNN_VERSION >= 7400 } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) { @@ -152,7 +148,6 @@ std::tuple cudnn_batch_norm( save_mean = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options()); -#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( @@ -204,22 +199,6 @@ std::tuple cudnn_batch_norm( workspace_size, reserve.mutable_data_ptr(), reserve_size)); -#else - reserve = at::empty({0}, input->options().dtype(kByte)); - AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( - handle, mode, &one, &zero, - idesc.desc(), input->data_ptr(), - idesc.desc(), output->data_ptr(), - wdesc.desc(), - weight->data_ptr(), - bias->data_ptr(), - exponential_average_factor, - at::maybe_data_ptr(running_mean), - at::maybe_data_ptr(running_var), - epsilon, - save_mean.mutable_data_ptr(), - save_var.mutable_data_ptr())); -#endif // CUDNN_VERSION >= 7400 } else { reserve = at::empty({0}, input->options().dtype(kByte)); // This keeps a consistent output with native_batch_norm @@ -317,7 +296,6 @@ std::tuple cudnn_batch_norm_backward( Constant one(dataType, 1); Constant zero(dataType, 0); -#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; @@ -354,19 +332,6 @@ std::tuple cudnn_batch_norm_backward( workspace_size, reserve->data_ptr(), reserve->numel())); -#else - AT_CUDNN_CHECK(cudnnBatchNormalizationBackward( - handle, mode, &one, &zero, &one, &zero, - idesc.desc(), input->data_ptr(), - odesc.desc(), grad_output->data_ptr(), - idesc.desc(), grad_input_t.data_ptr(), - wdesc.desc(), weight->data_ptr(), - grad_weight_t.data_ptr(), - grad_bias_t.data_ptr(), - epsilon, - save_mean->data_ptr(), - save_var->data_ptr())); -#endif // CUDNN_VERSION >= 7400 return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; } diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index fa06d0940471..89986adadac1 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -109,9 +109,7 @@ void raw_cudnn_convolution_add_relu_fallback_out( #if AT_CUDNN_ENABLED() -#include -#if HAS_CUDNN_V8() // v7 functions are preserved here to allow for runtime switching to v7 // (e.g., TORCH_CUDNN_V8_API_DISABLED=1). // Note that v7 forward/backward out can have different behavior from the v8 @@ -149,5 +147,4 @@ void raw_cudnn_convolution_add_relu_out_v7( bool deterministic, bool allow_tf32); #endif -#endif }} diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 5155d4449025..ef3a70a2232f 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -3,7 +3,6 @@ #if AT_CUDNN_ENABLED() -#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -60,10 +59,6 @@ // with the best algo, under the hood, cudnn will run with the slower kernel // since it sees fastest algorithm combination with a sub optimal mathType. -// Note [blocklist fft algorithms for strided dgrad] -// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution -// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 . - constexpr size_t operator "" _TiB(unsigned long long n) { return size_t(n) * 1024 * 1024 * 1024 * 1024; } @@ -225,15 +220,6 @@ size_t getMaxWorkspaceSize( template std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) { -// See Note [blocklist fft algorithms for strided dgrad] -#if CUDNN_VERSION < 7500 - bool blocklist = std::is_same::value; - int stride_dim = args.input.dim() - 2; - blocklist &= std::any_of(std::begin(args.params.stride), - std::begin(args.params.stride) + stride_dim, - [=](int n){return n != 1;}); -#endif - std::vector result; result.reserve(n_algo); for (const auto i : c10::irange(n_algo)) { @@ -244,16 +230,6 @@ std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArg if (perf.status == CUDNN_STATUS_SUCCESS) { if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) { - // See Note [blocklist fft algorithms for strided dgrad] -#if CUDNN_VERSION < 7500 - bool skip = blocklist; - skip &= (static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || - static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT); - if (skip) { - continue; - } -#endif - result.push_back(perf); } } @@ -493,11 +469,9 @@ public: perfResults[0].mathType = CUDNN_TENSOR_OP_MATH; } else { perfResults[0].mathType = CUDNN_DEFAULT_MATH; -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) { perfResults[0].mathType = CUDNN_FMA_MATH; } -#endif } search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory)); return perfResults; @@ -610,14 +584,10 @@ static inline void split_batch_dim_to_32bit_out( } -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 #define ASSERT_CORRECT_PRECISION(math_type) \ if (args.params.dataType == CUDNN_DATA_FLOAT) { \ TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \ } -#else -#define ASSERT_CORRECT_PRECISION(math_type) -#endif // CUDNN_VERSION >= 8000 // --------------------------------------------------------------------- @@ -672,11 +642,7 @@ void raw_cudnn_convolution_forward_out_32bit( } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_forward_out( -#else void raw_cudnn_convolution_forward_out_v7( -#endif const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -734,11 +700,7 @@ void raw_cudnn_convolution_backward_input_out_32bit( ); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_backward_input_out( -#else void raw_cudnn_convolution_backward_input_out_v7( -#endif const at::Tensor& grad_input, const at::Tensor& grad_output, const at::Tensor& weight, @@ -797,11 +759,7 @@ void raw_cudnn_convolution_backward_weight_out_32bit( ); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_backward_weight_out( -#else void raw_cudnn_convolution_backward_weight_out_v7( -#endif const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -853,12 +811,7 @@ void raw_cudnn_convolution_backward_weight_out_v7( TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } -#if !HAS_CUDNN_V8() -void raw_cudnn_convolution_add_relu_out( -#else void raw_cudnn_convolution_add_relu_out_v7( -#endif - const Tensor& output, const Tensor& input, const Tensor& weight, diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index d0f9d0a5427f..aa582fc19e14 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -4,10 +4,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include @@ -780,5 +776,4 @@ void raw_cudnn_convolution_add_relu_out( }} // at::native -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 96b31175f1ae..cb08b57c309c 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -18,7 +18,7 @@ #include #endif -#if (!AT_CUDNN_ENABLED()) || (CUDNN_VERSION < 7600) +#if (!AT_CUDNN_ENABLED()) namespace at { namespace native { diff --git a/aten/src/ATen/native/cudnn/Macros.h b/aten/src/ATen/native/cudnn/Macros.h deleted file mode 100644 index 941d77233b92..000000000000 --- a/aten/src/ATen/native/cudnn/Macros.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -// Note: The version below should not actually be 8000. Instead, it should -// be whatever version of cuDNN that v8 API work with PyTorch correctly. -// The version is set to 8000 today for convenience of debugging. -#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 -#define HAS_CUDNN_V8() true -#else -#define HAS_CUDNN_V8() false -#endif diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index c2d363045563..2d15e54c4052 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -28,7 +28,6 @@ #include #endif -int register_linear_params(); int register_embedding_params(); #ifdef USE_FBGEMM @@ -437,7 +436,9 @@ TORCH_API int register_conv_params<2>(); template TORCH_API int register_conv_params<3>(); -int register_linear_params() { +TORCH_API int register_linear_params(); + +TORCH_API int register_linear_params() { using SerializationType = std::tuple>; static auto register_linear_params = torch::selective_class_( diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index fbb46b4b0174..a225a86eeb90 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -2,8 +2,6 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() -#include -#if HAS_CUDNN_V8() #include #include @@ -259,6 +257,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 8d7d940f8986..4cb496640746 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -3,11 +3,8 @@ #if AT_CUDNN_ENABLED() -#include #include -#if HAS_CUDNN_V8() - #include #include #include @@ -25,6 +22,12 @@ #include #include +template +int register_conv_params(); + +extern template int register_conv_params<2>(); +extern template int register_conv_params<3>(); + // TODO: there is a table from input dtype and weight dtype to operator qdtype, // we can derive the operator dtype based on input dtype cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) { @@ -391,6 +394,8 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { // this is inconsistent with what has been done for conv2d where new variants use packed weights, and // old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d // and will eventually deprecate the old variants + register_conv_params<2>(); + register_conv_params<3>(); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run); @@ -401,6 +406,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at::native -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index 4269f8738640..44d37f27bf6f 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -20,6 +16,12 @@ #include #include +template +int register_conv_params(); + +extern template int register_conv_params<2>(); +extern template int register_conv_params<3>(); + template c10::intrusive_ptr> PackedConvWeightCudnn< kSpatialDim>:: @@ -203,6 +205,8 @@ class QConv1dPackWeightInt8Cudnn final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_conv_params<2>(); + register_conv_params<3>(); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8Cudnn::run_conv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv)); } @@ -211,6 +215,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp index 0b2cae06bca2..ce5ee36cad4f 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -23,6 +19,5 @@ std::tuple> PackedConvWeightCudnn< template std::tuple> PackedConvWeightCudnn< 2>::unpack(); -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 6cef8e10ed0f..37e08ba7861d 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -3,11 +3,8 @@ #if AT_CUDNN_ENABLED() -#include #include -#if HAS_CUDNN_V8() - #include #include #include @@ -25,6 +22,8 @@ #include #include +int register_linear_params(); + // TODO: there is a table from input dtype and weight dtype to operator dtype, // we can derive the operator dtype based on input dtype cudnn_frontend::MatMulDesc_v8 getLinearDescriptor(cudnnDataType_t dataType) { @@ -358,6 +357,7 @@ class QLinearInt8 final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), QLinearInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), QLinearInt8::run); } @@ -367,6 +367,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp index 212bab93d0d7..abbb5922f393 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -16,6 +12,8 @@ #include #include +int register_linear_params(); + c10::intrusive_ptr PackedLinearWeightCudnn::prepack( at::Tensor weight, c10::optional bias) { @@ -50,6 +48,7 @@ class QLinearPackWeightInt8Cudnn final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run)); } @@ -58,6 +57,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp index 2212d054b8c7..7200872480ef 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp @@ -3,10 +3,6 @@ #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -18,6 +14,5 @@ std::tuple> PackedLinearWeightCudnn::unpac return std::tuple>{orig_weight, bias_}; } -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 711afad775df..44d4251fbc62 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -3,7 +3,6 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() -#include #include #include #include @@ -54,7 +53,6 @@ Tensor adaptive_avg_pool2d_quantized_cuda( // TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() -// #if HAS_CUDNN_V8() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt // to per channel quantized tensors TORCH_CHECK(input.qscheme() == at::kPerTensorAffine, "adaptive_avg_pool2d_quantized_cuda oonly supports per tensor quantized tensors"); @@ -91,7 +89,6 @@ Tensor quantized_max_pool2d_cudnn( bool ceil_mode) { #ifdef USE_CUDA #if AT_CUDNN_ENABLED() -#if HAS_CUDNN_V8() check_maxpool2d_params( kernel_size, stride, @@ -207,10 +204,6 @@ Tensor quantized_max_pool2d_cudnn( // recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning return (ndim == 3 ? qy.view(std::vector(output_shape.begin() + 1, output_shape.end())) : qy); -#else // HAS_CUDNN_V8() - AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN v8 support"); - return Tensor{}; // never reached, placates the compiler -#endif // HAS_CUDNN_V8() #else // AT_CUDNN_ENABLED() AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); return Tensor{}; // never reached, placates the compiler diff --git a/aten/src/ATen/native/quantized/cudnn/utils.h b/aten/src/ATen/native/quantized/cudnn/utils.h index 3f2939fa5dab..18c891fcaa1c 100644 --- a/aten/src/ATen/native/quantized/cudnn/utils.h +++ b/aten/src/ATen/native/quantized/cudnn/utils.h @@ -8,10 +8,6 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linea #if AT_CUDNN_ENABLED() -#include - -#if HAS_CUDNN_V8() - #include #include #include @@ -354,6 +350,5 @@ cudnn_frontend::ExecutionPlan get_execplan_from_heuristics_else_fall_back(cudnn_ } // anonymous } // cudnn_utils -#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b035e50f81ae..72d40564fa40 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1414,12 +1414,6 @@ elseif(USE_ROCM) target_compile_definitions(torch_hip PRIVATE TORCH_HIP_BUILD_MAIN_LIB) endif() -if(USE_EXPERIMENTAL_CUDNN_V8_API) - if(USE_CUDA) - target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") - endif() -endif() - set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING "Experimental option to use a single thread pool for inter- and intra-op parallelism") if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") diff --git a/third_party/cudnn_frontend.BUILD b/third_party/cudnn_frontend.BUILD new file mode 100644 index 000000000000..0af69797a3e2 --- /dev/null +++ b/third_party/cudnn_frontend.BUILD @@ -0,0 +1,22 @@ +# Adopted from: https://github.com/tensorflow/tensorflow/blob/master/third_party/cudnn_frontend.BUILD + +# Description: +# The cuDNN Frontend API is a C++ header-only library that demonstrates how +# to use the cuDNN C backend API. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE.txt"]) + +cc_library( + name = "cudnn_frontend", + hdrs = glob(["include/**"]), + includes = ["include/"], + include_prefix = "third_party/cudnn_frontend", +) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b0d7bd842d33..24903a207ecc 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -145,10 +145,6 @@ if(USE_ROCM) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) endif() -if(USE_EXPERIMENTAL_CUDNN_V8_API) - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API) -endif() - if(USE_CUDNN OR USE_ROCM) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 09f8249d8b05..8f8c54e75b52 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -12,8 +12,6 @@ #if AT_CUDNN_ENABLED() -#include - #endif #include #include @@ -1359,15 +1357,8 @@ PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) { TORCH_WARN_ONCE( "cuDNN Benchmark limit is not supported in MIOpen and will have no effect."); #endif -#if AT_CUDNN_ENABLED() -#if HAS_CUDNN_V8() auto benchmark_limit = static_cast(THPUtils_unpackLong(arg)); at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit); -#else - TORCH_WARN_ONCE( - "cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect."); -#endif -#endif Py_RETURN_NONE; }