mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1][Spec Decode] Always use argmax for sampling draft tokens (#16899)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -226,7 +226,7 @@ def rejection_sample(
|
|||||||
is_greedy,
|
is_greedy,
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
IS_NGRAM=draft_probs is None,
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
num_warps=1,
|
num_warps=1,
|
||||||
)
|
)
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
@ -423,7 +423,7 @@ def sample_recovered_tokens(
|
|||||||
q,
|
q,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
triton.next_power_of_2(vocab_size),
|
triton.next_power_of_2(vocab_size),
|
||||||
IS_NGRAM=draft_probs is None,
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
)
|
)
|
||||||
return recovered_token_ids
|
return recovered_token_ids
|
||||||
|
|
||||||
@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
|
|||||||
is_greedy_ptr, # [batch_size]
|
is_greedy_ptr, # [batch_size]
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
IS_NGRAM: tl.constexpr,
|
NO_DRAFT_PROBS: tl.constexpr,
|
||||||
):
|
):
|
||||||
req_idx = tl.program_id(0)
|
req_idx = tl.program_id(0)
|
||||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||||
@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
|
|||||||
for pos in range(num_draft_tokens):
|
for pos in range(num_draft_tokens):
|
||||||
if not rejected:
|
if not rejected:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
if IS_NGRAM:
|
if NO_DRAFT_PROBS:
|
||||||
draft_prob = 1
|
draft_prob = 1
|
||||||
else:
|
else:
|
||||||
draft_prob = tl.load(draft_probs_ptr +
|
draft_prob = tl.load(draft_probs_ptr +
|
||||||
@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
|
|||||||
q_ptr, # [batch_size, vocab_size]
|
q_ptr, # [batch_size, vocab_size]
|
||||||
vocab_size,
|
vocab_size,
|
||||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||||
IS_NGRAM: tl.constexpr,
|
NO_DRAFT_PROBS: tl.constexpr,
|
||||||
):
|
):
|
||||||
req_idx = tl.program_id(0)
|
req_idx = tl.program_id(0)
|
||||||
if req_idx == 0:
|
if req_idx == 0:
|
||||||
@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
|
|||||||
return
|
return
|
||||||
|
|
||||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||||
if IS_NGRAM:
|
if NO_DRAFT_PROBS:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||||
draft_token_id)
|
draft_token_id)
|
||||||
@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
|
|||||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||||
|
|
||||||
if IS_NGRAM:
|
if NO_DRAFT_PROBS:
|
||||||
# Restore the original probability.
|
# Restore the original probability.
|
||||||
tl.store(
|
tl.store(
|
||||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||||
|
@ -51,7 +51,7 @@ class EagleProposer:
|
|||||||
# [batch_size, max_num_blocks_per_req]
|
# [batch_size, max_num_blocks_per_req]
|
||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
last_token_indices = cu_num_tokens[1:] - 1
|
last_token_indices = cu_num_tokens[1:] - 1
|
||||||
@ -94,17 +94,15 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
sample_hidden_states = hidden_states[last_token_indices]
|
sample_hidden_states = hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
logits, sampling_metadata)
|
|
||||||
|
|
||||||
# Early exit if there is only one draft token to be generated.
|
# Early exit if there is only one draft token to be generated.
|
||||||
if self.num_speculative_tokens == 1:
|
if self.num_speculative_tokens == 1:
|
||||||
# [batch_size, 1] and [batch_size, 1, vocab_size]
|
# [batch_size, 1]
|
||||||
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
|
return draft_token_ids.view(-1, 1)
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
draft_probs_list = [draft_probs]
|
|
||||||
|
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
hidden_states = sample_hidden_states
|
hidden_states = sample_hidden_states
|
||||||
@ -159,16 +157,12 @@ class EagleProposer:
|
|||||||
positions=clamped_positions,
|
positions=clamped_positions,
|
||||||
)
|
)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
draft_token_ids, probs = compute_probs_and_sample_next_token(
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
logits, sampling_metadata)
|
|
||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
draft_probs_list.append(probs)
|
|
||||||
|
|
||||||
# [batch_size, num_speculative_tokens]
|
# [batch_size, num_speculative_tokens]
|
||||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
# [batch_size, num_speculative_tokens, vocab_size]
|
return draft_token_ids
|
||||||
draft_probs = torch.stack(draft_probs_list, dim=1)
|
|
||||||
return draft_token_ids, draft_probs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
@ -238,6 +232,10 @@ class EagleProposer:
|
|||||||
self.model.lm_head = target_model.lm_head
|
self.model.lm_head = target_model.lm_head
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||||
|
# to sample the draft tokens. We will use this after we find a way to manage
|
||||||
|
# the draft prob tensor.
|
||||||
|
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
|
||||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||||
# We should refactor this to reuse the same sampling implementation.
|
# We should refactor this to reuse the same sampling implementation.
|
||||||
def compute_probs_and_sample_next_token(
|
def compute_probs_and_sample_next_token(
|
||||||
|
@ -1230,7 +1230,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||||
|
|
||||||
draft_token_ids, draft_probs = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
target_hidden_states=target_hidden_states,
|
target_hidden_states=target_hidden_states,
|
||||||
@ -1241,9 +1241,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
|
|
||||||
# in the next step.
|
|
||||||
del draft_probs
|
|
||||||
|
|
||||||
# Clear KVConnector state after all KVs are generated.
|
# Clear KVConnector state after all KVs are generated.
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
|
Reference in New Issue
Block a user