mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Added test for sampling repetition penalty bug. (#5659)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -631,3 +631,72 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
def test_sampling_params(sampling_params: List[SamplingParams]):
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
seq_lens: List[int] = []
|
||||
for i in range(2):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
fake_logits = torch.full((2, vocab_size),
|
||||
1e-2,
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
|
||||
fake_logits[:, 5] = 1.1e-2
|
||||
fake_logits[:, 1] = 1.2e-2
|
||||
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
generated_tokens = []
|
||||
for output in sampler_output:
|
||||
generated_tokens.append(output.samples[0].output_token)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
# one configuration is greedy with repetition_penalty
|
||||
sampling_params_rep = SamplingParams(
|
||||
temperature=0.0,
|
||||
repetition_penalty=2.0,
|
||||
)
|
||||
|
||||
# other configuration is sampling w/o repetition_penalty
|
||||
sampling_params_sample = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_k=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
tokens1 = test_sampling_params(
|
||||
[sampling_params_rep, sampling_params_sample])
|
||||
|
||||
tokens2 = test_sampling_params(
|
||||
[sampling_params_sample, sampling_params_rep])
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
Reference in New Issue
Block a user