[V1][Spec Decode] Always use argmax for sampling draft tokens (#16899)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-04-23 14:57:43 -07:00
committed by GitHub
parent 32d4b669d0
commit 41fb013d29
3 changed files with 18 additions and 23 deletions

View File

@ -226,7 +226,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
@ -423,7 +423,7 @@ def sample_recovered_tokens(
q,
vocab_size,
triton.next_power_of_2(vocab_size),
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if IS_NGRAM:
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,

View File

@ -51,7 +51,7 @@ class EagleProposer:
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
@ -94,17 +94,15 @@ class EagleProposer:
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]
positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
@ -159,16 +157,12 @@ class EagleProposer:
positions=clamped_positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
return draft_token_ids
@staticmethod
def prepare_inputs(
@ -238,6 +232,10 @@ class EagleProposer:
self.model.lm_head = target_model.lm_head
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(

View File

@ -1230,7 +1230,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids, draft_probs = self.drafter.propose(
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
@ -1241,9 +1241,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del draft_probs
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():