mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention (#24197)
Signed-off-by: Xiaozhu <mxz297@gmail.com>
This commit is contained in:
@ -163,6 +163,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
|
||||
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
|
||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
@ -1155,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_TRTLLM_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||
|
||||
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),
|
||||
|
||||
# If set, it means we pre-downloaded cubin files and flashinfer will
|
||||
# read the cubin files directly.
|
||||
"VLLM_HAS_FLASHINFER_CUBIN":
|
||||
@ -1310,6 +1315,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||
"VLLM_USE_CUDNN_PREFILL",
|
||||
"VLLM_USE_TRTLLM_ATTENTION",
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||
"VLLM_ROCM_USE_AITER",
|
||||
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
|
||||
"VLLM_ROCM_USE_AITER_LINEAR",
|
||||
|
@ -200,11 +200,6 @@ def use_trtllm_attention(
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
# TRTLLM prefill attention does not support FP8 kv cache with
|
||||
# non-quantized query
|
||||
if is_prefill and kv_cache_dtype.startswith("fp8"):
|
||||
return False
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
# the only backend that supports them
|
||||
if has_sinks:
|
||||
@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
|
||||
return output
|
||||
|
||||
|
||||
@functools.cache
|
||||
def flashinfer_disable_q_quantization() -> bool:
|
||||
"""Cache result which only depends on the environment"""
|
||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
|
@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
|
||||
supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@triton.jit
|
||||
def _trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache_ptr,
|
||||
block_tables_prefill_ptr,
|
||||
block_table_stride,
|
||||
mock_kv_cache_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
K_CACHE_STRIDE: tl.constexpr,
|
||||
KV_CACHE_STRIDE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0).to(tl.int64)
|
||||
mock_block_table_idx = tl.program_id(1).to(tl.int64)
|
||||
orig_page_num = tl.load(block_tables_prefill_ptr +
|
||||
batch_idx * block_table_stride +
|
||||
mock_block_table_idx).to(tl.int64)
|
||||
if orig_page_num <= 0:
|
||||
return
|
||||
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
|
||||
|
||||
# Dequantize K
|
||||
k_scale_val = tl.load(k_scale_ptr)
|
||||
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
|
||||
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
|
||||
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
# Dequantize V
|
||||
v_scale_val = tl.load(v_scale_ptr)
|
||||
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
|
||||
tl.arange(0, K_CACHE_STRIDE))
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
|
||||
mock_cache_offset = (
|
||||
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
|
||||
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
|
||||
def trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache: torch.Tensor,
|
||||
block_tables_prefill: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
dequant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_of_page_per_token = block_tables_prefill.shape
|
||||
s = kv_cache.shape
|
||||
assert s[1] == 2
|
||||
assert dequant_dtype in (torch.bfloat16, torch.float16)
|
||||
k_cache_stride = s[2] * s[3] * s[4]
|
||||
kv_cache_stride = k_cache_stride * s[1]
|
||||
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
|
||||
# mock kv cache contains just the pages needed by this prefill
|
||||
mock_kv_cache = torch.empty(new_s,
|
||||
dtype=dequant_dtype,
|
||||
device=kv_cache.device)
|
||||
# we simply sequentially index the pages needed by this prefill
|
||||
mock_block_table = torch.arange(
|
||||
start=1,
|
||||
end=batch_size * num_of_page_per_token + 1,
|
||||
dtype=torch.int32,
|
||||
device=block_tables_prefill.device,
|
||||
).reshape(batch_size, num_of_page_per_token)
|
||||
grid = (batch_size, num_of_page_per_token)
|
||||
_trtllm_prefill_attn_kvfp8_dequant[grid](
|
||||
kv_cache,
|
||||
block_tables_prefill,
|
||||
num_of_page_per_token,
|
||||
mock_kv_cache,
|
||||
k_scale,
|
||||
v_scale,
|
||||
k_cache_stride,
|
||||
kv_cache_stride,
|
||||
)
|
||||
return mock_kv_cache, mock_block_table
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
|
||||
# The data type of the query
|
||||
@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL
|
||||
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL)
|
||||
if self.enable_cuda_graph:
|
||||
# For full cudagraph capture, one `decode_wrapper` for each batch
|
||||
# size is needed for FlashInfer.
|
||||
@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
if supports_trtllm_attention()[0]:
|
||||
if supports_trtllm_attention()[0] and \
|
||||
not flashinfer_disable_q_quantization():
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert self.o_sf_scale is None
|
||||
out = output[num_decode_tokens:]
|
||||
|
||||
if attn_metadata.q_data_type != FP8_DTYPE \
|
||||
and self.kv_cache_dtype.startswith("fp8"):
|
||||
# TRTLLM prefill attention does not support BF16 Q
|
||||
# and fp8 kv cache. So to enable prefill attention
|
||||
# with fp8 kv cache, we can construct a mock block
|
||||
# and mock kv cache with BF16 KV involved in the prefill
|
||||
mock_kv_cache, mock_block_table = (
|
||||
trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache_permute,
|
||||
block_tables_prefill,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
attn_metadata.q_data_type,
|
||||
))
|
||||
else:
|
||||
mock_kv_cache = kv_cache_permute
|
||||
mock_block_table = block_tables_prefill
|
||||
|
||||
trtllm_batch_context_with_kv_cache(
|
||||
query=prefill_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
kv_cache=mock_kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_prefill,
|
||||
block_tables=mock_block_table,
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
block_tables_decode = attn_metadata.\
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
|
Reference in New Issue
Block a user