mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cuDNN][RNN] cuDNN RNN supports BFloat16 inputs since 9.13 (#164411)
seems to work Pull Request resolved: https://github.com/pytorch/pytorch/pull/164411 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -326,6 +326,23 @@ bool CUDAHooks::supportsBFloat16ConvolutionWithCuDNNv8() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::supportsBFloat16RNNWithCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED() && (CUDNN_VERSION >= 91300)
|
||||
if (!hasCUDA()) {
|
||||
return false;
|
||||
}
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
// Check for Volta cores
|
||||
if (prop->major >= 8) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
return CUDNN_VERSION;
|
||||
|
@ -45,6 +45,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool supportsDilatedConvolutionWithCuDNN() const override;
|
||||
bool supportsDepthwiseConvolutionWithCuDNN() const override;
|
||||
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
|
||||
bool supportsBFloat16RNNWithCuDNN() const override;
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
|
@ -166,6 +166,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool supportsBFloat16RNNWithCuDNN() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual long versionCuDNN() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
@ -93,6 +93,12 @@ inline bool cond_cudnn_grid_sampler(
|
||||
const TensorBase& input,
|
||||
const TensorBase& grid
|
||||
) {
|
||||
auto st = input.scalar_type();
|
||||
if (!(st == kDouble || st == kFloat || st == kHalf))
|
||||
return false;
|
||||
st = grid.scalar_type();
|
||||
if (!(st == kDouble || st == kFloat || st == kHalf))
|
||||
return false;
|
||||
return (
|
||||
at::native::cudnn_is_acceptable(input) &&
|
||||
at::native::cudnn_is_acceptable(grid) &&
|
||||
|
@ -108,6 +108,13 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool use_cudnn(const Tensor& t) {
|
||||
bool acceptable = at::cudnn_is_acceptable(t);
|
||||
auto st = t.scalar_type();
|
||||
bool bfloat16_cond = st == kBFloat16 && at::detail::getCUDAHooks().supportsBFloat16RNNWithCuDNN();
|
||||
return acceptable && (bfloat16_cond || st == kDouble || st == kFloat || st == kHalf);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
using pair_of = std::pair<T, T>;
|
||||
|
||||
@ -1200,7 +1207,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
|
||||
bool train, \
|
||||
bool bidirectional, \
|
||||
bool batch_first) { \
|
||||
if (at::cudnn_is_acceptable(_input)) { \
|
||||
if (use_cudnn(_input)) { \
|
||||
Tensor output, hy; \
|
||||
NAME##_cudnn_stub( \
|
||||
_input.device().type(), \
|
||||
@ -1262,7 +1269,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
|
||||
double dropout_p, \
|
||||
bool train, \
|
||||
bool bidirectional) { \
|
||||
if (at::cudnn_is_acceptable(data)) { \
|
||||
if (use_cudnn(data)) { \
|
||||
Tensor output, hy; \
|
||||
NAME##_packed_cudnn_stub( \
|
||||
data.device().type(), \
|
||||
@ -1430,7 +1437,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
|
||||
TensorList _params, bool has_biases,
|
||||
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
|
||||
if (at::cudnn_is_acceptable(_input)) {
|
||||
if (use_cudnn(_input)) {
|
||||
Tensor output, hy, cy;
|
||||
lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
|
||||
num_layers, dropout_p, train, bidirectional, batch_first);
|
||||
@ -1491,7 +1498,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
|
||||
TensorList _params, bool has_biases,
|
||||
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
|
||||
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
|
||||
if (at::cudnn_is_acceptable(data)) {
|
||||
if (use_cudnn(data)) {
|
||||
Tensor output, hy, cy;
|
||||
lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
|
||||
_params, has_biases, num_layers, dropout_p, train, bidirectional);
|
||||
|
@ -91,9 +91,6 @@ bool cudnn_is_acceptable(const TensorBase& self) {
|
||||
return false;
|
||||
if (!self.is_cuda())
|
||||
return false;
|
||||
auto st = self.scalar_type();
|
||||
if (!(st == kDouble || st == kFloat || st == kHalf))
|
||||
return false;
|
||||
if (!detail::getCUDAHooks().compiledWithCuDNN())
|
||||
return false;
|
||||
// cuDNN functions like grid_sampler returns CUDNN_STATUS_BAD_PARAM on empty
|
||||
|
@ -1222,7 +1222,7 @@ cudnnRNNAlgo_t get_algo(
|
||||
}
|
||||
|
||||
cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
|
||||
if (dtype == CUDNN_DATA_HALF) {
|
||||
if (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) {
|
||||
return CUDNN_DATA_FLOAT;
|
||||
}
|
||||
return dtype;
|
||||
|
@ -11224,6 +11224,16 @@ class TestNNDeviceType(NNTestCase):
|
||||
out_ref = m(inp_ref)
|
||||
self.assertEqual(out_ref, out)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
def test_cudnn_rnn(self, dtype):
|
||||
rnn = nn.RNN(10, 20, num_layers=2, device='cuda', dtype=dtype)
|
||||
input = torch.randn(5, 4, 10, device='cuda', dtype=dtype)
|
||||
hx = torch.randn(2, 4, 20, device='cuda', dtype=dtype)
|
||||
output = rnn(input, hx)
|
||||
output_ref = rnn.cpu()(input.cpu(), hx.cpu())
|
||||
self.assertEqual(tuple([i.cuda() for i in output_ref]), output, atol=5e-3, rtol=1e-3)
|
||||
|
||||
@onlyCUDA
|
||||
@gcIfJetson
|
||||
def test_upsamplingNearest3d_launch_config(self, device):
|
||||
|
Reference in New Issue
Block a user