mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[FP8][ROCm][Attention] Enable FP8 KV cache on ROCm for V1 (#17870)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
cd3edfc908
commit
06c0922a69
@ -9,6 +9,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
@ -267,7 +268,7 @@ def chunked_prefill_paged_decode(
|
|||||||
assert value_cache.dtype == torch.uint8
|
assert value_cache.dtype == torch.uint8
|
||||||
|
|
||||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
target_dtype = torch.float8_e4m3fn
|
target_dtype = current_platform.fp8_dtype()
|
||||||
elif kv_cache_dtype == "fp8_e5m2":
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
target_dtype = torch.float8_e5m2
|
target_dtype = torch.float8_e5m2
|
||||||
else:
|
else:
|
||||||
|
@ -1205,7 +1205,9 @@ class EngineArgs:
|
|||||||
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||||
supported = False
|
supported = False
|
||||||
if fp8_attention and will_use_fa:
|
if current_platform.is_rocm():
|
||||||
|
supported = True
|
||||||
|
elif fp8_attention and will_use_fa:
|
||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8)
|
flash_attn_supports_fp8)
|
||||||
supported = flash_attn_supports_fp8()
|
supported = flash_attn_supports_fp8()
|
||||||
|
@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
from vllm.v1.attention.backends.flash_attn import (
|
||||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||||
|
|
||||||
@ -108,6 +109,8 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"TritonAttentionImpl")
|
"TritonAttentionImpl")
|
||||||
|
|
||||||
|
self.fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -161,15 +164,18 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
assert layer._q_scale == 1.0, \
|
assert layer._q_scale == 1.0, \
|
||||||
"A non 1.0 q_scale is not currently supported."
|
"A non 1.0 q_scale is not currently supported."
|
||||||
query, _ = ops.scaled_fp8_quant(
|
if not current_platform.is_rocm():
|
||||||
query.reshape(
|
# Skip Q quantization on ROCm, since dequantizing back to
|
||||||
(num_tokens, num_heads * head_size)).contiguous(),
|
# f32 in the attention kernel is not supported.
|
||||||
layer._q_scale)
|
query, _ = ops.scaled_fp8_quant(
|
||||||
|
query.reshape(
|
||||||
|
(num_tokens, num_heads * head_size)).contiguous(),
|
||||||
|
layer._q_scale)
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
use_local_attn = \
|
use_local_attn = \
|
||||||
|
Reference in New Issue
Block a user