[FP8][ROCm][Attention] Enable FP8 KV cache on ROCm for V1 (#17870)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-05-11 03:58:45 -04:00
committed by GitHub
parent cd3edfc908
commit 06c0922a69
3 changed files with 17 additions and 8 deletions

View File

@ -9,6 +9,7 @@
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
@ -267,7 +268,7 @@ def chunked_prefill_paged_decode(
assert value_cache.dtype == torch.uint8 assert value_cache.dtype == torch.uint8
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2": elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2 target_dtype = torch.float8_e5m2
else: else:

View File

@ -1205,7 +1205,9 @@ class EngineArgs:
and not envs.is_set("VLLM_ATTENTION_BACKEND") and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False supported = False
if fp8_attention and will_use_fa: if current_platform.is_rocm():
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
supported = flash_attn_supports_fp8() supported = flash_attn_supports_fp8()

View File

@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import ( from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder) FlashAttentionMetadata, FlashAttentionMetadataBuilder)
@ -108,6 +109,8 @@ class TritonAttentionImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"TritonAttentionImpl") "TritonAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -161,15 +164,18 @@ class TritonAttentionImpl(AttentionImpl):
) )
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(torch.float8_e4m3fn) value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \ assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported." "A non 1.0 q_scale is not currently supported."
query, _ = ops.scaled_fp8_quant( if not current_platform.is_rocm():
query.reshape( # Skip Q quantization on ROCm, since dequantizing back to
(num_tokens, num_heads * head_size)).contiguous(), # f32 in the attention kernel is not supported.
layer._q_scale) query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size)) query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \ use_local_attn = \