[FP8][ROCm][Attention] Enable FP8 KV cache on ROCm for V1 (#17870)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-05-11 03:58:45 -04:00
committed by GitHub
parent cd3edfc908
commit 06c0922a69
3 changed files with 17 additions and 8 deletions

View File

@ -9,6 +9,7 @@
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
@ -267,7 +268,7 @@ def chunked_prefill_paged_decode(
assert value_cache.dtype == torch.uint8
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:

View File

@ -1205,7 +1205,9 @@ class EngineArgs:
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
if current_platform.is_rocm():
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()

View File

@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
@ -108,6 +109,8 @@ class TritonAttentionImpl(AttentionImpl):
"are not implemented for "
"TritonAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
def forward(
self,
layer: torch.nn.Module,
@ -161,11 +164,14 @@ class TritonAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
if not current_platform.is_rocm():
# Skip Q quantization on ROCm, since dequantizing back to
# f32 in the attention kernel is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),