mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix random dataset mismatched token length with config. (#24937)
Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]:
|
||||
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
def gen_prompt_decode_to_target_len(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
token_sequence: list[int],
|
||||
target_token_len: int,
|
||||
max_retry: int = 10,
|
||||
add_special_tokens: bool = False,
|
||||
rng: Optional[np.random.Generator] = None,
|
||||
) -> tuple[str, list[int]]:
|
||||
"""
|
||||
Ensure decoded-then-encoded prompt length matches the target token length.
|
||||
|
||||
This function decodes an initial token sequence to text and re-encodes it
|
||||
, iteratively adjusting the token sequence length to match a target.
|
||||
This is necessary because some tokenizers do not guarantee a 1:1 mapping
|
||||
between consecutive tokens and the decoded-then-encoded sequence length.
|
||||
For example, for GPT2Tokenizer:
|
||||
[6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
|
||||
Returns a tuple of the final prompt string and the adjusted token sequence.
|
||||
"""
|
||||
remain_num_try = max_retry
|
||||
token_mismatch = 0
|
||||
while True:
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
token_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
if remain_num_try <= 0:
|
||||
if len(token_sequence) != target_token_len:
|
||||
token_mismatch = len(token_sequence) - target_token_len
|
||||
break
|
||||
|
||||
if len(token_sequence) == target_token_len:
|
||||
break
|
||||
elif len(token_sequence) < target_token_len:
|
||||
if rng is not None:
|
||||
extra_tokens = rng.integers(
|
||||
0,
|
||||
tokenizer.vocab_size,
|
||||
size=target_token_len - len(token_sequence),
|
||||
).tolist()
|
||||
else:
|
||||
extra_tokens = np.random.randint(
|
||||
0,
|
||||
tokenizer.vocab_size,
|
||||
size=target_token_len - len(token_sequence),
|
||||
).tolist()
|
||||
token_sequence.extend(extra_tokens)
|
||||
elif len(token_sequence) > target_token_len:
|
||||
token_sequence = token_sequence[:target_token_len]
|
||||
|
||||
remain_num_try -= 1
|
||||
|
||||
return prompt, token_sequence, token_mismatch
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||
@ -420,8 +476,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
requests = []
|
||||
token_mismatch_total = 0
|
||||
for i in range(num_requests):
|
||||
prompt, total_input_len = self.generate_token_sequence(
|
||||
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=prefix_token_ids,
|
||||
prefix_len=prefix_len,
|
||||
@ -430,6 +487,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
offset=int(offsets[i]),
|
||||
index=i,
|
||||
)
|
||||
token_mismatch_total += token_mismatch
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
@ -453,6 +511,18 @@ class RandomDataset(BenchmarkDataset):
|
||||
)
|
||||
)
|
||||
requests = batch_requests
|
||||
|
||||
if token_mismatch_total != 0:
|
||||
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||
logger.warning(
|
||||
"Across all generated prompts, there were %d %s tokens "
|
||||
"than expected after decoding and re-encoding. This is "
|
||||
"expected due to the imperfect nature of the sampling "
|
||||
"procedure.",
|
||||
abs(token_mismatch_total),
|
||||
sign,
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
def get_prefix(
|
||||
@ -530,7 +600,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_len: int,
|
||||
offset: int,
|
||||
index: int,
|
||||
) -> tuple[str, int]:
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Returns (prompt, total_input_len).
|
||||
|
||||
@ -549,15 +619,16 @@ class RandomDataset(BenchmarkDataset):
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
|
||||
# Decode, then re-encode and truncate to preserve token count invariants
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
total_input_len = prefix_len + int(input_len)
|
||||
|
||||
re_encoded_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=False)[:total_input_len]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = len(re_encoded_sequence)
|
||||
|
||||
return prompt, total_input_len
|
||||
prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
|
||||
tokenizer=tokenizer,
|
||||
token_sequence=token_sequence,
|
||||
target_token_len=total_input_len,
|
||||
add_special_tokens=False,
|
||||
rng=self._rng,
|
||||
)
|
||||
total_input_len = len(adjusted_token_sequence)
|
||||
return prompt, total_input_len, token_mismatch
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -873,8 +944,9 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Add synthetic multimodal items to each request
|
||||
mm_requests = []
|
||||
token_mismatch_total = 0
|
||||
for i in range(num_requests):
|
||||
prompt, total_input_len = self.generate_token_sequence(
|
||||
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=prefix_token_ids,
|
||||
prefix_len=prefix_len,
|
||||
@ -883,6 +955,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
offset=int(offsets[i]),
|
||||
index=i,
|
||||
)
|
||||
token_mismatch_total += token_mismatch
|
||||
# Get multimodal item iterator for a given request
|
||||
mm_item_iterator = self.get_mm_item_iterator(
|
||||
min_num_mm_items,
|
||||
@ -918,6 +991,18 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
mm_requests.append(sample_request)
|
||||
|
||||
if token_mismatch_total != 0:
|
||||
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||
logger.warning(
|
||||
"Across all generated prompts, there were %d %s tokens "
|
||||
"than expected after decoding and re-encoding. This is "
|
||||
"expected due to the imperfect nature of the sampling "
|
||||
"procedure.",
|
||||
abs(token_mismatch_total),
|
||||
sign,
|
||||
)
|
||||
|
||||
return mm_requests
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -2694,27 +2779,23 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
# Generate random tokens
|
||||
tokens = np.random.randint(
|
||||
0, vocab_size, size=target_length).tolist()
|
||||
text = tokenizer.decode(tokens)
|
||||
re_encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(re_encoded) == target_length:
|
||||
return re_encoded
|
||||
elif len(re_encoded) < target_length:
|
||||
# Recursively generate additional consistent tokens
|
||||
needed = target_length - len(re_encoded)
|
||||
extra_tokens = _generate_exact_length_tokens(needed)
|
||||
return re_encoded + extra_tokens
|
||||
else:
|
||||
# Truncate to target length
|
||||
return re_encoded[:target_length]
|
||||
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
|
||||
tokenizer=tokenizer,
|
||||
token_sequence=tokens,
|
||||
target_token_len=target_length,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return adjusted_tokens, token_mismatch
|
||||
|
||||
requests = []
|
||||
token_mismatch_total = 0
|
||||
for _ in range(num_prefixes):
|
||||
prefix_tokens = _generate_exact_length_tokens(prefix_len)
|
||||
|
||||
for _ in range(prompts_per_prefix):
|
||||
suffix_tokens = _generate_exact_length_tokens(suffix_len)
|
||||
|
||||
suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501
|
||||
token_mismatch_total += token_mistmatch
|
||||
combined_tokens = prefix_tokens + suffix_tokens
|
||||
prompt = tokenizer.decode(combined_tokens)
|
||||
prompt_len = len(combined_tokens)
|
||||
@ -2726,6 +2807,16 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
)
|
||||
)
|
||||
|
||||
if token_mismatch_total != 0:
|
||||
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||
logger.warning(
|
||||
"Across all generated prompts, there were %d %s tokens "
|
||||
"than expected after decoding and re-encoding. This is "
|
||||
"expected due to the imperfect nature of the sampling "
|
||||
"procedure.",
|
||||
abs(token_mismatch_total),
|
||||
sign,
|
||||
)
|
||||
random.shuffle(requests)
|
||||
return requests
|
||||
|
||||
|
Reference in New Issue
Block a user