[Refactor] Refactor persistent buffers with CpuGpuBuffer (#23515)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Woosuk Kwon
2025-08-25 08:44:48 -07:00
committed by GitHub
parent a9082a4d14
commit 0ff902f3b4
2 changed files with 99 additions and 103 deletions

View File

@ -83,8 +83,9 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
from .utils import (AttentionGroup, CpuGpuBuffer, MultiModalBudget,
bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING:
@ -149,6 +150,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
parallel_config)
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
@ -242,21 +245,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._init_device_properties()
# Persistent buffers for CUDA graphs.
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.input_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens,
dtype=torch.int64)
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -270,23 +269,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros(
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64)
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
@ -294,28 +281,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_model_len,
self.max_num_tokens),
dtype=np.int64)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
@ -352,6 +317,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
@ -376,7 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens[:num_reqs]
seq_lens = self.seq_lens.gpu[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
@ -719,7 +690,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens)
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
@ -742,7 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
out=self.input_ids.cpu[:total_num_scheduled_tokens])
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
@ -750,36 +721,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
total_num_scheduled_tokens)
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
query_start_loc = self.query_start_loc[:num_reqs + 1]
self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
self.seq_lens_np[:num_reqs] = (
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens_np[num_reqs:].fill(0)
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -833,8 +801,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: dict[str, Any] = {}
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
@ -1065,9 +1033,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]
self.mrope_positions.cpu[:, dst_start:dst_end] = (
req.mrope_positions[:, src_start:src_end])
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
@ -1076,7 +1043,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dst_end = mrope_pos_ptr + completion_part_len
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions_np,
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
@ -1140,7 +1107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
@ -1471,7 +1438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output = self.model.pooler(
@ -1550,7 +1517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids[:num_scheduled_tokens],
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
)
@ -1569,13 +1536,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
positions = self.mrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]
positions = self.positions.gpu[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@ -1857,9 +1824,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[:num_scheduled_tokens]
target_positions = self.positions.gpu[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
@ -1879,9 +1846,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu)
target_token_ids = self.input_ids[token_indices]
target_token_ids = self.input_ids.gpu[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
target_positions = self.positions.gpu[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
@ -2123,7 +2090,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
@ -2196,7 +2163,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@functools.cache
def rand_input_ids() -> torch.Tensor:
return torch.randint_like(
self.input_ids,
self.input_ids.gpu,
low=0,
high=self.model_config.get_vocab_size(),
dtype=input_ids.dtype)
@ -2313,18 +2280,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len
self.seq_lens_np[num_reqs:] = 0
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
self.seq_lens.np[:num_reqs] = self.max_model_len
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs +
1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens.gpu[:num_reqs],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
@ -2353,14 +2320,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**self._dummy_mm_kwargs(num_reqs),
}
else:
input_ids = self.input_ids[:num_tokens]
input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
positions = self.mrope_positions.gpu[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
positions = self.positions.gpu[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None

View File

@ -298,3 +298,32 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
class CpuGpuBuffer:
def __init__(
self,
*args,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.cpu = torch.zeros(*args,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = self.cpu.to(device)
def copy_to_gpu(self, n: Optional[int] = None) -> None:
if n is None:
return self.gpu.copy_(self.cpu, non_blocking=True)
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
def copy_to_cpu(self, n: Optional[int] = None) -> None:
"""NOTE: Because this method is non-blocking, explicit synchronization
is needed to ensure the data is copied to CPU."""
if n is None:
return self.cpu.copy_(self.gpu, non_blocking=True)
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)