mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm][V0][Attention] Revert to the previous FA triton kernel (#18226)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
da4b69d0b4
commit
1b7cfd5a36
@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
and layer._v_scale and layer._prob_scale
|
||||
and self.kv_cache_dtype == "fp8")
|
||||
full_scales = (
|
||||
layer._q_scale, layer._k_scale, layer._v_scale,
|
||||
layer._prob_scale) if use_fp8_scales else None
|
||||
layer._q_scale.item(), layer._k_scale.item(),
|
||||
layer._v_scale.item(),
|
||||
layer._prob_scale.item()) if use_fp8_scales else None
|
||||
self.triton_attn_func(
|
||||
query,
|
||||
key,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -98,6 +98,12 @@ def with_amdsmi_context(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx1x() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi250_mi300() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
Reference in New Issue
Block a user