[Perf] Remove hardcoded num_warps=1 (#26183)

Signed-off-by: Corey Lowman <clowman1993@gmail.com>
This commit is contained in:
Corey Lowman
2025-10-03 16:38:50 -04:00
committed by GitHub
parent a26917332f
commit 0879736aab

View File

@ -164,12 +164,12 @@ def rejection_sample(
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID,
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
@ -186,7 +186,6 @@ def rejection_sample(
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
@ -227,7 +226,6 @@ def rejection_sample(
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
@ -329,7 +327,6 @@ def expand_batch_to_tokens(
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x