Fix random dataset mismatched token length with config. (#24937)

Signed-off-by: Weiliang Liu <weiliangl@nvidia.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
weiliang
2025-09-28 16:23:44 +08:00
committed by GitHub
parent 0efd540dbc
commit f4e4088c99

View File

@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]:
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
)
def gen_prompt_decode_to_target_len(
tokenizer: PreTrainedTokenizerBase,
token_sequence: list[int],
target_token_len: int,
max_retry: int = 10,
add_special_tokens: bool = False,
rng: Optional[np.random.Generator] = None,
) -> tuple[str, list[int]]:
"""
Ensure decoded-then-encoded prompt length matches the target token length.
This function decodes an initial token sequence to text and re-encodes it
, iteratively adjusting the token sequence length to match a target.
This is necessary because some tokenizers do not guarantee a 1:1 mapping
between consecutive tokens and the decoded-then-encoded sequence length.
For example, for GPT2Tokenizer:
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
Returns a tuple of the final prompt string and the adjusted token sequence.
"""
remain_num_try = max_retry
token_mismatch = 0
while True:
prompt = tokenizer.decode(token_sequence)
token_sequence = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
if remain_num_try <= 0:
if len(token_sequence) != target_token_len:
token_mismatch = len(token_sequence) - target_token_len
break
if len(token_sequence) == target_token_len:
break
elif len(token_sequence) < target_token_len:
if rng is not None:
extra_tokens = rng.integers(
0,
tokenizer.vocab_size,
size=target_token_len - len(token_sequence),
).tolist()
else:
extra_tokens = np.random.randint(
0,
tokenizer.vocab_size,
size=target_token_len - len(token_sequence),
).tolist()
token_sequence.extend(extra_tokens)
elif len(token_sequence) > target_token_len:
token_sequence = token_sequence[:target_token_len]
remain_num_try -= 1
return prompt, token_sequence, token_mismatch
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
@ -420,8 +476,9 @@ class RandomDataset(BenchmarkDataset):
vocab_size = tokenizer.vocab_size
requests = []
token_mismatch_total = 0
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
@ -430,6 +487,7 @@ class RandomDataset(BenchmarkDataset):
offset=int(offsets[i]),
index=i,
)
token_mismatch_total += token_mismatch
requests.append(
SampleRequest(
prompt=prompt,
@ -453,6 +511,18 @@ class RandomDataset(BenchmarkDataset):
)
)
requests = batch_requests
if token_mismatch_total != 0:
sign = "more" if token_mismatch_total > 0 else "fewer"
logger.warning(
"Across all generated prompts, there were %d %s tokens "
"than expected after decoding and re-encoding. This is "
"expected due to the imperfect nature of the sampling "
"procedure.",
abs(token_mismatch_total),
sign,
)
return requests
def get_prefix(
@ -530,7 +600,7 @@ class RandomDataset(BenchmarkDataset):
input_len: int,
offset: int,
index: int,
) -> tuple[str, int]:
) -> tuple[str, int, int]:
"""
Returns (prompt, total_input_len).
@ -549,15 +619,16 @@ class RandomDataset(BenchmarkDataset):
token_sequence = prefix_token_ids + inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
prompt = tokenizer.decode(token_sequence)
total_input_len = prefix_len + int(input_len)
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
return prompt, total_input_len
prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
tokenizer=tokenizer,
token_sequence=token_sequence,
target_token_len=total_input_len,
add_special_tokens=False,
rng=self._rng,
)
total_input_len = len(adjusted_token_sequence)
return prompt, total_input_len, token_mismatch
# -----------------------------------------------------------------------------
@ -873,8 +944,9 @@ class RandomMultiModalDataset(RandomDataset):
vocab_size = tokenizer.vocab_size
# Add synthetic multimodal items to each request
mm_requests = []
token_mismatch_total = 0
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
@ -883,6 +955,7 @@ class RandomMultiModalDataset(RandomDataset):
offset=int(offsets[i]),
index=i,
)
token_mismatch_total += token_mismatch
# Get multimodal item iterator for a given request
mm_item_iterator = self.get_mm_item_iterator(
min_num_mm_items,
@ -918,6 +991,18 @@ class RandomMultiModalDataset(RandomDataset):
request_id=request_id_prefix + str(i),
)
mm_requests.append(sample_request)
if token_mismatch_total != 0:
sign = "more" if token_mismatch_total > 0 else "fewer"
logger.warning(
"Across all generated prompts, there were %d %s tokens "
"than expected after decoding and re-encoding. This is "
"expected due to the imperfect nature of the sampling "
"procedure.",
abs(token_mismatch_total),
sign,
)
return mm_requests
# -----------------------------------------------------------------------------
@ -2694,27 +2779,23 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
# Generate random tokens
tokens = np.random.randint(
0, vocab_size, size=target_length).tolist()
text = tokenizer.decode(tokens)
re_encoded = tokenizer.encode(text, add_special_tokens=False)
if len(re_encoded) == target_length:
return re_encoded
elif len(re_encoded) < target_length:
# Recursively generate additional consistent tokens
needed = target_length - len(re_encoded)
extra_tokens = _generate_exact_length_tokens(needed)
return re_encoded + extra_tokens
else:
# Truncate to target length
return re_encoded[:target_length]
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
tokenizer=tokenizer,
token_sequence=tokens,
target_token_len=target_length,
add_special_tokens=False,
)
return adjusted_tokens, token_mismatch
requests = []
token_mismatch_total = 0
for _ in range(num_prefixes):
prefix_tokens = _generate_exact_length_tokens(prefix_len)
for _ in range(prompts_per_prefix):
suffix_tokens = _generate_exact_length_tokens(suffix_len)
suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501
token_mismatch_total += token_mistmatch
combined_tokens = prefix_tokens + suffix_tokens
prompt = tokenizer.decode(combined_tokens)
prompt_len = len(combined_tokens)
@ -2726,6 +2807,16 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
)
)
if token_mismatch_total != 0:
sign = "more" if token_mismatch_total > 0 else "fewer"
logger.warning(
"Across all generated prompts, there were %d %s tokens "
"than expected after decoding and re-encoding. This is "
"expected due to the imperfect nature of the sampling "
"procedure.",
abs(token_mismatch_total),
sign,
)
random.shuffle(requests)
return requests