[cuDNN][SDPA][submodule] Roll-back cuDNN frontend upgrade, update Meta registration (#163104)

For https://github.com/pytorch/torchtitan/issues/1713

Also note that we will need to rollback the cuDNN frontend upgrade in 2.9 as it currently introduces a segmentation fault by assuming tensors have their strides and sizes populated at graph creation time 1a7b4b78db/include/cudnn_frontend/node/sdpa_support_surface.h (L447%C2%A0)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163104
Approved by: https://github.com/drisspg
This commit is contained in:
Eddie Yan
2025-09-17 15:48:54 +00:00
committed by PyTorch MergeBot
parent 16475a829f
commit 9b7a8c4d05
4 changed files with 31 additions and 4 deletions

View File

@ -482,7 +482,9 @@ auto build_graph(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA")
.set_generate_stats(return_softmaxstats)
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
@ -702,7 +704,9 @@ auto build_graph_nestedtensor(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA_NESTEDTENSOR")
.set_generate_stats(return_softmaxstats)
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_seq_len_q(SEQ_LEN_Q_)

View File

@ -2823,6 +2823,29 @@ class TestSDPACudaOnly(NNTestCase):
for permute_order in permute_orders:
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_compiles(self):
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
grad = torch.randn_like(q)
@torch.compile()
def func():
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(q, q, q)
out.backward(grad)
return out
out = func()
q_cpu = q.float().cpu().detach().clone()
q_cpu.requires_grad = True
grad_cpu = grad.cpu().float()
out_cpu = torch.nn.functional.scaled_dot_product_attention(q_cpu, q_cpu, q_cpu)
out_cpu.backward(grad_cpu)
self.assertEqual(out, out_cpu.cuda().half(), atol=1e-3, rtol=1e-3)
self.assertEqual(q.grad, q_cpu.grad.cuda().half(), atol=7e-3, rtol=5e-3)
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4])
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):

View File

@ -5710,7 +5710,7 @@ def meta__scaled_dot_product_cudnn_attention(
res = alloc_with_matching_layout(query, res_shape)
logsum_exp = torch.empty(
(B, H, S_Q),
(B, H, S_Q, 1),
dtype=torch.float,
device=query.device,
)