[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282)

cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149282
Approved by: https://github.com/drisspg
This commit is contained in:
eqy
2025-05-14 01:39:24 +00:00
committed by PyTorch MergeBot
parent 8521a690f7
commit 9386701b51
11 changed files with 999 additions and 425 deletions

View File

@ -6746,11 +6746,10 @@ torch.cuda.synchronize()
and check_cudnn
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")