mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Fixed prompt length for random dataset (#17408)
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
This commit is contained in:
committed by
GitHub
parent
edbf2d609e
commit
dc47ba32f8
@ -315,13 +315,15 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
||||||
|
real_input_len = input_len - num_special_tokens
|
||||||
|
|
||||||
prefix_token_ids = (np.random.randint(
|
prefix_token_ids = (np.random.randint(
|
||||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||||
|
|
||||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||||
input_low = int(input_len * (1 - range_ratio))
|
input_low = int(real_input_len * (1 - range_ratio))
|
||||||
input_high = int(input_len * (1 + range_ratio))
|
input_high = int(real_input_len * (1 + range_ratio))
|
||||||
output_low = int(output_len * (1 - range_ratio))
|
output_low = int(output_len * (1 - range_ratio))
|
||||||
output_high = int(output_len * (1 + range_ratio))
|
output_high = int(output_len * (1 + range_ratio))
|
||||||
|
|
||||||
@ -344,6 +346,17 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
vocab_size).tolist()
|
vocab_size).tolist()
|
||||||
token_sequence = prefix_token_ids + inner_seq
|
token_sequence = prefix_token_ids + inner_seq
|
||||||
prompt = tokenizer.decode(token_sequence)
|
prompt = tokenizer.decode(token_sequence)
|
||||||
|
# After decoding the prompt we have to encode and decode it again.
|
||||||
|
# This is done because in some cases N consecutive tokens
|
||||||
|
# give a string tokenized into != N number of tokens.
|
||||||
|
# For example for GPT2Tokenizer:
|
||||||
|
# [6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||||
|
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||||
|
# To avoid uncontrolled change of the prompt length,
|
||||||
|
# the encoded sequence is truncated before being decode again.
|
||||||
|
re_encoded_sequence = tokenizer.encode(
|
||||||
|
prompt, add_special_tokens=False)[:input_lens[i]]
|
||||||
|
prompt = tokenizer.decode(re_encoded_sequence)
|
||||||
total_input_len = prefix_len + int(input_lens[i])
|
total_input_len = prefix_len + int(input_lens[i])
|
||||||
requests.append(
|
requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
|
Reference in New Issue
Block a user