[ROCm][V0][Attention] Revert to the previous FA triton kernel (#18226)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-05-29 12:13:18 -04:00
committed by GitHub
parent da4b69d0b4
commit 1b7cfd5a36
3 changed files with 692 additions and 1081 deletions

View File

@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale,
layer._prob_scale) if use_fp8_scales else None
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(),
layer._prob_scale.item()) if use_fp8_scales else None
self.triton_attn_func(
query,
key,

File diff suppressed because it is too large Load Diff

View File

@ -98,6 +98,12 @@ def with_amdsmi_context(fn):
return wrapper
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@cache
def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName