mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[cuDNN] Cleanup cuDNN < 8.1 ifdefs (#120862)
Follow-up of #95722 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120862 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9087f8571
commit
967dd31621
@ -210,9 +210,7 @@ struct TORCH_CUDA_CPP_API ConvolutionDescriptor
|
||||
if(dataType == CUDNN_DATA_HALF) {
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
|
||||
} else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -304,13 +302,9 @@ struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
|
||||
if (input_type == CUDNN_DATA_HALF) {
|
||||
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
|
||||
}
|
||||
#endif
|
||||
#if !defined(USE_CUDNN_RNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
|
||||
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
|
||||
}
|
||||
#endif
|
||||
#ifndef USE_CUDNN_RNN_V8_API
|
||||
else {
|
||||
// Technically, as the default it's not necessary to explicitly
|
||||
// set this.
|
||||
|
@ -78,15 +78,8 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(
|
||||
return CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
} else if (training && memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
|
||||
} else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
|
||||
#if CUDNN_VERSION >= 8100
|
||||
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
#else
|
||||
return CUDNN_BATCHNORM_SPATIAL;
|
||||
#endif // CUDNN_VERSION >= 8100
|
||||
|
||||
} else {
|
||||
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
|
||||
// introduced in CuDNN 7 for performance optimization, but it results in
|
||||
|
@ -735,23 +735,6 @@ Tensor cudnn_convolution_relu(
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
#ifdef AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
raw_cudnn_convolution_add_relu_fallback_out(
|
||||
output_t,
|
||||
input,
|
||||
weight,
|
||||
output_t, // use output_t as z to satisfy CUDNN API
|
||||
0, // alpha
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark, // benchmark
|
||||
false, // deterministic
|
||||
allow_tf32 // allow_tf32
|
||||
);
|
||||
#else // AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
raw_cudnn_convolution_add_relu_out(
|
||||
output_t,
|
||||
input,
|
||||
@ -767,7 +750,6 @@ Tensor cudnn_convolution_relu(
|
||||
false, // deterministic
|
||||
allow_tf32 // allow_tf32
|
||||
);
|
||||
#endif
|
||||
|
||||
return output_t;
|
||||
}
|
||||
@ -813,23 +795,6 @@ Tensor cudnn_convolution_add_relu(
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
#ifdef AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
raw_cudnn_convolution_add_relu_fallback_out(
|
||||
output_t,
|
||||
input,
|
||||
weight,
|
||||
z,
|
||||
_alpha,
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark,
|
||||
false, // deterministic
|
||||
allow_tf32 // allow_tf32
|
||||
);
|
||||
#else // AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
raw_cudnn_convolution_add_relu_out(
|
||||
output_t,
|
||||
input,
|
||||
@ -845,7 +810,6 @@ Tensor cudnn_convolution_add_relu(
|
||||
false, // deterministic
|
||||
allow_tf32 // allow_tf32
|
||||
);
|
||||
#endif // AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
|
||||
return output_t;
|
||||
}
|
||||
|
@ -6,10 +6,6 @@
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
|
||||
#if CUDNN_VERSION < 8000
|
||||
#define AT_CUDNN_CONV_BIAS_RELU_FALLBACK
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
|
@ -16,17 +16,12 @@
|
||||
#include <cudnn.h>
|
||||
|
||||
static_assert(
|
||||
CUDNN_VERSION >= 5000,
|
||||
"Caffe2 requires cudnn version 5.0 or above.");
|
||||
|
||||
#if CUDNN_VERSION < 6000
|
||||
#pragma message "CUDNN version under 6.0 is supported at best effort."
|
||||
#pragma message "We strongly encourage you to move to 6.0 and above."
|
||||
#pragma message "This message is intended to annoy you enough to update."
|
||||
#endif // CUDNN_VERSION < 6000
|
||||
CUDNN_VERSION >= 8200,
|
||||
"Caffe2 requires cudnn version 8.2 or above.");
|
||||
|
||||
#define CUDNN_VERSION_MIN(major, minor, patch) \
|
||||
(CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch)))
|
||||
(major >= 9 ? CUDNN_VERSION >= ((major) * 10000 + (minor) * 100 + (patch)) : \
|
||||
CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch)))
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -135,7 +130,6 @@ class cudnnTypeWrapper<float> {
|
||||
}
|
||||
};
|
||||
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
template <>
|
||||
class cudnnTypeWrapper<int> {
|
||||
public:
|
||||
@ -151,7 +145,6 @@ class cudnnTypeWrapper<int> {
|
||||
return &v;
|
||||
}
|
||||
};
|
||||
#endif // CUDNN_VERSION_MIN(6, 0, 0)
|
||||
|
||||
template <>
|
||||
class cudnnTypeWrapper<double> {
|
||||
|
@ -37,29 +37,9 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
"The current padding scheme leads to unequal padding on the left "
|
||||
"and right, which is not supported by cudnn.");
|
||||
}
|
||||
// dilated convolution supported by some algorithms in cuDNN v6
|
||||
#if !(CUDNN_VERSION_MIN(6, 0, 0))
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
dilation_h() == 1 && dilation_w() == 1,
|
||||
"The cudnn convolution does not support dilation yet.");
|
||||
#endif
|
||||
// dilated grouped convolution supported in cuDNN v7.1
|
||||
#if !(CUDNN_VERSION_MIN(7, 1, 0))
|
||||
if (group_ != 1) {
|
||||
for (int dim = 0; dim < kernel_.size(); ++dim) {
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
dilation_[dim] == 1,
|
||||
"When group is used, dilation should not be set at the same time.");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// verify TensorCore math is supported
|
||||
enable_tensor_core_ &= TensorCoreAvailable();
|
||||
#else
|
||||
enable_tensor_core_ = false;
|
||||
#endif
|
||||
|
||||
bool individual_force_algo = OperatorBase::HasArgument("force_algo_fwd") ||
|
||||
OperatorBase::HasArgument("force_algo_dgrad") ||
|
||||
@ -108,11 +88,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
int H,
|
||||
int W,
|
||||
int D) {
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int CC = C;
|
||||
#else
|
||||
const int CC = C / group_;
|
||||
#endif
|
||||
switch (order_) {
|
||||
case StorageOrder::NHWC:
|
||||
if (size == 4) {
|
||||
@ -182,7 +158,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
int dilation_height = 0;
|
||||
int dilation_width = 0;
|
||||
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor(
|
||||
input,
|
||||
&pad_height,
|
||||
@ -193,19 +168,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
&dilation_width,
|
||||
&mode,
|
||||
&dataType));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor(
|
||||
input,
|
||||
&pad_height,
|
||||
&pad_width,
|
||||
&stride_height,
|
||||
&stride_width,
|
||||
&dilation_height,
|
||||
&dilation_width,
|
||||
&mode));
|
||||
#endif
|
||||
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
copy,
|
||||
pad_height,
|
||||
@ -216,17 +179,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
dilation_width,
|
||||
mode,
|
||||
dataType));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
copy,
|
||||
pad_height,
|
||||
pad_width,
|
||||
stride_height,
|
||||
stride_width,
|
||||
dilation_height,
|
||||
dilation_width,
|
||||
mode));
|
||||
#endif
|
||||
} else {
|
||||
cudnnConvolutionMode_t mode;
|
||||
cudnnDataType_t dataType;
|
||||
@ -278,7 +230,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
}
|
||||
|
||||
void SetConvDescFromArguments() {
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
if (kernel_.size() == 1 || kernel_.size() == 2) {
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
@ -300,29 +251,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
CUDNN_CROSS_CORRELATION,
|
||||
compute_type_));
|
||||
}
|
||||
#else
|
||||
if (kernel_.size() == 2) {
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
pad_t(),
|
||||
pad_l(),
|
||||
stride_h(),
|
||||
stride_w(),
|
||||
1,
|
||||
1,
|
||||
CUDNN_CROSS_CORRELATION));
|
||||
} else {
|
||||
vector<int> ones(dilation_.size(), 1);
|
||||
CUDNN_ENFORCE(cudnnSetConvolutionNdDescriptor(
|
||||
conv_desc_,
|
||||
kernel_.size(),
|
||||
pads_.data(),
|
||||
stride_.data(),
|
||||
ones.data(),
|
||||
CUDNN_CROSS_CORRELATION,
|
||||
compute_type_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void SetConvDescComputeType(
|
||||
@ -338,7 +266,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
int dilation_height = 0;
|
||||
int dilation_width = 0;
|
||||
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor(
|
||||
conv_desc,
|
||||
&pad_height,
|
||||
@ -349,19 +276,7 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
&dilation_width,
|
||||
&mode,
|
||||
&dataType));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnGetConvolution2dDescriptor(
|
||||
conv_desc,
|
||||
&pad_height,
|
||||
&pad_width,
|
||||
&stride_height,
|
||||
&stride_width,
|
||||
&dilation_height,
|
||||
&dilation_width,
|
||||
&mode));
|
||||
#endif
|
||||
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc,
|
||||
pad_height,
|
||||
@ -372,17 +287,6 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
||||
dilation_width,
|
||||
mode,
|
||||
math));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc,
|
||||
pad_height,
|
||||
pad_width,
|
||||
stride_height,
|
||||
stride_width,
|
||||
dilation_height,
|
||||
dilation_width,
|
||||
mode));
|
||||
#endif
|
||||
} else {
|
||||
cudnnConvolutionMode_t mode;
|
||||
cudnnDataType_t dataType;
|
||||
@ -576,9 +480,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
return true;
|
||||
}
|
||||
|
||||
#if !CUDNN_VERSION_MIN(7, 0, 0)
|
||||
int group_offset_filter = filter.numel() / group_;
|
||||
#endif
|
||||
|
||||
// Set up the cudnn algorithms & workspace if necessary
|
||||
bool input_changed = (X.sizes() != cudnn_input_dims_);
|
||||
@ -592,11 +493,7 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
if (filter_changed) {
|
||||
cudnn_filter_dims_ = filter.sizes().vec();
|
||||
if (kernel_.size() == 1 || kernel_.size() == 2) {
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int MM = M;
|
||||
#else
|
||||
const int MM = M / group_;
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||
filter_desc_,
|
||||
cudnnTypeWrapper<T_W>::type,
|
||||
@ -607,12 +504,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
kernel_.size() == 1 ? 1 : kernel_w()));
|
||||
} else {
|
||||
vector<int> dims(filter.sizes().begin(), filter.sizes().end());
|
||||
#if !CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// We only need to divide dims by group_ when CUDNN version < 7.0
|
||||
// see CUDA group convolution doc: https://fburl.com/dgj6dvpd
|
||||
order_ == StorageOrder::NCHW ? dims[1] /= group_
|
||||
: dims[filter.ndim() - 1] /= group_;
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
|
||||
filter_desc_,
|
||||
cudnnTypeWrapper<T_W>::type,
|
||||
@ -674,7 +565,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
compute_type_ = DetermineComputeTypeFromInput(X);
|
||||
SetConvDescFromArguments();
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
if (enable_tensor_core_) {
|
||||
CUDNN_ENFORCE(
|
||||
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
|
||||
@ -682,7 +572,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
|
||||
// enable cuDNN conv groups
|
||||
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
|
||||
#endif
|
||||
|
||||
if (force_algo_[ALGO_FWD] >= 0) {
|
||||
algo_ = (cudnnConvolutionFwdAlgo_t)force_algo_[ALGO_FWD];
|
||||
@ -808,7 +697,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
|
||||
// Now, actually run the computation.
|
||||
// Run directly through cuDNN if possible
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionForward(
|
||||
state->cudnn_handle(),
|
||||
@ -825,27 +713,6 @@ bool CudnnConvOp::DoRunWithType() {
|
||||
top_desc_,
|
||||
Y->template mutable_data<T_Y>()));
|
||||
});
|
||||
#else
|
||||
// otherwise manually run through groups
|
||||
for (int i = 0; i < group_; ++i) {
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionForward(
|
||||
state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T_X>::kOne(),
|
||||
bottom_desc_,
|
||||
X.template data<T_X>() + i * group_offset_X,
|
||||
filter_desc_,
|
||||
filter.template data<T_W>() + i * group_offset_filter,
|
||||
conv_desc_,
|
||||
algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T_Y>::kZero(),
|
||||
top_desc_,
|
||||
Y->template mutable_data<T_Y>() + i * group_offset_Y));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
// Bias
|
||||
if (InputSize() == 3) {
|
||||
auto& bias = Input(BIAS);
|
||||
@ -953,9 +820,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
"If you set group, the number of output channels should be divisible "
|
||||
"by group.");
|
||||
|
||||
#if !CUDNN_VERSION_MIN(7, 0, 0)
|
||||
int group_offset_filter = filter.numel() / group_;
|
||||
#endif
|
||||
if (kernel_.size() == 1) {
|
||||
ConvPoolOpBase<CUDAContext>::ComputePads({H});
|
||||
} else if (kernel_.size() == 2) {
|
||||
@ -1003,11 +867,7 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
if (filter_changed) {
|
||||
cudnn_filter_dims_ = filter.sizes().vec();
|
||||
if (kernel_.size() == 1 || kernel_.size() == 2) {
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int MM = M;
|
||||
#else
|
||||
const int MM = M / group_;
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||
filter_desc_,
|
||||
cudnnTypeWrapper<T_W>::type,
|
||||
@ -1018,12 +878,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
kernel_.size() == 1 ? 1 : kernel_w()));
|
||||
} else {
|
||||
vector<int> dims(filter.sizes().begin(), filter.sizes().end());
|
||||
#if !CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// We only need to divide dims by group_ when CUDNN version < 7.0
|
||||
// see CUDA group convolution doc: https://fburl.com/dgj6dvpd
|
||||
order_ == StorageOrder::NCHW ? dims[1] /= group_
|
||||
: dims[filter.ndim() - 1] /= group_;
|
||||
#endif
|
||||
|
||||
CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
|
||||
filter_desc_,
|
||||
@ -1091,7 +945,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
DuplicateConvDesc(
|
||||
conv_desc_, kernel_.size(), dilation_.size(), bwd_data_conv_desc_);
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
if (enable_tensor_core_) {
|
||||
CUDNN_ENFORCE(cudnnSetConvolutionMathType(
|
||||
bwd_filter_conv_desc_, CUDNN_TENSOR_OP_MATH));
|
||||
@ -1102,7 +955,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
// set cuDNN groups if appropriate
|
||||
CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_filter_conv_desc_, group_));
|
||||
CUDNN_CHECK(cudnnSetConvolutionGroupCount(bwd_data_conv_desc_, group_));
|
||||
#endif
|
||||
|
||||
// Choose dW algorithm
|
||||
if (force_algo_[ALGO_WGRAD] >= 0) {
|
||||
@ -1388,7 +1240,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
dbias->template mutable_data<T_DB>()));
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||
state->cudnn_handle(),
|
||||
@ -1427,45 +1278,6 @@ bool CudnnConvGradientOp::DoRunWithType() {
|
||||
dX->template mutable_data<T_DX>()));
|
||||
}
|
||||
});
|
||||
#else
|
||||
for (int i = 0; i < group_; ++i) {
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||
state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T_X>::kOne(),
|
||||
bottom_desc_,
|
||||
X.template data<T_X>() + i * group_offset_X,
|
||||
top_desc_,
|
||||
dY.template data<T_DY>() + i * group_offset_Y,
|
||||
bwd_filter_conv_desc_,
|
||||
bwd_filter_algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T_DW>::kZero(),
|
||||
filter_desc_,
|
||||
dfilter->template mutable_data<T_DW>() + i * group_offset_filter));
|
||||
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
||||
// Compute the gradient w.r.t. the input.
|
||||
auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
|
||||
dX->ResizeLike(X);
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
|
||||
state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T_W>::kOne(),
|
||||
filter_desc_,
|
||||
filter.template data<T_W>() + i * group_offset_filter,
|
||||
top_desc_,
|
||||
dY.template data<T_DY>() + i * group_offset_Y,
|
||||
bwd_data_conv_desc_,
|
||||
bwd_data_algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T_DX>::kZero(),
|
||||
bottom_desc_,
|
||||
dX->template mutable_data<T_DX>() + i * group_offset_X));
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -77,11 +77,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
|
||||
const int H,
|
||||
const int W,
|
||||
cudnnTensorDescriptor_t* desc) const {
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int CC = C;
|
||||
#else
|
||||
const int CC = C / group_;
|
||||
#endif
|
||||
switch (order_) {
|
||||
case StorageOrder::NCHW: {
|
||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
|
||||
@ -242,11 +238,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
}
|
||||
if (filter_changed) {
|
||||
cudnn_filter_dims_ = filter.sizes().vec();
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int MM = M;
|
||||
#else
|
||||
const int MM = M / group_;
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||
filter_desc_,
|
||||
cudnnTypeWrapper<T>::type,
|
||||
@ -292,7 +284,6 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
"The current padding scheme leads to unequal padding on the left "
|
||||
"and right, which is not supported by cudnn.");
|
||||
// Set the convolution descriptor
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
pad_t(),
|
||||
@ -303,19 +294,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
1,
|
||||
CUDNN_CROSS_CORRELATION,
|
||||
cudnnTypeWrapper<T>::type));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
pad_t(),
|
||||
pad_l(),
|
||||
stride_h(),
|
||||
stride_w(),
|
||||
1,
|
||||
1,
|
||||
CUDNN_CROSS_CORRELATION));
|
||||
#endif
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// enable TensorCore math if desired
|
||||
enable_tensor_core_ &= TensorCoreAvailable();
|
||||
if (enable_tensor_core_) {
|
||||
@ -324,7 +303,6 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
}
|
||||
// set cuDNN groups if appropriate
|
||||
CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_));
|
||||
#endif
|
||||
|
||||
if (force_algo_[ALGO_DGRAD] >= 0) {
|
||||
bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
|
||||
@ -400,7 +378,6 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
|
||||
// Now, actually run the computation.
|
||||
// Filter
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
|
||||
state->cudnn_handle(),
|
||||
@ -417,33 +394,6 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
||||
top_desc_,
|
||||
Y_data));
|
||||
});
|
||||
#else
|
||||
const int X_HxW = H * W;
|
||||
const int Y_HxW = H_out * W_out;
|
||||
const int group_offset_X =
|
||||
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
|
||||
const int group_offset_Y =
|
||||
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
|
||||
const int group_offset_filter = filter.numel() / group_;
|
||||
for (int i = 0; i < group_; ++i) {
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(
|
||||
cudnnConvolutionBackwardData(state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
filter_desc_,
|
||||
filter_data + i * group_offset_filter,
|
||||
bottom_desc_,
|
||||
X_data + i * group_offset_X;
|
||||
conv_desc_,
|
||||
bwd_data_algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T_DX>::kZero(),
|
||||
top_desc_,
|
||||
Y_data + i * group_offset_Y));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
// Bias
|
||||
if (InputSize() == 3) {
|
||||
CUDNN_ENFORCE(cudnnAddTensor(
|
||||
@ -527,11 +477,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
}
|
||||
if (filter_changed) {
|
||||
cudnn_filter_dims_ = filter.sizes().vec();
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
const int MM = M;
|
||||
#else
|
||||
const int MM = M / group_;
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||
filter_desc_,
|
||||
cudnnTypeWrapper<T>::type,
|
||||
@ -576,7 +522,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
pad_r(),
|
||||
"The current padding scheme leads to unequal padding on the left "
|
||||
"and right, which is not supported by cudnn.");
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
pad_t(),
|
||||
@ -587,18 +532,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
1,
|
||||
CUDNN_CROSS_CORRELATION,
|
||||
cudnnTypeWrapper<T>::type));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_,
|
||||
pad_t(),
|
||||
pad_l(),
|
||||
stride_h(),
|
||||
stride_w(),
|
||||
1,
|
||||
1,
|
||||
CUDNN_CROSS_CORRELATION));
|
||||
#endif
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// enable TensorCore math if desired
|
||||
enable_tensor_core_ &= TensorCoreAvailable();
|
||||
if (enable_tensor_core_) {
|
||||
@ -607,7 +540,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
}
|
||||
// set cuDNN groups if appropriate
|
||||
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
|
||||
#endif
|
||||
if (force_algo_[ALGO_WGRAD] >= 0) {
|
||||
bwd_filter_algo_ =
|
||||
(cudnnConvolutionBwdFilterAlgo_t)force_algo_[ALGO_WGRAD];
|
||||
@ -762,7 +694,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
dbias->template mutable_data<T>()));
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||
state->cudnn_handle(),
|
||||
@ -801,55 +732,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||
dX->template mutable_data<T>()));
|
||||
}
|
||||
});
|
||||
#else
|
||||
const int X_HxW = H * W;
|
||||
const int Y_HxW = H_out * W_out;
|
||||
const int group_offset_X =
|
||||
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
|
||||
const int group_offset_Y =
|
||||
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
|
||||
const int group_offset_filter = filter.numel() / group_;
|
||||
for (int i = 0; i < group_; ++i) {
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||
state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
top_desc_,
|
||||
dY.template data<T>() + i * group_offset_Y,
|
||||
bottom_desc_,
|
||||
X.template data<T>() + i * group_offset_X,
|
||||
conv_desc_,
|
||||
bwd_filter_algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
filter_desc_,
|
||||
dfilter->template mutable_data<T>() + i * group_offset_filter));
|
||||
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
||||
// Compute the gradient w.r.t. the input.
|
||||
auto* dX = Output(
|
||||
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
|
||||
X.sizes(),
|
||||
at::dtype<T>());
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_ENFORCE(cudnnConvolutionForward(
|
||||
state->cudnn_handle(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
top_desc_,
|
||||
dY.template data<T>() + i * group_offset_Y,
|
||||
filter_desc_,
|
||||
filter.template data<T>() + i * group_offset_filter,
|
||||
conv_desc_,
|
||||
algo_,
|
||||
state->workspace().get(cudnn_ws_nbytes_),
|
||||
cudnn_ws_nbytes_,
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
bottom_desc_,
|
||||
dX->template mutable_data<T>() + i * group_offset_X));
|
||||
});
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -5,10 +5,6 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// cudnnRestoreDropoutDescriptor is needed for correctness and
|
||||
// doesn't exist prior to cuDNN v7
|
||||
#if CUDNN_VERSION_MIN(7,0,0)
|
||||
|
||||
class CuDNNDropoutOp final : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
@ -293,6 +289,4 @@ REGISTER_CUDNN_OPERATOR(Dropout, CuDNNDropoutOp);
|
||||
REGISTER_CUDNN_OPERATOR(DropoutGrad, CuDNNDropoutGradientOp);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}; // namespace caffe2
|
||||
|
@ -14,7 +14,6 @@ static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
|
||||
// This does not have any performance implications, as we will always find the
|
||||
// fastest algorithm; setting them to the right number of algorithms will enable
|
||||
// us to best report the statistics when doing an exhaustive search, though.
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// Note: Double each of these due to potential
|
||||
// tensorcore + non-tensorcore versions
|
||||
// which are treated as separate returned algos
|
||||
@ -24,11 +23,6 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
|
||||
2 * CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
|
||||
2 * CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
#else
|
||||
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
|
||||
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
|
||||
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename ArrayOfcudnnConvolutionAlgoPerf_t>
|
||||
|
@ -434,11 +434,7 @@ struct CuDNNMaxPoolFunctor {
|
||||
deterministic(op.GetSingleArgument<bool>("deterministic", false)) {}
|
||||
|
||||
cudnnPoolingMode_t GetPoolingMode() const {
|
||||
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||
return deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX;
|
||||
#else
|
||||
return CUDNN_POOLING_MAX;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, StorageOrder kOrder>
|
||||
|
@ -98,7 +98,6 @@ void RecurrentBaseOp<T>::initialize(
|
||||
|
||||
// RNN setup
|
||||
{
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
CUDNN_ENFORCE(cudnnSetRNNDescriptor_v6(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
rnnDesc_,
|
||||
@ -110,17 +109,6 @@ void RecurrentBaseOp<T>::initialize(
|
||||
rnnMode,
|
||||
CUDNN_RNN_ALGO_STANDARD, // TODO: verify correctness / efficiency.
|
||||
cudnnTypeWrapper<T>::type));
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnSetRNNDescriptor(
|
||||
rnnDesc_,
|
||||
hiddenSize,
|
||||
numLayers,
|
||||
dropoutDesc_,
|
||||
rnnInput,
|
||||
rnnDirection,
|
||||
rnnMode,
|
||||
cudnnTypeWrapper<T>::type));
|
||||
#endif
|
||||
}
|
||||
// X setup
|
||||
{
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include "caffe2/operators/spatial_batch_norm_op_impl.cuh"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
#if CUDNN_VERSION_MIN(5, 0, 0)
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -63,16 +62,12 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
|
||||
CuDNNSpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SpatialBNOp<CUDAContext>(operator_def, ws),
|
||||
cudnn_wrapper_(&context_),
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// TODO(T31829456): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
|
||||
// introduced in CuDNN 7 for performance optimization, but it results in
|
||||
// accuracy losses in convolution models such as ResNeXt-101 and
|
||||
// video R(2+1)D. We will fall back to the normal
|
||||
// CUDNN_BATCHNORM_SPATIAL for now
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL) {
|
||||
#else
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL) {
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(¶m_desc_));
|
||||
if (epsilon_ < CUDNN_BN_MIN_EPSILON) {
|
||||
@ -192,7 +187,6 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
|
||||
}
|
||||
const double alpha = static_cast<double>(1.0f - momentum_);
|
||||
|
||||
#if CUDNN_VERSION_MIN(8, 0, 0)
|
||||
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
|
||||
@ -250,26 +244,6 @@ class CuDNNSpatialBNOp final : public SpatialBNOp<CUDAContext> {
|
||||
state->workspace().get(reserve_size),
|
||||
reserve_size));
|
||||
});
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
data_desc_,
|
||||
X_data,
|
||||
data_desc_,
|
||||
Y_data,
|
||||
param_desc_,
|
||||
scale_data,
|
||||
bias_data,
|
||||
alpha,
|
||||
running_mean_data,
|
||||
running_var_data,
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_inv_std_data));
|
||||
#endif // CUDNN_VERSION_MIN(8, 0, 0)
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -290,16 +264,12 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
|
||||
CuDNNSpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SpatialBNGradientOp<CUDAContext>(operator_def, ws),
|
||||
cudnn_wrapper_(&context_),
|
||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||
// TODO(T31829456): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
|
||||
// introduced in CuDNN 7 for performance optimization, but it results in
|
||||
// accuracy losses in convolution models such as ResNeXt-101 and
|
||||
// video R(2+1)D. We will fall back to the normal
|
||||
// CUDNN_BATCHNORM_SPATIAL for now
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL) {
|
||||
#else
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL) {
|
||||
#endif
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(¶m_desc_));
|
||||
if (epsilon_ < CUDNN_BN_MIN_EPSILON) {
|
||||
@ -375,7 +345,6 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
|
||||
data_desc_,
|
||||
param_desc_);
|
||||
}
|
||||
#if CUDNN_VERSION_MIN(8, 0, 0)
|
||||
// Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
|
||||
@ -439,28 +408,6 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp<CUDAContext> {
|
||||
state->workspace().get(reserve_size),
|
||||
reserve_size));
|
||||
});
|
||||
#else
|
||||
CUDNN_ENFORCE(cudnnBatchNormalizationBackward(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
mode_,
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
cudnnTypeWrapper<T>::kZero(),
|
||||
data_desc_,
|
||||
X_data,
|
||||
data_desc_,
|
||||
dY_data,
|
||||
data_desc_,
|
||||
dX_data,
|
||||
param_desc_,
|
||||
scale_data,
|
||||
dscale_data,
|
||||
dbias_data,
|
||||
epsilon_,
|
||||
saved_mean_data,
|
||||
saved_rstd_data));
|
||||
#endif // CUDNN_VERSION_MIN(8, 0, 0)
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -478,5 +425,3 @@ REGISTER_CUDNN_OPERATOR(SpatialBN, CuDNNSpatialBNOp);
|
||||
REGISTER_CUDNN_OPERATOR(SpatialBNGradient, CuDNNSpatialBNGradientOp);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CUDNN_VERSION_MIN(5, 0, 0)
|
||||
|
@ -134,33 +134,6 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
|
||||
std::vector<std::int32_t> axes_;
|
||||
};
|
||||
|
||||
#if !CUDNN_VERSION_MIN(6, 0, 0)
|
||||
|
||||
// CuDNN 5.1 does not have int support yet.
|
||||
template <>
|
||||
bool CuDNNTransposeOp::DoRunWithType<int>() {
|
||||
const auto& X = Input(0);
|
||||
const int ndim = X.dim();
|
||||
if (axes_.empty()) {
|
||||
axes_.resize(ndim);
|
||||
std::iota(axes_.rbegin(), axes_.rend(), 0);
|
||||
} else {
|
||||
CAFFE_ENFORCE_EQ(axes_.size(), ndim);
|
||||
}
|
||||
std::vector<std::int64_t> X_dims = X.sizes().vec();
|
||||
std::vector<std::int64_t> Y_dims(ndim);
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
Y_dims[i] = X_dims[axes_[i]];
|
||||
}
|
||||
auto* Y = Output(0, Y_dims, at::dtype<T>());
|
||||
const T* X_data = X.template data<T>();
|
||||
T* Y_data = Y->template mutable_data<T>();
|
||||
math::Transpose<std::int64_t, T, CUDAContext>(
|
||||
ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_);
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif // !CUDNN_VERSION_MIN(6, 0, 0)
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -251,8 +251,8 @@ if(CAFFE2_USE_CUDNN)
|
||||
"Cannot find cuDNN library. Turning the option off")
|
||||
set(CAFFE2_USE_CUDNN OFF)
|
||||
else()
|
||||
if(CUDNN_VERSION VERSION_LESS "8.0.0")
|
||||
message(FATAL_ERROR "PyTorch requires cuDNN 8 and above.")
|
||||
if(CUDNN_VERSION VERSION_LESS "8.1.0")
|
||||
message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -35,8 +35,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
||||
IS_PPC, \
|
||||
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
|
||||
skipIfTorchDynamo, gcIfJetson, set_default_dtype
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION, \
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
||||
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
@ -4147,7 +4146,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
dtype = torch.double
|
||||
self._test_RNN_cpu_vs_cudnn(0, dtype)
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
def test_RNN_cpu_vs_cudnn_with_dropout(self):
|
||||
# Because of dropout randomness, can only compare dropout=0 and dropout=1
|
||||
self._test_RNN_cpu_vs_cudnn(1)
|
||||
@ -4206,8 +4205,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
warnings.simplefilter("always")
|
||||
self.assertEqual(m(inp)[0].cpu(), out_expected[0])
|
||||
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@set_default_dtype(torch.double)
|
||||
def test_RNN_dropout(self):
|
||||
# checking the assumption that cuDNN sticks dropout in between
|
||||
@ -4251,6 +4249,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertEqual(hy.data[0][0][0], 10)
|
||||
self.assertEqual(hy.data[1][0][0], output_val)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@set_default_dtype(torch.double)
|
||||
def test_error_RNN_seq_len_zero(self):
|
||||
# checking error message when RNN has seq_len = 0
|
||||
@ -4279,7 +4278,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
# Check that backward does not cause a hard error
|
||||
outs[0].sum().backward()
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
def test_RNN_dropout_state(self):
|
||||
for p in (0, 0.1234):
|
||||
for train in (True, False):
|
||||
@ -4319,7 +4318,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertNotEqual(hy1, hy2)
|
||||
self.assertNotEqual(hy1, hy3)
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@set_default_dtype(torch.double)
|
||||
def test_RNN_change_dropout(self):
|
||||
for train, cuda in product((True, False), repeat=2):
|
||||
|
Reference in New Issue
Block a user