[Bugfix] Fixed prompt length for random dataset (#17408)

Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
This commit is contained in:
Mikhail Podvitskii
2025-05-06 09:00:08 +02:00
committed by GitHub
parent edbf2d609e
commit dc47ba32f8

View File

@ -315,13 +315,15 @@ class RandomDataset(BenchmarkDataset):
) )
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint( prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)] # New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio)) input_low = int(real_input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio)) input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio)) output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio)) output_high = int(output_len * (1 + range_ratio))
@ -344,6 +346,17 @@ class RandomDataset(BenchmarkDataset):
vocab_size).tolist() vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence) prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i]) total_input_len = prefix_len + int(input_lens[i])
requests.append( requests.append(
SampleRequest( SampleRequest(