[Misc] Simplify FlashInfer attention metadata (#23585)

Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon
2025-08-25 15:42:29 -07:00
committed by GitHub
parent 7b6a837275
commit efc88cf64a

View File

@ -123,29 +123,9 @@ class FlashInferMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr_cpu: torch.Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
paged_kv_indptr_cpu: torch.Tensor
# The page indices of the paged kv cache (on device for plan)
paged_kv_indices: torch.Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size] (CPU for plan)
paged_kv_last_page_len_cpu: torch.Tensor
# The data type of the query
q_data_type: torch.dtype
seq_lens_cpu: torch.Tensor
slot_mapping: torch.Tensor
# For flashinfer trtllm batch decode
@ -164,10 +144,6 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning).
use_cascade: bool
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
@ -327,134 +303,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata):
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[
attn_metadata.shared_qo_indptr_cpu,
attn_metadata.qo_indptr_cpu
],
[
attn_metadata.shared_kv_page_indptr_cpu,
attn_metadata.paged_kv_indptr_cpu
],
[
attn_metadata.shared_kv_page_indices_cpu,
attn_metadata.paged_kv_indices
],
[
attn_metadata.shared_kv_last_page_len_cpu,
attn_metadata.paged_kv_last_page_len_cpu
],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert attn_metadata.paged_kv_last_page_len_cpu[
prefill_start:].shape[0] == num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
prefill_start:]
if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
attn_metadata.paged_kv_indices,
attn_metadata.
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device)
if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
num_input_tokens].fill_(
attn_metadata.
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decodes:num_input_tokens].fill_(1)
else:
num_input_tokens = num_decodes
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
if not attn_metadata.decode_use_trtllm:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
attn_metadata.paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
attn_metadata.seq_lens_cpu[:num_input_tokens],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@ -548,13 +396,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs],
q_data_type=self.q_data_type,
seq_lens_cpu=seq_lens_cpu,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_seq_len=max_seq_len,
@ -567,14 +409,123 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
)
self._plan(attn_metadata)
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu],
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
[shared_kv_page_indices_cpu, paged_kv_indices],
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert qo_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert paged_kv_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert paged_kv_last_page_len_cpu[prefill_start:].shape[
0] == num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
prefill_start]
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device)
if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
num_input_tokens].fill_(
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decodes:num_input_tokens].fill_(1)
else:
num_input_tokens = num_decodes
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
if not attn_metadata.decode_use_trtllm:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
return attn_metadata
def build_for_cudagraph_capture(