[Misc] Use CpuGpuBuffer for FlashInfer metadata builder

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-27 03:18:30 -07:00
parent 6578e87365
commit 0dba2a36a9

View File

@ -39,6 +39,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
split_decodes_and_prefills)
# yapf: enable
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
@ -215,34 +216,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
# Preparing persistent buffers (device-side)
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)
# host-side buffer
# Preparing persistent buffers
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_last_page_len_np = (
self.paged_kv_last_page_len_cpu.numpy())
self.paged_kv_indptr = CpuGpuBuffer(max_num_reqs + 1,
dtype=torch.int32,
device=self.device,
pin_memory=pin_memory)
self.paged_kv_indices = CpuGpuBuffer(max_num_pages,
dtype=torch.int32,
device=self.device,
pin_memory=pin_memory)
self.paged_kv_last_page_len = CpuGpuBuffer(max_num_reqs,
dtype=torch.int32,
device=self.device,
pin_memory=pin_memory)
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
@ -269,10 +256,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if decode_wrapper is None:
if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
paged_kv_indptr = self.paged_kv_indptr.gpu[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices.gpu
paged_kv_last_page_len = (
self.paged_kv_last_page_len.gpu[:batch_size])
else:
paged_kv_indptr = None
paged_kv_indices = None
@ -355,15 +342,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr_np[1:num_reqs + 1],
out=self.paged_kv_indptr.np[1:num_reqs + 1],
)
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
non_blocking=True)
paged_kv_indptr = self.paged_kv_indptr.copy_to_gpu(num_reqs + 1)
# write self.paged_kv_indices inplace
num_actual_pages = num_blocks_np.sum().item()
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs, )](
paged_kv_indices,
block_table_tensor,
@ -374,7 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
paged_kv_last_page_len_np == 0,
page_size,
paged_kv_last_page_len_np,
@ -418,8 +403,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[:1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
@ -495,14 +480,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
num_input_tokens].fill_(
paged_kv_indptr_cpu[-1])
self.paged_kv_indptr.np[1 + num_decodes:1 +
num_input_tokens].fill(
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decodes:num_input_tokens].fill_(1)
self.paged_kv_last_page_len.np[
num_decodes:num_input_tokens].fill(1)
else:
num_input_tokens = num_decodes
@ -515,9 +500,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
self.paged_kv_indptr.cpu[:num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
self.paged_kv_last_page_len.cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads,
self.num_kv_heads,