[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention (#24197)

Signed-off-by: Xiaozhu <mxz297@gmail.com>
This commit is contained in:
Xiaozhu Meng
2025-09-11 14:20:09 -07:00
committed by GitHub
parent 074854b24f
commit e42af78b18
3 changed files with 121 additions and 14 deletions

View File

@ -163,6 +163,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
@ -1155,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
@ -1310,6 +1315,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER",
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
"VLLM_ROCM_USE_AITER_LINEAR",

View File

@ -200,11 +200,6 @@ def use_trtllm_attention(
logger.info_once("Using TRTLLM attention (query is quantized).")
return True
# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if is_prefill and kv_cache_dtype.startswith("fp8"):
return False
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",

View File

@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block
@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8
logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend):
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
kv_cache_ptr,
block_tables_prefill_ptr,
block_table_stride,
mock_kv_cache_ptr,
k_scale_ptr,
v_scale_ptr,
K_CACHE_STRIDE: tl.constexpr,
KV_CACHE_STRIDE: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
orig_page_num = tl.load(block_tables_prefill_ptr +
batch_idx * block_table_stride +
mock_block_table_idx).to(tl.int64)
if orig_page_num <= 0:
return
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
# Dequantize K
k_scale_val = tl.load(k_scale_ptr)
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
# Dequantize V
v_scale_val = tl.load(v_scale_ptr)
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
tl.arange(0, K_CACHE_STRIDE))
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
mock_cache_offset = (
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
def trtllm_prefill_attn_kvfp8_dequant(
kv_cache: torch.Tensor,
block_tables_prefill: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, num_of_page_per_token = block_tables_prefill.shape
s = kv_cache.shape
assert s[1] == 2
assert dequant_dtype in (torch.bfloat16, torch.float16)
k_cache_stride = s[2] * s[3] * s[4]
kv_cache_stride = k_cache_stride * s[1]
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s,
dtype=dequant_dtype,
device=kv_cache.device)
# we simply sequentially index the pages needed by this prefill
mock_block_table = torch.arange(
start=1,
end=batch_size * num_of_page_per_token + 1,
dtype=torch.int32,
device=block_tables_prefill.device,
).reshape(batch_size, num_of_page_per_token)
grid = (batch_size, num_of_page_per_token)
_trtllm_prefill_attn_kvfp8_dequant[grid](
kv_cache,
block_tables_prefill,
num_of_page_per_token,
mock_kv_cache,
k_scale,
v_scale,
k_cache_stride,
kv_cache_stride,
)
return mock_kv_cache, mock_block_table
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend):
@dataclass
class FlashInferMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# The data type of the query
@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL)
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
if supports_trtllm_attention()[0]:
if supports_trtllm_attention()[0] and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
if attn_metadata.q_data_type != FP8_DTYPE \
and self.kv_cache_dtype.startswith("fp8"):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
mock_kv_cache, mock_block_table = (
trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute,
block_tables_prefill,
layer._k_scale,
layer._v_scale,
attn_metadata.q_data_type,
))
else:
mock_kv_cache = kv_cache_permute
mock_block_table = block_tables_prefill
trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
kv_cache=mock_kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables_prefill,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl):
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND