mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
2 Commits
v0.11.0rc3
...
woosuk/sam
Author | SHA1 | Date | |
---|---|---|---|
fefed35cee | |||
901afda905 |
@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
num_sampled_tokens = model_runner_output.num_sampled_tokens
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
@ -849,6 +850,10 @@ class Scheduler(SchedulerInterface):
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
|
||||
if sampled_token_ids is not None:
|
||||
# Optimization: Avoid a .tolist() call for each request.
|
||||
sampled_token_ids = sampled_token_ids.tolist()
|
||||
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
@ -867,14 +872,19 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
generated_token_ids: list[int] = []
|
||||
if sampled_token_ids is not None:
|
||||
assert num_sampled_tokens is not None
|
||||
num_sampled = num_sampled_tokens[req_index]
|
||||
if num_sampled > 0:
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index][:num_sampled]
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
num_draft_tokens = len(scheduled_spec_token_ids)
|
||||
num_accepted = len(generated_token_ids) - 1
|
||||
num_accepted = num_sampled - 1
|
||||
num_rejected = num_draft_tokens - num_accepted
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
|
@ -4,6 +4,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -88,11 +89,12 @@ class ModelRunnerOutput:
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
# [num_reqs]
|
||||
# Number of tokens sampled in the current step. Each request may generate
|
||||
# different number of tokens due to chunked prefilling and spec decoding.
|
||||
num_sampled_tokens: Optional[np.ndarray]
|
||||
# [num_reqs, max_num_sampled_tokens]
|
||||
sampled_token_ids: Optional[np.ndarray]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -123,10 +125,13 @@ class DraftTokenIds:
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
num_nans_in_logits=None)
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
req_id_to_index={},
|
||||
num_sampled_tokens=None,
|
||||
sampled_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -106,9 +107,9 @@ class RejectionSampler(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
output_token_ids: np.ndarray,
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
) -> np.ndarray:
|
||||
"""Parse the output of the rejection sampler.
|
||||
|
||||
Args:
|
||||
@ -119,17 +120,14 @@ class RejectionSampler(nn.Module):
|
||||
vocab_size: The size of the vocabulary.
|
||||
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
A Numpy array of the number of valid sampled tokens.
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
valid_mask = ((output_token_ids != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids < vocab_size))
|
||||
# Get the number until the first valid_mask=False.
|
||||
num_sampled_tokens = np.cumprod(valid_mask, axis=1).sum(axis=1)
|
||||
return num_sampled_tokens
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
|
@ -1456,7 +1456,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
num_sampled_tokens=None,
|
||||
sampled_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
@ -1665,23 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Post-processing for chunked prefill.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
chunked_prefilling = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens_np
|
||||
< self.input_batch.num_tokens_no_spec[:num_reqs])
|
||||
if self.input_batch.generators:
|
||||
chunked_prefill_indices = np.where(chunked_prefilling)[0]
|
||||
for i in chunked_prefill_indices:
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
@ -1700,16 +1699,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
sampled_token_ids_np = self._to_numpy(sampled_token_ids)
|
||||
num_sampled_tokens = (~chunked_prefilling).astype(np.int32)
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
sampled_token_ids_np = sampled_token_ids.cpu().numpy()
|
||||
num_sampled_tokens = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids_np,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
num_sampled_tokens *= ~chunked_prefilling
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
@ -1717,9 +1716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
for req_idx in range(num_reqs):
|
||||
num_sampled = num_sampled_tokens[req_idx]
|
||||
if num_sampled == 0:
|
||||
continue
|
||||
sampled_ids = sampled_token_ids_np[req_idx][:num_sampled].tolist()
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
@ -1740,7 +1741,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
num_sampled_tokens,
|
||||
sampled_token_ids_np,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
@ -1754,7 +1756,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
sampled_token_ids=sampled_token_ids_np,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
@ -1776,7 +1779,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_sampled_tokens: np.ndarray,
|
||||
sampled_token_ids: np.ndarray,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
@ -1788,19 +1792,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
num_sampled_tokens)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||||
if sample_hidden_states.shape[0] == len(num_sampled_tokens):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
for num_draft, num_sampled in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
num_sampled_tokens):
|
||||
indices.append(offset + num_sampled - 1)
|
||||
offset += num_draft + 1
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
@ -1813,11 +1817,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
req_ids = self.input_batch.req_ids
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
for i in range(num_reqs):
|
||||
num_sampled = num_sampled_tokens[i]
|
||||
if num_sampled > 0:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
next_token_id = sampled_token_ids[i][num_sampled - 1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
@ -1844,13 +1850,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
|
||||
dtype=torch.int32)
|
||||
num_draft_tokens = np.asarray(
|
||||
spec_decode_metadata.num_draft_tokens, dtype=np.int32)
|
||||
num_accepted_tokens = num_sampled_tokens - 1
|
||||
num_rejected_tokens = np.clip(num_draft_tokens -
|
||||
num_accepted_tokens,
|
||||
a_min=0)
|
||||
num_rejected_tokens_cpu = torch.from_numpy(num_rejected_tokens)
|
||||
common_attn_metadata, token_indices =\
|
||||
self.drafter.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens_cpu)
|
||||
@ -1881,13 +1887,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_sampled_tokens: np.ndarray,
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
req_ids = self.input_batch.req_ids
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
for i in range(num_reqs):
|
||||
num_sampled_ids = num_sampled_tokens[i]
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
@ -3267,7 +3274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
def _to_numpy(self, sampled_token_ids: torch.Tensor) -> np.ndarray:
|
||||
# This is a short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/22754.
|
||||
# `tolist` would trigger a cuda wise stream sync, which
|
||||
@ -3280,4 +3287,4 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pinned.copy_(sampled_token_ids, non_blocking=True)
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return pinned.tolist()
|
||||
return pinned.numpy()
|
||||
|
Reference in New Issue
Block a user