[bugfix] [spec-decoding] fix data race in sample_recovered_tokens_kernel (vLLM v1) (#23829)
Signed-off-by: He-Jingkai <he-jingkai@outlook.com>
This commit is contained in:
@ -598,17 +598,10 @@ def sample_recovered_tokens_kernel(
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
mask=((vocab_offset < vocab_size) &
|
||||
(vocab_offset != draft_token_id)),
|
||||
other=0)
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
@ -628,9 +621,3 @@ def sample_recovered_tokens_kernel(
|
||||
other=float("-inf"))
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
|
Reference in New Issue
Block a user