mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Perf] Remove hardcoded num_warps=1 (#26183)
Signed-off-by: Corey Lowman <clowman1993@gmail.com>
This commit is contained in:
@ -164,12 +164,12 @@ def rejection_sample(
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
output_token_ids = torch.full(
|
||||
(batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
@ -186,7 +186,6 @@ def rejection_sample(
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
@ -227,7 +226,6 @@ def rejection_sample(
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@ -329,7 +327,6 @@ def expand_batch_to_tokens(
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
num_warps=1,
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
|
Reference in New Issue
Block a user