[Refactor] Refactor persistent buffers with CpuGpuBuffer (#23515)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -83,8 +83,9 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
|||||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
from .utils import (AttentionGroup, CpuGpuBuffer, MultiModalBudget,
|
||||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
bind_kv_cache, gather_mm_placeholders,
|
||||||
|
initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -149,6 +150,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
parallel_config)
|
parallel_config)
|
||||||
self.hidden_size = model_config.get_hidden_size()
|
self.hidden_size = model_config.get_hidden_size()
|
||||||
self.attention_chunk_size = model_config.attention_chunk_size
|
self.attention_chunk_size = model_config.attention_chunk_size
|
||||||
|
# Only relevant for models using ALiBi (e.g, MPT)
|
||||||
|
self.use_alibi = check_use_alibi(model_config)
|
||||||
|
|
||||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||||
|
|
||||||
@ -242,21 +245,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self._init_device_properties()
|
self._init_device_properties()
|
||||||
|
|
||||||
# Persistent buffers for CUDA graphs.
|
# Persistent buffers for CUDA graphs.
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = self._make_buffer(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32)
|
||||||
device=self.device)
|
self.positions = self._make_buffer(self.max_num_tokens,
|
||||||
self.positions = torch.zeros(self.max_num_tokens,
|
dtype=torch.int64)
|
||||||
dtype=torch.int64,
|
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
|
||||||
device=self.device)
|
dtype=torch.int32)
|
||||||
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||||
dtype=torch.int32,
|
self.inputs_embeds = torch.zeros(
|
||||||
device=self.device)
|
(self.max_num_tokens, self.hidden_size),
|
||||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
dtype=self.dtype,
|
||||||
dtype=torch.int32,
|
device=self.device)
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
# None in the first PP rank. The rest are set after load_model.
|
|
||||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
@ -270,23 +269,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# identical position IDs, making M-RoPE functionally equivalent to
|
# identical position IDs, making M-RoPE functionally equivalent to
|
||||||
# 1D-RoPE.
|
# 1D-RoPE.
|
||||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
self.mrope_positions = self._make_buffer(
|
||||||
dtype=torch.int64,
|
(3, self.max_num_tokens + 1), dtype=torch.int64)
|
||||||
device=self.device)
|
|
||||||
self.mrope_positions_cpu = torch.zeros(
|
|
||||||
(3, self.max_num_tokens + 1),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
|
||||||
|
|
||||||
# Only relevant for models using ALiBi (e.g, MPT)
|
# None in the first PP rank. The rest are set after load_model.
|
||||||
self.use_alibi = check_use_alibi(model_config)
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
|
||||||
(self.max_num_tokens, self.hidden_size),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
# Keep in int64 to avoid overflow with long context
|
# Keep in int64 to avoid overflow with long context
|
||||||
@ -294,28 +281,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
self.max_num_tokens),
|
self.max_num_tokens),
|
||||||
dtype=np.int64)
|
dtype=np.int64)
|
||||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
||||||
# a faster version of creating a new tensor every time. Thus, we should
|
|
||||||
# not make any assumptions about the values in these tensors.
|
|
||||||
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.positions_np = self.positions_cpu.numpy()
|
|
||||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
||||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
||||||
|
|
||||||
# Layer pairings for cross-layer KV sharing.
|
# Layer pairings for cross-layer KV sharing.
|
||||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||||
@ -352,6 +317,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||||
torch.Tensor]] = None
|
torch.Tensor]] = None
|
||||||
|
|
||||||
|
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||||
|
return CpuGpuBuffer(*args,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
def _init_model_kwargs(self, num_tokens: int):
|
def _init_model_kwargs(self, num_tokens: int):
|
||||||
model_kwargs = dict[str, Any]()
|
model_kwargs = dict[str, Any]()
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -376,7 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if len(token_type_id_requests) == 0:
|
if len(token_type_id_requests) == 0:
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
seq_lens = self.seq_lens[:num_reqs]
|
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||||
token_type_ids = []
|
token_type_ids = []
|
||||||
|
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
@ -719,7 +690,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
|
|
||||||
# Get positions.
|
# Get positions.
|
||||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||||
arange,
|
arange,
|
||||||
out=positions_np)
|
out=positions_np)
|
||||||
@ -742,7 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||||
0,
|
0,
|
||||||
torch.from_numpy(token_indices),
|
torch.from_numpy(token_indices),
|
||||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
out=self.input_ids.cpu[:total_num_scheduled_tokens])
|
||||||
|
|
||||||
self.input_batch.block_table.compute_slot_mapping(
|
self.input_batch.block_table.compute_slot_mapping(
|
||||||
req_indices, positions_np)
|
req_indices, positions_np)
|
||||||
@ -750,36 +721,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
total_num_scheduled_tokens)
|
total_num_scheduled_tokens)
|
||||||
|
|
||||||
# Prepare the attention metadata.
|
# Prepare the attention metadata.
|
||||||
self.query_start_loc_np[0] = 0
|
self.query_start_loc.np[0] = 0
|
||||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
|
||||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||||
# like FlashAttention requires that
|
# like FlashAttention requires that
|
||||||
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
|
self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1])
|
||||||
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
|
self.query_start_loc.copy_to_gpu()
|
||||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||||||
|
|
||||||
self.seq_lens_np[:num_reqs] = (
|
self.seq_lens.np[:num_reqs] = (
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
# Fill unused with 0 for full cuda graph mode.
|
# Fill unused with 0 for full cuda graph mode.
|
||||||
self.seq_lens_np[num_reqs:].fill(0)
|
self.seq_lens.np[num_reqs:].fill(0)
|
||||||
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
|
self.seq_lens.copy_to_gpu()
|
||||||
seq_lens = self.seq_lens[:num_reqs]
|
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||||
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
|
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||||
|
|
||||||
# Copy the tensors to the GPU.
|
# Copy the tensors to the GPU.
|
||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
else:
|
else:
|
||||||
# Common case (1D positions)
|
# Common case (1D positions)
|
||||||
self.positions[:total_num_scheduled_tokens].copy_(
|
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
self.positions_cpu[:total_num_scheduled_tokens],
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
use_spec_decode = len(
|
use_spec_decode = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
@ -833,8 +801,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
# Used in the below loop.
|
# Used in the below loop.
|
||||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
spec_decode_common_attn_metadata = None
|
spec_decode_common_attn_metadata = None
|
||||||
@ -1065,9 +1033,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
src_start = num_computed_tokens
|
src_start = num_computed_tokens
|
||||||
src_end = num_computed_tokens + prompt_part_len
|
src_end = num_computed_tokens + prompt_part_len
|
||||||
|
|
||||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
self.mrope_positions.cpu[:, dst_start:dst_end] = (
|
||||||
req.mrope_positions[:,src_start:src_end]
|
req.mrope_positions[:, src_start:src_end])
|
||||||
|
|
||||||
mrope_pos_ptr += prompt_part_len
|
mrope_pos_ptr += prompt_part_len
|
||||||
|
|
||||||
if completion_part_len > 0:
|
if completion_part_len > 0:
|
||||||
@ -1076,7 +1043,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dst_end = mrope_pos_ptr + completion_part_len
|
dst_end = mrope_pos_ptr + completion_part_len
|
||||||
|
|
||||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||||
out=self.mrope_positions_np,
|
out=self.mrope_positions.np,
|
||||||
out_offset=dst_start,
|
out_offset=dst_start,
|
||||||
mrope_position_delta=req.mrope_position_delta,
|
mrope_position_delta=req.mrope_position_delta,
|
||||||
context_len=num_computed_tokens + prompt_part_len,
|
context_len=num_computed_tokens + prompt_part_len,
|
||||||
@ -1140,7 +1107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Compute the draft token ids.
|
# Compute the draft token ids.
|
||||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||||
draft_token_ids = self.input_ids[logits_indices]
|
draft_token_ids = self.input_ids.gpu[logits_indices]
|
||||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||||
|
|
||||||
metadata = SpecDecodeMetadata(
|
metadata = SpecDecodeMetadata(
|
||||||
@ -1471,7 +1438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pooling_metadata = self.input_batch.pooling_metadata
|
pooling_metadata = self.input_batch.pooling_metadata
|
||||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
||||||
|
|
||||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||||
raw_pooler_output = self.model.pooler(
|
raw_pooler_output = self.model.pooler(
|
||||||
@ -1550,7 +1517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
||||||
input_ids=self.input_ids[:num_scheduled_tokens],
|
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
|
||||||
multimodal_embeddings=mm_embeds or None,
|
multimodal_embeddings=mm_embeds or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1569,13 +1536,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# While it is possible to use embeddings as input just like the
|
# While it is possible to use embeddings as input just like the
|
||||||
# multimodal models, it is not desirable for performance since
|
# multimodal models, it is not desirable for performance since
|
||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids.gpu[:num_input_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_input_tokens]
|
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||||||
else:
|
else:
|
||||||
positions = self.positions[:num_input_tokens]
|
positions = self.positions.gpu[:num_input_tokens]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
@ -1857,9 +1824,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||||||
# TODO(woosuk): Support M-RoPE.
|
# TODO(woosuk): Support M-RoPE.
|
||||||
target_positions = self.positions[:num_scheduled_tokens]
|
target_positions = self.positions.gpu[:num_scheduled_tokens]
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||||
@ -1879,9 +1846,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.drafter.prepare_inputs(
|
self.drafter.prepare_inputs(
|
||||||
common_attn_metadata, num_rejected_tokens_cpu)
|
common_attn_metadata, num_rejected_tokens_cpu)
|
||||||
|
|
||||||
target_token_ids = self.input_ids[token_indices]
|
target_token_ids = self.input_ids.gpu[token_indices]
|
||||||
# TODO(woosuk): Support M-RoPE.
|
# TODO(woosuk): Support M-RoPE.
|
||||||
target_positions = self.positions[token_indices]
|
target_positions = self.positions.gpu[token_indices]
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
@ -2123,7 +2090,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# If this is a partial request (i.e. chunked prefill),
|
# If this is a partial request (i.e. chunked prefill),
|
||||||
# then there is prompt logprob generated for each index.
|
# then there is prompt logprob generated for each index.
|
||||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||||
offset = self.query_start_loc_np[req_idx].item()
|
offset = self.query_start_loc.np[req_idx].item()
|
||||||
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||||
logits = self.model.compute_logits(prompt_hidden_states, None)
|
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||||
|
|
||||||
@ -2196,7 +2163,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
@functools.cache
|
@functools.cache
|
||||||
def rand_input_ids() -> torch.Tensor:
|
def rand_input_ids() -> torch.Tensor:
|
||||||
return torch.randint_like(
|
return torch.randint_like(
|
||||||
self.input_ids,
|
self.input_ids.gpu,
|
||||||
low=0,
|
low=0,
|
||||||
high=self.model_config.get_vocab_size(),
|
high=self.model_config.get_vocab_size(),
|
||||||
dtype=input_ids.dtype)
|
dtype=input_ids.dtype)
|
||||||
@ -2313,18 +2280,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
|
|
||||||
# Make sure max_model_len is used at the graph capture time.
|
# Make sure max_model_len is used at the graph capture time.
|
||||||
self.seq_lens_np[:num_reqs] = self.max_model_len
|
self.seq_lens.np[:num_reqs] = self.max_model_len
|
||||||
self.seq_lens_np[num_reqs:] = 0
|
self.seq_lens.np[num_reqs:] = 0
|
||||||
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
|
self.seq_lens.copy_to_gpu()
|
||||||
|
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs +
|
||||||
1],
|
1],
|
||||||
seq_lens=self.seq_lens[:num_reqs],
|
seq_lens=self.seq_lens.gpu[:num_reqs],
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||||||
num_computed_tokens_cpu=self.input_batch.
|
num_computed_tokens_cpu=self.input_batch.
|
||||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
@ -2353,14 +2320,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
**self._dummy_mm_kwargs(num_reqs),
|
**self._dummy_mm_kwargs(num_reqs),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids.gpu[:num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_kwargs = self._init_model_kwargs(num_tokens)
|
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||||
|
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_tokens]
|
positions = self.mrope_positions.gpu[:, :num_tokens]
|
||||||
else:
|
else:
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions.gpu[:num_tokens]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
|
@ -298,3 +298,32 @@ def bind_kv_cache(
|
|||||||
for layer_name, kv_cache in kv_caches.items():
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
# NOTE: Use list because of v0 PP virtual engine.
|
# NOTE: Use list because of v0 PP virtual engine.
|
||||||
forward_context[layer_name].kv_cache = [kv_cache]
|
forward_context[layer_name].kv_cache = [kv_cache]
|
||||||
|
|
||||||
|
|
||||||
|
class CpuGpuBuffer:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
pin_memory: bool,
|
||||||
|
):
|
||||||
|
self.cpu = torch.zeros(*args,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.np = self.cpu.numpy()
|
||||||
|
self.gpu = self.cpu.to(device)
|
||||||
|
|
||||||
|
def copy_to_gpu(self, n: Optional[int] = None) -> None:
|
||||||
|
if n is None:
|
||||||
|
return self.gpu.copy_(self.cpu, non_blocking=True)
|
||||||
|
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||||
|
|
||||||
|
def copy_to_cpu(self, n: Optional[int] = None) -> None:
|
||||||
|
"""NOTE: Because this method is non-blocking, explicit synchronization
|
||||||
|
is needed to ensure the data is copied to CPU."""
|
||||||
|
if n is None:
|
||||||
|
return self.cpu.copy_(self.gpu, non_blocking=True)
|
||||||
|
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||||
|
Reference in New Issue
Block a user