diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index c2f7ce2ac2d5..1658ce34ca6c 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -482,7 +482,9 @@ auto build_graph( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") - .set_generate_stats(return_softmaxstats) + .set_is_inference(return_softmaxstats == false) + // TODO(eqy): switch to this API once cuDNN FE is upgraded + // .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { @@ -702,7 +704,9 @@ auto build_graph_nestedtensor( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA_NESTEDTENSOR") - .set_generate_stats(return_softmaxstats) + .set_is_inference(return_softmaxstats == false) + // TODO(eqy): switch to this API once cuDNN FE is upgraded + // .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) .set_seq_len_q(SEQ_LEN_Q_) diff --git a/test/test_transformers.py b/test/test_transformers.py index b90b1ed86ef2..c58fe05d37be 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2823,6 +2823,29 @@ class TestSDPACudaOnly(NNTestCase): for permute_order in permute_orders: test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3]) + @skipIfRocm + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_compiles(self): + q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True) + grad = torch.randn_like(q) + + @torch.compile() + def func(): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention(q, q, q) + out.backward(grad) + return out + + out = func() + + q_cpu = q.float().cpu().detach().clone() + q_cpu.requires_grad = True + grad_cpu = grad.cpu().float() + out_cpu = torch.nn.functional.scaled_dot_product_attention(q_cpu, q_cpu, q_cpu) + out_cpu.backward(grad_cpu) + self.assertEqual(out, out_cpu.cuda().half(), atol=1e-3, rtol=1e-3) + self.assertEqual(q.grad, q_cpu.grad.cuda().half(), atol=7e-3, rtol=5e-3) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index 1a7b4b78db44..f937055efc6d 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 +Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6 diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9f2c9d1c44d6..78b3964d7171 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5710,7 +5710,7 @@ def meta__scaled_dot_product_cudnn_attention( res = alloc_with_matching_layout(query, res_shape) logsum_exp = torch.empty( - (B, H, S_Q), + (B, H, S_Q, 1), dtype=torch.float, device=query.device, )