This commit is contained in:
LiuXiaoxuanPKU
2025-04-01 15:46:41 -07:00
parent 8fcd4d18e0
commit b484b79504

View File

@ -195,7 +195,7 @@ def rejection_sample(
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
return output_token_ids, output_probs
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
@ -475,8 +475,8 @@ def rejection_greedy_sample_kernel(
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
not rejected)
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
draft_token_id == target_argmax_id)
if not rejected:
# If all tokens are accepted, append the bonus token.