mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Spec Decode] Always use argmax for sampling draft tokens (#16899)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -226,7 +226,7 @@ def rejection_sample(
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
@ -423,7 +423,7 @@ def sample_recovered_tokens(
|
||||
q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
IS_NGRAM=draft_probs is None,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if IS_NGRAM:
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
IS_NGRAM: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if IS_NGRAM:
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
if IS_NGRAM:
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
|
@ -51,7 +51,7 @@ class EagleProposer:
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
@ -94,17 +94,15 @@ class EagleProposer:
|
||||
)
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
|
||||
logits, sampling_metadata)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1] and [batch_size, 1, vocab_size]
|
||||
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
draft_probs_list = [draft_probs]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = sample_hidden_states
|
||||
@ -159,16 +157,12 @@ class EagleProposer:
|
||||
positions=clamped_positions,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
draft_token_ids, probs = compute_probs_and_sample_next_token(
|
||||
logits, sampling_metadata)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
draft_probs_list.append(probs)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
# [batch_size, num_speculative_tokens, vocab_size]
|
||||
draft_probs = torch.stack(draft_probs_list, dim=1)
|
||||
return draft_token_ids, draft_probs
|
||||
return draft_token_ids
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
@ -238,6 +232,10 @@ class EagleProposer:
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
# the draft prob tensor.
|
||||
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
# We should refactor this to reuse the same sampling implementation.
|
||||
def compute_probs_and_sample_next_token(
|
||||
|
@ -1230,7 +1230,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
|
||||
draft_token_ids, draft_probs = self.drafter.propose(
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
@ -1241,9 +1241,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
|
||||
# in the next step.
|
||||
del draft_probs
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
|
Reference in New Issue
Block a user