[cuDNN][RNN] cuDNN RNN supports BFloat16 inputs since 9.13 (#164411)

seems to work

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164411
Approved by: https://github.com/Skylion007
This commit is contained in:
eqy
2025-10-08 15:26:50 +00:00
committed by PyTorch MergeBot
parent 90c0825e2d
commit 0d39ecb2ce
8 changed files with 50 additions and 8 deletions

View File

@ -326,6 +326,23 @@ bool CUDAHooks::supportsBFloat16ConvolutionWithCuDNNv8() const {
#endif
}
bool CUDAHooks::supportsBFloat16RNNWithCuDNN() const {
#if AT_CUDNN_ENABLED() && (CUDNN_VERSION >= 91300)
if (!hasCUDA()) {
return false;
}
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
// Check for Volta cores
if (prop->major >= 8) {
return true;
} else {
return false;
}
#else
return false;
#endif
}
long CUDAHooks::versionCuDNN() const {
#if AT_CUDNN_ENABLED()
return CUDNN_VERSION;

View File

@ -45,6 +45,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool supportsDilatedConvolutionWithCuDNN() const override;
bool supportsDepthwiseConvolutionWithCuDNN() const override;
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
bool supportsBFloat16RNNWithCuDNN() const override;
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;

View File

@ -166,6 +166,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
return false;
}
virtual bool supportsBFloat16RNNWithCuDNN() const {
return false;
}
virtual long versionCuDNN() const {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -93,6 +93,12 @@ inline bool cond_cudnn_grid_sampler(
const TensorBase& input,
const TensorBase& grid
) {
auto st = input.scalar_type();
if (!(st == kDouble || st == kFloat || st == kHalf))
return false;
st = grid.scalar_type();
if (!(st == kDouble || st == kFloat || st == kHalf))
return false;
return (
at::native::cudnn_is_acceptable(input) &&
at::native::cudnn_is_acceptable(grid) &&

View File

@ -108,6 +108,13 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
return false;
}
bool use_cudnn(const Tensor& t) {
bool acceptable = at::cudnn_is_acceptable(t);
auto st = t.scalar_type();
bool bfloat16_cond = st == kBFloat16 && at::detail::getCUDAHooks().supportsBFloat16RNNWithCuDNN();
return acceptable && (bfloat16_cond || st == kDouble || st == kFloat || st == kHalf);
}
template<typename T>
using pair_of = std::pair<T, T>;
@ -1200,7 +1207,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
bool train, \
bool bidirectional, \
bool batch_first) { \
if (at::cudnn_is_acceptable(_input)) { \
if (use_cudnn(_input)) { \
Tensor output, hy; \
NAME##_cudnn_stub( \
_input.device().type(), \
@ -1262,7 +1269,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
double dropout_p, \
bool train, \
bool bidirectional) { \
if (at::cudnn_is_acceptable(data)) { \
if (use_cudnn(data)) { \
Tensor output, hy; \
NAME##_packed_cudnn_stub( \
data.device().type(), \
@ -1430,7 +1437,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
TensorList _params, bool has_biases,
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(_input)) {
if (use_cudnn(_input)) {
Tensor output, hy, cy;
lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
@ -1491,7 +1498,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
TensorList _params, bool has_biases,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(data)) {
if (use_cudnn(data)) {
Tensor output, hy, cy;
lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);

View File

@ -91,9 +91,6 @@ bool cudnn_is_acceptable(const TensorBase& self) {
return false;
if (!self.is_cuda())
return false;
auto st = self.scalar_type();
if (!(st == kDouble || st == kFloat || st == kHalf))
return false;
if (!detail::getCUDAHooks().compiledWithCuDNN())
return false;
// cuDNN functions like grid_sampler returns CUDNN_STATUS_BAD_PARAM on empty

View File

@ -1222,7 +1222,7 @@ cudnnRNNAlgo_t get_algo(
}
cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
if (dtype == CUDNN_DATA_HALF) {
if (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) {
return CUDNN_DATA_FLOAT;
}
return dtype;

View File

@ -11224,6 +11224,16 @@ class TestNNDeviceType(NNTestCase):
out_ref = m(inp_ref)
self.assertEqual(out_ref, out)
@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
def test_cudnn_rnn(self, dtype):
rnn = nn.RNN(10, 20, num_layers=2, device='cuda', dtype=dtype)
input = torch.randn(5, 4, 10, device='cuda', dtype=dtype)
hx = torch.randn(2, 4, 20, device='cuda', dtype=dtype)
output = rnn(input, hx)
output_ref = rnn.cpu()(input.cpu(), hx.cpu())
self.assertEqual(tuple([i.cuda() for i in output_ref]), output, atol=5e-3, rtol=1e-3)
@onlyCUDA
@gcIfJetson
def test_upsamplingNearest3d_launch_config(self, device):