mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cuDNN][SDPA][submodule] Roll-back cuDNN frontend upgrade, update Meta registration (#163104)
For https://github.com/pytorch/torchtitan/issues/1713
Also note that we will need to rollback the cuDNN frontend upgrade in 2.9 as it currently introduces a segmentation fault by assuming tensors have their strides and sizes populated at graph creation time 1a7b4b78db/include/cudnn_frontend/node/sdpa_support_surface.h (L447%C2%A0)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163104
Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
16475a829f
commit
9b7a8c4d05
@ -482,7 +482,9 @@ auto build_graph(
|
||||
auto scaled_dot_product_flash_attention_options =
|
||||
fe::graph::SDPA_attributes()
|
||||
.set_name("CUDNN_SDPA")
|
||||
.set_generate_stats(return_softmaxstats)
|
||||
.set_is_inference(return_softmaxstats == false)
|
||||
// TODO(eqy): switch to this API once cuDNN FE is upgraded
|
||||
// .set_generate_stats(return_softmaxstats)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale);
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
@ -702,7 +704,9 @@ auto build_graph_nestedtensor(
|
||||
auto scaled_dot_product_flash_attention_options =
|
||||
fe::graph::SDPA_attributes()
|
||||
.set_name("CUDNN_SDPA_NESTEDTENSOR")
|
||||
.set_generate_stats(return_softmaxstats)
|
||||
.set_is_inference(return_softmaxstats == false)
|
||||
// TODO(eqy): switch to this API once cuDNN FE is upgraded
|
||||
// .set_generate_stats(return_softmaxstats)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale)
|
||||
.set_seq_len_q(SEQ_LEN_Q_)
|
||||
|
@ -2823,6 +2823,29 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
for permute_order in permute_orders:
|
||||
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_compiles(self):
|
||||
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
|
||||
grad = torch.randn_like(q)
|
||||
|
||||
@torch.compile()
|
||||
def func():
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, q, q)
|
||||
out.backward(grad)
|
||||
return out
|
||||
|
||||
out = func()
|
||||
|
||||
q_cpu = q.float().cpu().detach().clone()
|
||||
q_cpu.requires_grad = True
|
||||
grad_cpu = grad.cpu().float()
|
||||
out_cpu = torch.nn.functional.scaled_dot_product_attention(q_cpu, q_cpu, q_cpu)
|
||||
out_cpu.backward(grad_cpu)
|
||||
self.assertEqual(out, out_cpu.cuda().half(), atol=1e-3, rtol=1e-3)
|
||||
self.assertEqual(q.grad, q_cpu.grad.cuda().half(), atol=7e-3, rtol=5e-3)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):
|
||||
|
2
third_party/cudnn_frontend
vendored
2
third_party/cudnn_frontend
vendored
Submodule third_party/cudnn_frontend updated: 1a7b4b78db...f937055efc
@ -5710,7 +5710,7 @@ def meta__scaled_dot_product_cudnn_attention(
|
||||
res = alloc_with_matching_layout(query, res_shape)
|
||||
|
||||
logsum_exp = torch.empty(
|
||||
(B, H, S_Q),
|
||||
(B, H, S_Q, 1),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
Reference in New Issue
Block a user