[BE] [cuDNN] Always build assuming cuDNN >= 8.0 (#95722)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 27084ed</samp>

This pull request simplifies and cleans up the code that uses the cuDNN library for convolution, batch normalization, CTC loss, and quantized operations. It removes the unnecessary checks and conditions for older cuDNN versions and the experimental cuDNN v8 API, and ~~replaces them with the stable `cudnn_frontend` API that requires cuDNN v8 or higher. It also adds the dependency and configuration for the `cudnn_frontend` library in the cmake and bazel files.~~ Correction: The v7 API will still be available with this PR, and can still be used, without any changes to the defaults. This change simply always _builds_ the v8 API, and removes the case where _only_ the v7 API is built.

This is a re-land of https://github.com/pytorch/pytorch/pull/91527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95722
Approved by: https://github.com/malfet
This commit is contained in:
Eddie Yan
2023-11-08 07:53:23 +00:00
committed by PyTorch MergeBot
parent 8ba11bf79d
commit df4f0b3829
24 changed files with 55 additions and 170 deletions

View File

@ -413,6 +413,7 @@ cc_library(
"@cuda//:cusolver", "@cuda//:cusolver",
"@cuda//:nvrtc", "@cuda//:nvrtc",
"@cudnn", "@cudnn",
"@cudnn_frontend",
], ],
alwayslink = True, alwayslink = True,
) )

View File

@ -213,9 +213,6 @@ cmake_dependent_option(
cmake_dependent_option( cmake_dependent_option(
USE_CUSPARSELT "Use cuSPARSELt" ON USE_CUSPARSELT "Use cuSPARSELt" ON
"USE_CUDA" OFF) "USE_CUDA" OFF)
cmake_dependent_option(
USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" ON
"USE_CUDNN" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON) option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)

View File

@ -246,6 +246,12 @@ new_local_repository(
path = "/usr/", path = "/usr/",
) )
new_local_repository(
name = "cudnn_frontend",
build_file = "@//third_party:cudnn_frontend.BUILD",
path = "third_party/cudnn_frontend/",
)
local_repository( local_repository(
name = "com_github_google_flatbuffers", name = "com_github_google_flatbuffers",
path = "third_party/flatbuffers", path = "third_party/flatbuffers",

View File

@ -305,7 +305,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor
void set(cudnnDataType_t datatype) { void set(cudnnDataType_t datatype) {
AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
} }
#if CUDNN_VERSION >= 7600
void setEx( void setEx(
cudnnDataType_t datatype, cudnnDataType_t datatype,
cudnnLossNormalizationMode_t normMode, cudnnLossNormalizationMode_t normMode,
@ -313,7 +312,6 @@ struct TORCH_CUDA_CPP_API CTCLossDescriptor
AT_CUDNN_CHECK( AT_CUDNN_CHECK(
cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
} }
#endif
}; };
struct TORCH_CUDA_CPP_API ActivationDescriptor struct TORCH_CUDA_CPP_API ActivationDescriptor

View File

@ -59,11 +59,7 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(bool training, at::MemoryFormat memor
return CUDNN_BATCHNORM_PER_ACTIVATION; return CUDNN_BATCHNORM_PER_ACTIVATION;
} else if (training && memory_format == at::MemoryFormat::ChannelsLast) { } else if (training && memory_format == at::MemoryFormat::ChannelsLast) {
#if CUDNN_VERSION >= 7400
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
return CUDNN_BATCHNORM_SPATIAL;
#endif // CUDNN_VERSION >= 7400
} else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) { } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) {
@ -152,7 +148,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
save_mean = at::empty({ num_features }, weight_t.options()); save_mean = at::empty({ num_features }, weight_t.options());
save_var = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options());
#if CUDNN_VERSION >= 7400
auto op = CUDNN_BATCHNORM_OPS_BN; auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size; size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
@ -204,22 +199,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
workspace_size, workspace_size,
reserve.mutable_data_ptr(), reserve.mutable_data_ptr(),
reserve_size)); reserve_size));
#else
reserve = at::empty({0}, input->options().dtype(kByte));
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
idesc.desc(), output->data_ptr(),
wdesc.desc(),
weight->data_ptr(),
bias->data_ptr(),
exponential_average_factor,
at::maybe_data_ptr(running_mean),
at::maybe_data_ptr(running_var),
epsilon,
save_mean.mutable_data_ptr(),
save_var.mutable_data_ptr()));
#endif // CUDNN_VERSION >= 7400
} else { } else {
reserve = at::empty({0}, input->options().dtype(kByte)); reserve = at::empty({0}, input->options().dtype(kByte));
// This keeps a consistent output with native_batch_norm // This keeps a consistent output with native_batch_norm
@ -317,7 +296,6 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
Constant one(dataType, 1); Constant one(dataType, 1);
Constant zero(dataType, 0); Constant zero(dataType, 0);
#if CUDNN_VERSION >= 7400
auto op = CUDNN_BATCHNORM_OPS_BN; auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size; size_t workspace_size;
@ -354,19 +332,6 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
workspace_size, workspace_size,
reserve->data_ptr(), reserve->data_ptr(),
reserve->numel())); reserve->numel()));
#else
AT_CUDNN_CHECK(cudnnBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &zero,
idesc.desc(), input->data_ptr(),
odesc.desc(), grad_output->data_ptr(),
idesc.desc(), grad_input_t.data_ptr(),
wdesc.desc(), weight->data_ptr(),
grad_weight_t.data_ptr(),
grad_bias_t.data_ptr(),
epsilon,
save_mean->data_ptr(),
save_var->data_ptr()));
#endif // CUDNN_VERSION >= 7400
return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t}; return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
} }

View File

@ -109,9 +109,7 @@ void raw_cudnn_convolution_add_relu_fallback_out(
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
// v7 functions are preserved here to allow for runtime switching to v7 // v7 functions are preserved here to allow for runtime switching to v7
// (e.g., TORCH_CUDNN_V8_API_DISABLED=1). // (e.g., TORCH_CUDNN_V8_API_DISABLED=1).
// Note that v7 forward/backward out can have different behavior from the v8 // Note that v7 forward/backward out can have different behavior from the v8
@ -149,5 +147,4 @@ void raw_cudnn_convolution_add_relu_out_v7(
bool deterministic, bool deterministic,
bool allow_tf32); bool allow_tf32);
#endif #endif
#endif
}} }}

View File

@ -3,7 +3,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <ATen/core/Tensor.h> #include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
@ -60,10 +59,6 @@
// with the best algo, under the hood, cudnn will run with the slower kernel // with the best algo, under the hood, cudnn will run with the slower kernel
// since it sees fastest algorithm combination with a sub optimal mathType. // since it sees fastest algorithm combination with a sub optimal mathType.
// Note [blocklist fft algorithms for strided dgrad]
// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution
// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 .
constexpr size_t operator "" _TiB(unsigned long long n) { constexpr size_t operator "" _TiB(unsigned long long n) {
return size_t(n) * 1024 * 1024 * 1024 * 1024; return size_t(n) * 1024 * 1024 * 1024 * 1024;
} }
@ -225,15 +220,6 @@ size_t getMaxWorkspaceSize(
template<typename perf_t> template<typename perf_t>
std::vector<perf_t> getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) { std::vector<perf_t> getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) {
// See Note [blocklist fft algorithms for strided dgrad]
#if CUDNN_VERSION < 7500
bool blocklist = std::is_same<decltype(perfResults[0].algo), cudnnConvolutionBwdDataAlgo_t>::value;
int stride_dim = args.input.dim() - 2;
blocklist &= std::any_of(std::begin(args.params.stride),
std::begin(args.params.stride) + stride_dim,
[=](int n){return n != 1;});
#endif
std::vector<perf_t> result; std::vector<perf_t> result;
result.reserve(n_algo); result.reserve(n_algo);
for (const auto i : c10::irange(n_algo)) { for (const auto i : c10::irange(n_algo)) {
@ -244,16 +230,6 @@ std::vector<perf_t> getValidAlgorithms(perf_t *perfResults, const ConvolutionArg
if (perf.status == CUDNN_STATUS_SUCCESS) { if (perf.status == CUDNN_STATUS_SUCCESS) {
if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) { if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) {
// See Note [blocklist fft algorithms for strided dgrad]
#if CUDNN_VERSION < 7500
bool skip = blocklist;
skip &= (static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT);
if (skip) {
continue;
}
#endif
result.push_back(perf); result.push_back(perf);
} }
} }
@ -493,11 +469,9 @@ public:
perfResults[0].mathType = CUDNN_TENSOR_OP_MATH; perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
} else { } else {
perfResults[0].mathType = CUDNN_DEFAULT_MATH; perfResults[0].mathType = CUDNN_DEFAULT_MATH;
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) { if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
perfResults[0].mathType = CUDNN_FMA_MATH; perfResults[0].mathType = CUDNN_FMA_MATH;
} }
#endif
} }
search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory)); search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory));
return perfResults; return perfResults;
@ -610,14 +584,10 @@ static inline void split_batch_dim_to_32bit_out(
} }
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#define ASSERT_CORRECT_PRECISION(math_type) \ #define ASSERT_CORRECT_PRECISION(math_type) \
if (args.params.dataType == CUDNN_DATA_FLOAT) { \ if (args.params.dataType == CUDNN_DATA_FLOAT) { \
TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \ TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
} }
#else
#define ASSERT_CORRECT_PRECISION(math_type)
#endif // CUDNN_VERSION >= 8000
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
@ -672,11 +642,7 @@ void raw_cudnn_convolution_forward_out_32bit(
} }
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_forward_out(
#else
void raw_cudnn_convolution_forward_out_v7( void raw_cudnn_convolution_forward_out_v7(
#endif
const Tensor& output, const Tensor& input, const Tensor& weight, const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) { bool benchmark, bool deterministic, bool allow_tf32) {
@ -734,11 +700,7 @@ void raw_cudnn_convolution_backward_input_out_32bit(
); );
} }
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_backward_input_out(
#else
void raw_cudnn_convolution_backward_input_out_v7( void raw_cudnn_convolution_backward_input_out_v7(
#endif
const at::Tensor& grad_input, const at::Tensor& grad_input,
const at::Tensor& grad_output, const at::Tensor& grad_output,
const at::Tensor& weight, const at::Tensor& weight,
@ -797,11 +759,7 @@ void raw_cudnn_convolution_backward_weight_out_32bit(
); );
} }
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_backward_weight_out(
#else
void raw_cudnn_convolution_backward_weight_out_v7( void raw_cudnn_convolution_backward_weight_out_v7(
#endif
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) { bool benchmark, bool deterministic, bool allow_tf32) {
@ -853,12 +811,7 @@ void raw_cudnn_convolution_backward_weight_out_v7(
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
} }
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_add_relu_out(
#else
void raw_cudnn_convolution_add_relu_out_v7( void raw_cudnn_convolution_add_relu_out_v7(
#endif
const Tensor& output, const Tensor& output,
const Tensor& input, const Tensor& input,
const Tensor& weight, const Tensor& weight,

View File

@ -4,10 +4,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/cudnn/cudnn-wrapper.h> #include <ATen/cudnn/cudnn-wrapper.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
@ -800,5 +796,4 @@ void raw_cudnn_convolution_add_relu_out(
}} // at::native }} // at::native
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED

View File

@ -18,7 +18,7 @@
#include <ATen/ops/empty_like.h> #include <ATen/ops/empty_like.h>
#endif #endif
#if (!AT_CUDNN_ENABLED()) || (CUDNN_VERSION < 7600) #if (!AT_CUDNN_ENABLED())
namespace at { namespace native { namespace at { namespace native {

View File

@ -1,12 +0,0 @@
#pragma once
#include <ATen/cudnn/cudnn-wrapper.h>
// Note: The version below should not actually be 8000. Instead, it should
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
// The version is set to 8000 today for convenience of debugging.
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
#define HAS_CUDNN_V8() true
#else
#define HAS_CUDNN_V8() false
#endif

View File

@ -28,7 +28,6 @@
#include <utility> #include <utility>
#endif #endif
int register_linear_params();
int register_embedding_params(); int register_embedding_params();
#ifdef USE_FBGEMM #ifdef USE_FBGEMM
@ -437,7 +436,9 @@ TORCH_API int register_conv_params<2>();
template template
TORCH_API int register_conv_params<3>(); TORCH_API int register_conv_params<3>();
int register_linear_params() { TORCH_API int register_linear_params();
TORCH_API int register_linear_params() {
using SerializationType = std::tuple<at::Tensor, c10::optional<at::Tensor>>; using SerializationType = std::tuple<at::Tensor, c10::optional<at::Tensor>>;
static auto register_linear_params = static auto register_linear_params =
torch::selective_class_<LinearPackedParamsBase>( torch::selective_class_<LinearPackedParamsBase>(

View File

@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/core/TensorBase.h> #include <ATen/core/TensorBase.h>
#include <ATen/core/TensorBody.h> #include <ATen/core/TensorBody.h>
@ -259,6 +257,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
} // namespace native } // namespace native
} // namespace at } // namespace at
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,11 +3,8 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <c10/util/ArrayRef.h> #include <c10/util/ArrayRef.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
@ -25,6 +22,12 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
template <int kSpatialDim = 2>
int register_conv_params();
extern template int register_conv_params<2>();
extern template int register_conv_params<3>();
// TODO: there is a table from input dtype and weight dtype to operator qdtype, // TODO: there is a table from input dtype and weight dtype to operator qdtype,
// we can derive the operator dtype based on input dtype // we can derive the operator dtype based on input dtype
cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) { cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) {
@ -391,6 +394,8 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
// this is inconsistent with what has been done for conv2d where new variants use packed weights, and // this is inconsistent with what has been done for conv2d where new variants use packed weights, and
// old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d // old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d
// and will eventually deprecate the old variants // and will eventually deprecate the old variants
register_conv_params<2>();
register_conv_params<3>();
m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8<false>::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8<false>::run);
m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8<true>::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8<true>::run);
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run);
@ -401,6 +406,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
} // namespace at::native } // namespace at::native
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,10 +3,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/native/quantized/cpu/QuantUtils.h> #include <ATen/native/quantized/cpu/QuantUtils.h>
@ -20,6 +16,12 @@
#include <array> #include <array>
#include <vector> #include <vector>
template <int kSpatialDim = 2>
int register_conv_params();
extern template int register_conv_params<2>();
extern template int register_conv_params<3>();
template <int kSpatialDim> template <int kSpatialDim>
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn< c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn<
kSpatialDim>:: kSpatialDim>::
@ -203,6 +205,8 @@ class QConv1dPackWeightInt8Cudnn final {
}; };
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
register_conv_params<2>();
register_conv_params<3>();
m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8Cudnn::run_conv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8Cudnn::run_conv));
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv));
} }
@ -211,6 +215,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
} // namespace native } // namespace native
} // namespace at } // namespace at
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,10 +3,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/native/quantized/cudnn/utils.h> #include <ATen/native/quantized/cudnn/utils.h>
#include <ATen/native/quantized/PackedParams.h> #include <ATen/native/quantized/PackedParams.h>
@ -23,6 +19,5 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn< template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
2>::unpack(); 2>::unpack();
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,11 +3,8 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <c10/util/ArrayRef.h> #include <c10/util/ArrayRef.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
@ -25,6 +22,8 @@
#include <iostream> #include <iostream>
#include <unordered_map> #include <unordered_map>
int register_linear_params();
// TODO: there is a table from input dtype and weight dtype to operator dtype, // TODO: there is a table from input dtype and weight dtype to operator dtype,
// we can derive the operator dtype based on input dtype // we can derive the operator dtype based on input dtype
cudnn_frontend::MatMulDesc_v8 getLinearDescriptor(cudnnDataType_t dataType) { cudnn_frontend::MatMulDesc_v8 getLinearDescriptor(cudnnDataType_t dataType) {
@ -358,6 +357,7 @@ class QLinearInt8 final {
}; };
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
register_linear_params();
m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), QLinearInt8<false>::run); m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), QLinearInt8<false>::run);
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), QLinearInt8<true>::run); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), QLinearInt8<true>::run);
} }
@ -367,6 +367,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
} // namespace at } // namespace at
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,10 +3,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/native/quantized/cudnn/utils.h> #include <ATen/native/quantized/cudnn/utils.h>
@ -16,6 +12,8 @@
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <torch/library.h> #include <torch/library.h>
int register_linear_params();
c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightCudnn::prepack( c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightCudnn::prepack(
at::Tensor weight, at::Tensor weight,
c10::optional<at::Tensor> bias) { c10::optional<at::Tensor> bias) {
@ -50,6 +48,7 @@ class QLinearPackWeightInt8Cudnn final {
}; };
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
register_linear_params();
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run));
} }
@ -58,6 +57,5 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
} // namespace native } // namespace native
} // namespace at } // namespace at
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,10 +3,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/native/quantized/cudnn/utils.h> #include <ATen/native/quantized/cudnn/utils.h>
#include <ATen/native/quantized/PackedParams.h> #include <ATen/native/quantized/PackedParams.h>
@ -18,6 +14,5 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightCudnn::unpac
return std::tuple<at::Tensor, c10::optional<at::Tensor>>{orig_weight, bias_}; return std::tuple<at::Tensor, c10::optional<at::Tensor>>{orig_weight, bias_};
} }
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Descriptors.h> #include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
@ -54,7 +53,6 @@ Tensor adaptive_avg_pool2d_quantized_cuda(
// TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn // TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn
#ifdef USE_CUDA #ifdef USE_CUDA
// #if AT_CUDNN_ENABLED() // #if AT_CUDNN_ENABLED()
// #if HAS_CUDNN_V8()
// TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt
// to per channel quantized tensors // to per channel quantized tensors
TORCH_CHECK(input.qscheme() == at::kPerTensorAffine, "adaptive_avg_pool2d_quantized_cuda oonly supports per tensor quantized tensors"); TORCH_CHECK(input.qscheme() == at::kPerTensorAffine, "adaptive_avg_pool2d_quantized_cuda oonly supports per tensor quantized tensors");
@ -91,7 +89,6 @@ Tensor quantized_max_pool2d_cudnn(
bool ceil_mode) { bool ceil_mode) {
#ifdef USE_CUDA #ifdef USE_CUDA
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#if HAS_CUDNN_V8()
check_maxpool2d_params( check_maxpool2d_params(
kernel_size, kernel_size,
stride, stride,
@ -207,10 +204,6 @@ Tensor quantized_max_pool2d_cudnn(
// recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning // recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning
return (ndim == 3 ? qy.view(std::vector<int64_t>(output_shape.begin() + 1, output_shape.end())) : qy); return (ndim == 3 ? qy.view(std::vector<int64_t>(output_shape.begin() + 1, output_shape.end())) : qy);
#else // HAS_CUDNN_V8()
AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN v8 support");
return Tensor{}; // never reached, placates the compiler
#endif // HAS_CUDNN_V8()
#else // AT_CUDNN_ENABLED() #else // AT_CUDNN_ENABLED()
AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support");
return Tensor{}; // never reached, placates the compiler return Tensor{}; // never reached, placates the compiler

View File

@ -8,10 +8,6 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linea
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/cudnn/Types.h> #include <ATen/cudnn/Types.h>
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <ATen/native/quantized/PackedParams.h> #include <ATen/native/quantized/PackedParams.h>
@ -354,6 +350,5 @@ cudnn_frontend::ExecutionPlan get_execplan_from_heuristics_else_fall_back(cudnn_
} // anonymous } // anonymous
} // cudnn_utils } // cudnn_utils
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED #endif // AT_CUDNN_ENABLED
#endif // USE_CUDA #endif // USE_CUDA

View File

@ -1394,12 +1394,6 @@ elseif(USE_ROCM)
target_compile_definitions(torch_hip PRIVATE TORCH_HIP_BUILD_MAIN_LIB) target_compile_definitions(torch_hip PRIVATE TORCH_HIP_BUILD_MAIN_LIB)
endif() endif()
if(USE_EXPERIMENTAL_CUDNN_V8_API)
if(USE_CUDA)
target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
endif()
endif()
set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING
"Experimental option to use a single thread pool for inter- and intra-op parallelism") "Experimental option to use a single thread pool for inter- and intra-op parallelism")
if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") if("${EXPERIMENTAL_SINGLE_THREAD_POOL}")

22
third_party/cudnn_frontend.BUILD vendored Normal file
View File

@ -0,0 +1,22 @@
# Adopted from: https://github.com/tensorflow/tensorflow/blob/master/third_party/cudnn_frontend.BUILD
# Description:
# The cuDNN Frontend API is a C++ header-only library that demonstrates how
# to use the cuDNN C backend API.
load("@rules_cc//cc:defs.bzl", "cc_library")
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # MIT
exports_files(["LICENSE.txt"])
cc_library(
name = "cudnn_frontend",
hdrs = glob(["include/**"]),
includes = ["include/"],
include_prefix = "third_party/cudnn_frontend",
)

View File

@ -145,10 +145,6 @@ if(USE_ROCM)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB})
endif() endif()
if(USE_EXPERIMENTAL_CUDNN_V8_API)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API)
endif()
if(USE_CUDNN OR USE_ROCM) if(USE_CUDNN OR USE_ROCM)
list(APPEND TORCH_PYTHON_SRCS list(APPEND TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp ${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp

View File

@ -12,8 +12,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#endif #endif
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
@ -1356,12 +1354,7 @@ PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) {
"cuDNN Benchmark limit is not supported in MIOpen and will have no effect."); "cuDNN Benchmark limit is not supported in MIOpen and will have no effect.");
#endif #endif
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#if HAS_CUDNN_V8()
at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit); at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit);
#else
TORCH_WARN_ONCE(
"cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect.");
#endif
#endif #endif
Py_RETURN_NONE; Py_RETURN_NONE;
} }