Compare commits

...

2 Commits

Author SHA1 Message Date
fefed35cee fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 18:58:00 -07:00
901afda905 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 09:32:49 -07:00
4 changed files with 89 additions and 69 deletions

View File

@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
num_sampled_tokens = model_runner_output.num_sampled_tokens
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
@ -849,6 +850,10 @@ class Scheduler(SchedulerInterface):
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
if sampled_token_ids is not None:
# Optimization: Avoid a .tolist() call for each request.
sampled_token_ids = sampled_token_ids.tolist()
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@ -867,14 +872,19 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
generated_token_ids: list[int] = []
if sampled_token_ids is not None:
assert num_sampled_tokens is not None
num_sampled = num_sampled_tokens[req_index]
if num_sampled > 0:
generated_token_ids = sampled_token_ids[
req_index][:num_sampled]
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_accepted = num_sampled - 1
num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled

View File

@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import NamedTuple, Optional
import numpy as np
import torch
@ -88,11 +89,12 @@ class ModelRunnerOutput:
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
# [num_reqs]
# Number of tokens sampled in the current step. Each request may generate
# different number of tokens due to chunked prefilling and spec decoding.
num_sampled_tokens: Optional[np.ndarray]
# [num_reqs, max_num_sampled_tokens]
sampled_token_ids: Optional[np.ndarray]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@ -123,10 +125,13 @@ class DraftTokenIds:
draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
num_sampled_tokens=None,
sampled_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
@ -106,9 +107,9 @@ class RejectionSampler(nn.Module):
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
output_token_ids: np.ndarray,
vocab_size: int,
) -> list[list[int]]:
) -> np.ndarray:
"""Parse the output of the rejection sampler.
Args:
@ -119,17 +120,14 @@ class RejectionSampler(nn.Module):
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
A Numpy array of the number of valid sampled tokens.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
valid_mask = ((output_token_ids != PLACEHOLDER_TOKEN_ID) &
(output_token_ids < vocab_size))
# Get the number until the first valid_mask=False.
num_sampled_tokens = np.cumprod(valid_mask, axis=1).sum(axis=1)
return num_sampled_tokens
def rejection_sample(

View File

@ -1456,7 +1456,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
num_sampled_tokens=None,
sampled_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
@ -1665,23 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Post-processing for chunked prefill.
num_reqs = self.input_batch.num_reqs
chunked_prefilling = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens_np
< self.input_batch.num_tokens_no_spec[:num_reqs])
if self.input_batch.generators:
chunked_prefill_indices = np.where(chunked_prefilling)[0]
for i in chunked_prefill_indices:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
@ -1700,16 +1699,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
sampled_token_ids_np = self._to_numpy(sampled_token_ids)
num_sampled_tokens = (~chunked_prefilling).astype(np.int32)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
sampled_token_ids_np = sampled_token_ids.cpu().numpy()
num_sampled_tokens = self.rejection_sampler.parse_output(
sampled_token_ids_np,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
num_sampled_tokens *= ~chunked_prefilling
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
@ -1717,9 +1716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
for req_idx in range(num_reqs):
num_sampled = num_sampled_tokens[req_idx]
if num_sampled == 0:
continue
sampled_ids = sampled_token_ids_np[req_idx][:num_sampled].tolist()
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
@ -1740,7 +1741,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
num_sampled_tokens,
sampled_token_ids_np,
sampling_metadata,
hidden_states,
sample_hidden_states,
@ -1754,7 +1756,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
num_sampled_tokens=num_sampled_tokens,
sampled_token_ids=sampled_token_ids_np,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
@ -1776,7 +1779,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
num_sampled_tokens: np.ndarray,
sampled_token_ids: np.ndarray,
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
@ -1788,19 +1792,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
num_sampled_tokens)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids):
if sample_hidden_states.shape[0] == len(num_sampled_tokens):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
for num_draft, num_sampled in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
num_sampled_tokens):
indices.append(offset + num_sampled - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
@ -1813,11 +1817,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
req_ids = self.input_batch.req_ids
num_reqs = self.input_batch.num_reqs
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
for i in range(num_reqs):
num_sampled = num_sampled_tokens[i]
if num_sampled > 0:
# Common case.
next_token_id = token_ids[-1]
next_token_id = sampled_token_ids[i][num_sampled - 1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
@ -1844,13 +1850,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
dtype=torch.int32)
num_draft_tokens = np.asarray(
spec_decode_metadata.num_draft_tokens, dtype=np.int32)
num_accepted_tokens = num_sampled_tokens - 1
num_rejected_tokens = np.clip(num_draft_tokens -
num_accepted_tokens,
a_min=0)
num_rejected_tokens_cpu = torch.from_numpy(num_rejected_tokens)
common_attn_metadata, token_indices =\
self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu)
@ -1881,13 +1887,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
num_sampled_tokens: np.ndarray,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids
num_reqs = self.input_batch.num_reqs
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
for i in range(num_reqs):
num_sampled_ids = num_sampled_tokens[i]
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
@ -3267,7 +3274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return kv_cache_spec
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
def _to_numpy(self, sampled_token_ids: torch.Tensor) -> np.ndarray:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
@ -3280,4 +3287,4 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
return pinned.numpy()