diff --git a/BUILD.bazel b/BUILD.bazel index b3f0435006d0..887647b2363e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -407,7 +407,6 @@ cc_library( "@cuda//:cusolver", "@cuda//:nvrtc", "@cudnn", - "@cudnn_frontend", ], alwayslink = True, ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e948801a501..7081ad868298 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,9 @@ cmake_dependent_option( cmake_dependent_option( BUILD_NVFUSER_BENCHMARK "Build C++ binaries for nvfuser benchmarks" OFF "USE_CUDA" OFF) +cmake_dependent_option( + USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" ON + "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) diff --git a/WORKSPACE b/WORKSPACE index 925e5ba3ea6e..e8591f291abd 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -203,12 +203,6 @@ new_local_repository( path = "/usr/", ) -new_local_repository( - name = "cudnn_frontend", - build_file = "@//third_party:cudnn_frontend.BUILD", - path = "third_party/cudnn_frontend/", -) - local_repository( name = "com_github_google_flatbuffers", path = "third_party/flatbuffers", diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 56ad0784fe62..e111987785cc 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -305,6 +305,7 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor void set(cudnnDataType_t datatype) { AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); } +#if CUDNN_VERSION >= 7600 void setEx( cudnnDataType_t datatype, cudnnLossNormalizationMode_t normMode, @@ -312,6 +313,7 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor AT_CUDNN_CHECK( cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); } +#endif }; struct TORCH_CUDA_CPP_API ActivationDescriptor diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index da57ec21500d..f1f275e63885 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -59,7 +59,11 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(bool training, at::MemoryFormat memor return CUDNN_BATCHNORM_PER_ACTIVATION; } else if (training && memory_format == at::MemoryFormat::ChannelsLast) { +#if CUDNN_VERSION >= 7400 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + return CUDNN_BATCHNORM_SPATIAL; +#endif // CUDNN_VERSION >= 7400 } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) { @@ -148,6 +152,7 @@ std::tuple cudnn_batch_norm( save_mean = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options()); +#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( @@ -199,6 +204,22 @@ std::tuple cudnn_batch_norm( workspace_size, reserve.data_ptr(), reserve_size)); +#else + reserve = at::empty({0}, input->options().dtype(kByte)); + AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( + handle, mode, &one, &zero, + idesc.desc(), input->data_ptr(), + idesc.desc(), output->data_ptr(), + wdesc.desc(), + weight->data_ptr(), + bias->data_ptr(), + exponential_average_factor, + at::maybe_data_ptr(running_mean), + at::maybe_data_ptr(running_var), + epsilon, + save_mean.data_ptr(), + save_var.data_ptr())); +#endif // CUDNN_VERSION >= 7400 } else { reserve = at::empty({0}, input->options().dtype(kByte)); // This keeps a consistent output with native_batch_norm @@ -296,6 +317,7 @@ std::tuple cudnn_batch_norm_backward( Constant one(dataType, 1); Constant zero(dataType, 0); +#if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; @@ -332,6 +354,19 @@ std::tuple cudnn_batch_norm_backward( workspace_size, reserve->data_ptr(), reserve->numel())); +#else + AT_CUDNN_CHECK(cudnnBatchNormalizationBackward( + handle, mode, &one, &zero, &one, &zero, + idesc.desc(), input->data_ptr(), + odesc.desc(), grad_output->data_ptr(), + idesc.desc(), grad_input_t.data_ptr(), + wdesc.desc(), weight->data_ptr(), + grad_weight_t.data_ptr(), + grad_bias_t.data_ptr(), + epsilon, + save_mean->data_ptr(), + save_var->data_ptr())); +#endif // CUDNN_VERSION >= 7400 return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; } diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index ec2cd50b1e0a..fa06d0940471 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -111,6 +111,7 @@ void raw_cudnn_convolution_add_relu_fallback_out( #if AT_CUDNN_ENABLED() #include +#if HAS_CUDNN_V8() // v7 functions are preserved here to allow for runtime switching to v7 // (e.g., TORCH_CUDNN_V8_API_DISABLED=1). // Note that v7 forward/backward out can have different behavior from the v8 @@ -148,4 +149,5 @@ void raw_cudnn_convolution_add_relu_out_v7( bool deterministic, bool allow_tf32); #endif +#endif }} diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 0b99d1e58676..f5c7af79a740 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -3,6 +3,7 @@ #if AT_CUDNN_ENABLED() +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -59,6 +60,10 @@ // with the best algo, under the hood, cudnn will run with the slower kernel // since it sees fastest algorithm combination with a sub optimal mathType. +// Note [blocklist fft algorithms for strided dgrad] +// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution +// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 . + constexpr size_t operator "" _TiB(unsigned long long n) { return size_t(n) * 1024 * 1024 * 1024 * 1024; } @@ -220,6 +225,15 @@ size_t getMaxWorkspaceSize( template std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) { +// See Note [blocklist fft algorithms for strided dgrad] +#if CUDNN_VERSION < 7500 + bool blocklist = std::is_same::value; + int stride_dim = args.input.dim() - 2; + blocklist &= std::any_of(std::begin(args.params.stride), + std::begin(args.params.stride) + stride_dim, + [=](int n){return n != 1;}); +#endif + std::vector result; result.reserve(n_algo); for (const auto i : c10::irange(n_algo)) { @@ -230,6 +244,16 @@ std::vector getValidAlgorithms(perf_t *perfResults, const ConvolutionArg if (perf.status == CUDNN_STATUS_SUCCESS) { if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) { + // See Note [blocklist fft algorithms for strided dgrad] +#if CUDNN_VERSION < 7500 + bool skip = blocklist; + skip &= (static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || + static_cast(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT); + if (skip) { + continue; + } +#endif + result.push_back(perf); } } @@ -469,9 +493,11 @@ public: perfResults[0].mathType = CUDNN_TENSOR_OP_MATH; } else { perfResults[0].mathType = CUDNN_DEFAULT_MATH; +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) { perfResults[0].mathType = CUDNN_FMA_MATH; } +#endif } search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory)); return perfResults; @@ -584,10 +610,14 @@ static inline void split_batch_dim_to_32bit_out( } +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 #define ASSERT_CORRECT_PRECISION(math_type) \ if (args.params.dataType == CUDNN_DATA_FLOAT) { \ TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \ } +#else +#define ASSERT_CORRECT_PRECISION(math_type) +#endif // CUDNN_VERSION >= 8000 // --------------------------------------------------------------------- @@ -642,7 +672,11 @@ void raw_cudnn_convolution_forward_out_32bit( } +#if !HAS_CUDNN_V8() +void raw_cudnn_convolution_forward_out( +#else void raw_cudnn_convolution_forward_out_v7( +#endif const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -700,7 +734,11 @@ void raw_cudnn_convolution_backward_input_out_32bit( ); } +#if !HAS_CUDNN_V8() +void raw_cudnn_convolution_backward_input_out( +#else void raw_cudnn_convolution_backward_input_out_v7( +#endif const at::Tensor& grad_input, const at::Tensor& grad_output, const at::Tensor& weight, @@ -759,7 +797,11 @@ void raw_cudnn_convolution_backward_weight_out_32bit( ); } +#if !HAS_CUDNN_V8() +void raw_cudnn_convolution_backward_weight_out( +#else void raw_cudnn_convolution_backward_weight_out_v7( +#endif const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { @@ -811,7 +853,12 @@ void raw_cudnn_convolution_backward_weight_out_v7( TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } +#if !HAS_CUDNN_V8() +void raw_cudnn_convolution_add_relu_out( +#else void raw_cudnn_convolution_add_relu_out_v7( +#endif + const Tensor& output, const Tensor& input, const Tensor& weight, diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index de38de30d49d..916ad6bbf920 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -4,6 +4,10 @@ #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include @@ -783,4 +787,5 @@ void raw_cudnn_convolution_add_relu_out( }} // at::native +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index eba108823fec..7737e91d4417 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -18,7 +18,7 @@ #include #endif -#if (!AT_CUDNN_ENABLED()) +#if (!AT_CUDNN_ENABLED()) || (CUDNN_VERSION < 7600) namespace at { namespace native { diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index a225a86eeb90..fbb46b4b0174 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -2,6 +2,8 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() +#include +#if HAS_CUDNN_V8() #include #include @@ -257,5 +259,6 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index c93d669dad99..ca1c3e146684 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -3,8 +3,11 @@ #if AT_CUDNN_ENABLED() +#include #include +#if HAS_CUDNN_V8() + #include #include #include @@ -429,5 +432,6 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index c8068468431f..e214ab6492df 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -3,6 +3,10 @@ #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include #include @@ -208,5 +212,6 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp index ce5ee36cad4f..0b2cae06bca2 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp @@ -3,6 +3,10 @@ #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include #include @@ -19,5 +23,6 @@ std::tuple> PackedConvWeightCudnn< template std::tuple> PackedConvWeightCudnn< 2>::unpack(); +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 3b477eeb8e79..6cef8e10ed0f 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -3,8 +3,11 @@ #if AT_CUDNN_ENABLED() +#include #include +#if HAS_CUDNN_V8() + #include #include #include @@ -364,5 +367,6 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp index 1d21e4d80315..212bab93d0d7 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp @@ -3,6 +3,10 @@ #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include #include @@ -54,5 +58,6 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace native } // namespace at +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp index 7200872480ef..2212d054b8c7 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp @@ -3,6 +3,10 @@ #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include #include @@ -14,5 +18,6 @@ std::tuple> PackedLinearWeightCudnn::unpac return std::tuple>{orig_weight, bias_}; } +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index bad45329fb95..ffab667dc3d1 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -3,6 +3,7 @@ #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() +#include #include #include #include @@ -53,6 +54,7 @@ Tensor adaptive_avg_pool2d_quantized_cuda( // TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() +// #if HAS_CUDNN_V8() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt // to per channel quantized tensors TORCH_CHECK(input.qscheme() == at::kPerTensorAffine, "adaptive_avg_pool2d_quantized_cuda oonly supports per tensor quantized tensors"); @@ -89,6 +91,7 @@ Tensor quantized_max_pool2d_cudnn( bool ceil_mode) { #ifdef USE_CUDA #if AT_CUDNN_ENABLED() +#if HAS_CUDNN_V8() check_maxpool2d_params( kernel_size, stride, @@ -204,6 +207,10 @@ Tensor quantized_max_pool2d_cudnn( // recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning return (ndim == 3 ? qy.view(std::vector(output_shape.begin() + 1, output_shape.end())) : qy); +#else // HAS_CUDNN_V8() + AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN v8 support"); + return Tensor{}; // never reached, placates the compiler +#endif // HAS_CUDNN_V8() #else // AT_CUDNN_ENABLED() AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); return Tensor{}; // never reached, placates the compiler diff --git a/aten/src/ATen/native/quantized/cudnn/utils.h b/aten/src/ATen/native/quantized/cudnn/utils.h index 6f6a3b1bf60f..5fd021383b12 100644 --- a/aten/src/ATen/native/quantized/cudnn/utils.h +++ b/aten/src/ATen/native/quantized/cudnn/utils.h @@ -8,6 +8,10 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linea #if AT_CUDNN_ENABLED() +#include + +#if HAS_CUDNN_V8() + #include #include #include @@ -350,5 +354,6 @@ cudnn_frontend::ExecutionPlan get_execplan_from_heuristics_else_fall_back(cudnn_ } // anonymous } // cudnn_utils +#endif // HAS_CUDNN_V8 #endif // AT_CUDNN_ENABLED #endif // USE_CUDA diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 7e71a542a96c..c0585b9f05ae 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1393,6 +1393,12 @@ elseif(USE_ROCM) target_compile_definitions(torch_hip PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB") endif() +if(USE_EXPERIMENTAL_CUDNN_V8_API) + if(USE_CUDA) + target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") + endif() +endif() + set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING "Experimental option to use a single thread pool for inter- and intra-op parallelism") if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 3b41664e6310..23c9cd8eeb77 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -77,6 +77,7 @@ function(caffe2_print_configuration_summary) message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") + message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") if(${USE_CUDNN}) diff --git a/third_party/cudnn_frontend.BUILD b/third_party/cudnn_frontend.BUILD deleted file mode 100644 index 0af69797a3e2..000000000000 --- a/third_party/cudnn_frontend.BUILD +++ /dev/null @@ -1,22 +0,0 @@ -# Adopted from: https://github.com/tensorflow/tensorflow/blob/master/third_party/cudnn_frontend.BUILD - -# Description: -# The cuDNN Frontend API is a C++ header-only library that demonstrates how -# to use the cuDNN C backend API. - -load("@rules_cc//cc:defs.bzl", "cc_library") - -package( - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) # MIT - -exports_files(["LICENSE.txt"]) - -cc_library( - name = "cudnn_frontend", - hdrs = glob(["include/**"]), - includes = ["include/"], - include_prefix = "third_party/cudnn_frontend", -) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index e775c83cd9b4..e5d13b57535d 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -150,6 +150,10 @@ if(USE_ROCM) list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${roctracer_INCLUDE_DIRS}) endif() +if(USE_EXPERIMENTAL_CUDNN_V8_API) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API) +endif() + if(USE_CUDNN OR USE_ROCM) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 20f0a407510d..331b6add4434 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1146,7 +1146,12 @@ PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) { "cuDNN Benchmark limit is not supported in MIOpen and will have no effect."); #endif #if AT_CUDNN_ENABLED() +#if HAS_CUDNN_V8() at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit); +#else + TORCH_WARN_ONCE( + "cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect."); +#endif #endif Py_RETURN_NONE; }