mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
SDPA skip logic for ROCm (#160522)
Skips some test for flex and eff attention if they are not supported by the hardware Pull Request resolved: https://github.com/pytorch/pytorch/pull/160522 Approved by: https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
a72803f1e3
commit
f9df4ec2af
@ -38,7 +38,9 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
SM80OrLater,
|
||||
tf32_on_and_off,
|
||||
)
|
||||
@ -1451,6 +1453,12 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
||||
@unittest.skipIf(
|
||||
# for archs where this isn't lowered to flash attention, the math
|
||||
# backend will be used and it doesn't work for bfloat16
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Some archs don't support SDPA with bfloat16",
|
||||
)
|
||||
def test_sdpa_2(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1723,6 +1731,9 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
||||
|
||||
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
|
||||
)
|
||||
def test_fallback_kernel_with_symexpr_output(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
@ -4293,6 +4304,9 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
|
||||
)
|
||||
def test_scaled_dot_product_efficient_attention(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
@ -11644,6 +11644,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@expectedFailureXPU
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
|
||||
)
|
||||
def test_scaled_dot_product_efficient_attention(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
|
Reference in New Issue
Block a user