mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90
, sm100
(#149282)
cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/149282 Approved by: https://github.com/drisspg
This commit is contained in:
@ -6746,11 +6746,10 @@ torch.cuda.synchronize()
|
||||
and check_cudnn
|
||||
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
):
|
||||
check_forward_backward()
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
):
|
||||
check_forward_backward()
|
||||
|
||||
@skipIfTorchDynamo("SDPA test compiles internally")
|
||||
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
||||
|
Reference in New Issue
Block a user