mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
fix
This commit is contained in:
@ -195,7 +195,7 @@ def rejection_sample(
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
return output_token_ids, output_probs
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
@ -475,8 +475,8 @@ def rejection_greedy_sample_kernel(
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
not rejected)
|
||||
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
draft_token_id == target_argmax_id)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
|
Reference in New Issue
Block a user