[Bugfix] Fix erroneous randomly generated cases in bad word testing (#22170)

Signed-off-by: phantomlei <phantomlei3@gmail.com>
This commit is contained in:
phantomlei
2025-08-12 17:03:22 +08:00
committed by GitHub
parent 8d17fa633e
commit bc8372efc3

View File

@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
return bad_words_token_ids
# Returns all last tokens of bad word sequences that share the same prefix
# as `given_prefix` (excluding the last token).
def _collect_suffixes_with_same_prefix(
given_prefix: list[int],
bad_words_token_ids: list[list[int]]) -> list[int]:
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
# generate a valid token id that is not in bad_words_token_ids
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
vocab_size: int) -> int:
forbidden_start_tokens = set()
for bad_word in bad_words_token_ids:
forbidden_start_tokens.add(bad_word[0])
# Get a safe token that's not in forbidden starts
safe_token_candidates = list(
set(range(vocab_size)) - forbidden_start_tokens)
# Pick a random safe token
return np.random.choice(safe_token_candidates)
def _update_output_token_ids_for_bad_words(
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
bad_words_last_tokens = {}
@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
prefix_length = len(bad_word_token_ids) - 1
has_bad_words = np.random.choice([True, False])
if has_bad_words:
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
bad_words_last_token.append(bad_word_token_ids[-1])
prefix = bad_word_token_ids[:-1]
output_token_ids[-prefix_length:] = prefix
# Collect all last tokens from other bad words
# that share this prefix
bad_words_last_token.extend(
_collect_suffixes_with_same_prefix(
prefix, bad_words_token_ids))
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
output_token_ids[-1] = (bad_word_token_ids[-2] +
1) % vocab_size
output_token_ids[-1] = _generate_valid_token_id(
bad_words_token_ids, vocab_size)
bad_words_last_tokens[batch_idx] = bad_words_last_token
return bad_words_last_tokens