[bugfix] [spec-decoding] fix data race in sample_recovered_tokens_kernel (vLLM v1) (#23829)

Signed-off-by: He-Jingkai <he-jingkai@outlook.com>
This commit is contained in:
Jingkai He
2025-08-29 03:05:20 +08:00
committed by GitHub
parent 04d1dd7f4a
commit 57d4ede520

View File

@ -598,17 +598,10 @@ def sample_recovered_tokens_kernel(
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
@ -628,9 +621,3 @@ def sample_recovered_tokens_kernel(
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)