mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Simplify FlashInfer attention metadata (#23585)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
@ -123,29 +123,9 @@ class FlashInferMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
qo_indptr_cpu: torch.Tensor
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
|
||||
paged_kv_indptr_cpu: torch.Tensor
|
||||
# The page indices of the paged kv cache (on device for plan)
|
||||
paged_kv_indices: torch.Tensor
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size] (CPU for plan)
|
||||
paged_kv_last_page_len_cpu: torch.Tensor
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
@ -164,10 +144,6 @@ class FlashInferMetadata:
|
||||
|
||||
# For cascade attention (CPU for planning).
|
||||
use_cascade: bool
|
||||
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
@ -327,134 +303,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[
|
||||
attn_metadata.shared_qo_indptr_cpu,
|
||||
attn_metadata.qo_indptr_cpu
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indptr_cpu,
|
||||
attn_metadata.paged_kv_indptr_cpu
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indices_cpu,
|
||||
attn_metadata.paged_kv_indices
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_last_page_len_cpu,
|
||||
attn_metadata.paged_kv_last_page_len_cpu
|
||||
],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len_cpu[
|
||||
prefill_start:].shape[0] == num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
||||
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
||||
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
|
||||
prefill_start:]
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||
self.device)
|
||||
|
||||
if num_decodes > 0:
|
||||
pure_decode = num_prefills == 0
|
||||
# possible required padding for cudagraph replay
|
||||
use_cudagraph = (self.enable_cuda_graph and pure_decode and
|
||||
num_decodes <= self._decode_cudagraph_max_bs)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = (
|
||||
self.vllm_config.pad_for_cudagraph(num_decodes))
|
||||
# Carefully fulfill the padding region with reasonable value
|
||||
# on cpu.
|
||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
||||
num_input_tokens].fill_(
|
||||
attn_metadata.
|
||||
paged_kv_indptr_cpu[-1])
|
||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||
# This is because flashinfer treats 0 as a full page
|
||||
# instead of empty.
|
||||
self.paged_kv_last_page_len_cpu[
|
||||
num_decodes:num_input_tokens].fill_(1)
|
||||
|
||||
else:
|
||||
num_input_tokens = num_decodes
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph)
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
attn_metadata.decode_wrapper,
|
||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
attn_metadata.seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@ -548,13 +396,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=self.
|
||||
paged_kv_last_page_len_cpu[:num_reqs],
|
||||
q_data_type=self.q_data_type,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -567,14 +409,123 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
|
||||
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
||||
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
||||
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
|
||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
||||
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert qo_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert paged_kv_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert paged_kv_last_page_len_cpu[prefill_start:].shape[
|
||||
0] == num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
|
||||
prefill_start]
|
||||
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||
self.device)
|
||||
|
||||
if num_decodes > 0:
|
||||
pure_decode = num_prefills == 0
|
||||
# possible required padding for cudagraph replay
|
||||
use_cudagraph = (self.enable_cuda_graph and pure_decode and
|
||||
num_decodes <= self._decode_cudagraph_max_bs)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = (
|
||||
self.vllm_config.pad_for_cudagraph(num_decodes))
|
||||
# Carefully fulfill the padding region with reasonable value
|
||||
# on cpu.
|
||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
||||
num_input_tokens].fill_(
|
||||
paged_kv_indptr_cpu[-1])
|
||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||
# This is because flashinfer treats 0 as a full page
|
||||
# instead of empty.
|
||||
self.paged_kv_last_page_len_cpu[
|
||||
num_decodes:num_input_tokens].fill_(1)
|
||||
|
||||
else:
|
||||
num_input_tokens = num_decodes
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph)
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
attn_metadata.decode_wrapper,
|
||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
|
Reference in New Issue
Block a user