SDPA skip logic for ROCm (#160522)

Skips some test for flex and eff attention if they are not supported by the hardware

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160522
Approved by: https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
AmdSampsa
2025-08-26 15:51:07 +00:00
committed by PyTorch MergeBot
parent a72803f1e3
commit f9df4ec2af
2 changed files with 17 additions and 0 deletions

View File

@ -38,7 +38,9 @@ from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM80OrLater,
tf32_on_and_off,
)
@ -1451,6 +1453,12 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs)
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
@unittest.skipIf(
# for archs where this isn't lowered to flash attention, the math
# backend will be used and it doesn't work for bfloat16
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Some archs don't support SDPA with bfloat16",
)
def test_sdpa_2(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -1723,6 +1731,9 @@ class AOTInductorTestsTemplate:
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
)
def test_fallback_kernel_with_symexpr_output(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@ -4293,6 +4304,9 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
)
def test_scaled_dot_product_efficient_attention(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")

View File

@ -11644,6 +11644,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
@xfail_if_mps_unimplemented
@expectedFailureXPU
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
)
def test_scaled_dot_product_efficient_attention(self):
if self.device == "cpu":
raise unittest.SkipTest(f"requires {GPU_TYPE}")