Compare commits

...

2 Commits

Author SHA1 Message Date
c11d1e6781 optimize spec
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 16:40:54 -07:00
e696f78e05 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 13:29:58 -07:00
2 changed files with 83 additions and 81 deletions

View File

@ -27,7 +27,7 @@ class InputBatch:
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
total_num_tokens: int
max_num_tokens: int
max_query_len: int
num_reqs: int
attn_metadata: dict[str, Any]
@ -91,3 +91,53 @@ def prepare_inputs(
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_spec_decode(
# Inputs
query_start_loc: np.ndarray, # [B + 1]
num_draft_tokens: np.ndarray, # [B]
# Outputs
cu_num_draft_tokens: np.ndarray, # [B]
logits_indices: np.ndarray, # [N + B]
target_logits_indices: np.ndarray, # [N]
bonus_logits_indices: np.ndarray, # [B]
) -> int: # N
# Inputs:
# query_start_loc: [ 0, 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# return: 6 (total number of draft tokens)
cu_num_draft = 0
cu_num_sample = 0
num_reqs = num_draft_tokens.shape[0]
for i in range(num_reqs):
q_end_idx = query_start_loc[i + 1]
draft_len = num_draft_tokens[i]
# The last draft_len + 1 query tokens are used for sampling.
sample_len = draft_len + 1
sample_start_idx = cu_num_sample
sample_end_idx = sample_start_idx + sample_len
logits_indices[sample_start_idx:sample_end_idx] = (np.arange(
q_end_idx - sample_len, q_end_idx))
# For each query, the first draft_len tokens need target logits for
# rejection sampling. The draft_len + 1th token is used for bonus token.
draft_start_idx = cu_num_draft
draft_end_idx = draft_start_idx + draft_len
target_logits_indices[draft_start_idx:draft_end_idx] = (np.arange(
sample_start_idx, sample_end_idx - 1))
bonus_logits_indices[i] = sample_end_idx - 1
cu_num_draft += draft_len
cu_num_draft_tokens[i] = cu_num_draft
cu_num_sample += sample_len
return cu_num_draft

View File

@ -77,7 +77,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_block_table import BlockTables
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
from vllm.v1.worker.gpu_input_batch import (InputBatch, prepare_inputs,
prepare_spec_decode)
from vllm.v1.worker.gpu_worker_states import RequestState
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
@ -241,6 +242,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.spec_logits_indices = self._make_buffer(self.max_num_tokens +
self.max_num_reqs,
dtype=torch.int32)
self.target_logits_indices = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.bonus_logits_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -272,13 +282,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.idx_mapping = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
self.max_num_tokens),
dtype=np.int64)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
@ -529,26 +532,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
cumsum_dtype: Optional[np.dtype] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@ -632,7 +615,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if draft_token_ids:
num_draft_tokens[i] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, self.query_start_loc.np[1:num_reqs + 1])
num_draft_tokens)
logits_indices = spec_decode_metadata.logits_indices
logits_indices_padded = None
@ -751,7 +734,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np=idx_mapping_np,
num_reqs=num_reqs,
total_num_tokens=total_num_scheduled_tokens,
max_num_tokens=max_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
attn_metadata=attn_metadata,
spec_decode_metadata=spec_decode_metadata,
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
@ -905,55 +888,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
num_reqs = num_draft_tokens.shape[0]
total_num_draft_tokens = prepare_spec_decode(
self.query_start_loc.np,
num_draft_tokens,
self.cu_num_draft_tokens.np,
self.logits_indices.np,
self.target_logits_indices.np,
self.bonus_logits_indices.np,
)
# Compute the logits indices.
# [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
num_sampled_tokens, cumsum_dtype=np.int32)
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
# Compute the draft logits indices.
# cu_num_draft_tokens: [3, 3, 5, 5, 6]
# arange: [0, 1, 2, 0, 1, 0]
cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
num_draft_tokens, cumsum_dtype=np.int32)
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
logits_indices = self.logits_indices.copy_to_gpu(
num_reqs + total_num_draft_tokens)
target_logits_indices = self.target_logits_indices.copy_to_gpu(
total_num_draft_tokens)
bonus_logits_indices = self.bonus_logits_indices.copy_to_gpu(num_reqs)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@ -1412,10 +1363,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
uniform_decode = (input_batch.max_num_tokens
uniform_decode = (input_batch.max_query_len
== self.uniform_decode_query_len
and num_scheduled_tokens
== input_batch.num_reqs * input_batch.max_num_tokens)
== input_batch.num_reqs * input_batch.max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
@ -1669,7 +1620,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = input_batch.spec_decode_metadata.num_draft_tokens
num_draft_tokens = (
input_batch.spec_decode_metadata.num_draft_tokens)
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)