mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix for inconsistent behaviour related to sampling and repetition penalties (#5639)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -386,18 +386,10 @@ class SamplingTensors:
|
||||
presence_penalties += [0] * prefill_len
|
||||
frequency_penalties += [0] * prefill_len
|
||||
repetition_penalties += [1] * prefill_len
|
||||
if do_penalties:
|
||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||
output_tokens.extend([] for _ in range(prefill_len))
|
||||
|
||||
if seq_group.do_sample:
|
||||
sample_lens = len(seq_group.sample_indices)
|
||||
assert sample_lens == len(seq_ids)
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
if do_penalties:
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
@ -424,6 +416,20 @@ class SamplingTensors:
|
||||
sampling_seeds.append(seq_seeds)
|
||||
sample_indices.extend(seq_group.sample_indices)
|
||||
|
||||
if do_penalties:
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
if (seq_group.is_prompt
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||
output_tokens.extend([] for _ in range(prefill_len))
|
||||
if seq_group.do_sample:
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
|
||||
sampling_tensors = SamplingTensors.from_lists(
|
||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||
frequency_penalties, repetition_penalties, sampling_seeds,
|
||||
|
Reference in New Issue
Block a user