mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix erroneous randomly generated cases in bad word testing (#22170)
Signed-off-by: phantomlei <phantomlei3@gmail.com>
This commit is contained in:
@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
|
||||
return bad_words_token_ids
|
||||
|
||||
|
||||
# Returns all last tokens of bad word sequences that share the same prefix
|
||||
# as `given_prefix` (excluding the last token).
|
||||
def _collect_suffixes_with_same_prefix(
|
||||
given_prefix: list[int],
|
||||
bad_words_token_ids: list[list[int]]) -> list[int]:
|
||||
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
|
||||
|
||||
|
||||
# generate a valid token id that is not in bad_words_token_ids
|
||||
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
|
||||
vocab_size: int) -> int:
|
||||
forbidden_start_tokens = set()
|
||||
for bad_word in bad_words_token_ids:
|
||||
forbidden_start_tokens.add(bad_word[0])
|
||||
# Get a safe token that's not in forbidden starts
|
||||
safe_token_candidates = list(
|
||||
set(range(vocab_size)) - forbidden_start_tokens)
|
||||
# Pick a random safe token
|
||||
return np.random.choice(safe_token_candidates)
|
||||
|
||||
|
||||
def _update_output_token_ids_for_bad_words(
|
||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||
bad_words_last_tokens = {}
|
||||
@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
|
||||
prefix_length = len(bad_word_token_ids) - 1
|
||||
has_bad_words = np.random.choice([True, False])
|
||||
if has_bad_words:
|
||||
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
|
||||
bad_words_last_token.append(bad_word_token_ids[-1])
|
||||
prefix = bad_word_token_ids[:-1]
|
||||
output_token_ids[-prefix_length:] = prefix
|
||||
# Collect all last tokens from other bad words
|
||||
# that share this prefix
|
||||
bad_words_last_token.extend(
|
||||
_collect_suffixes_with_same_prefix(
|
||||
prefix, bad_words_token_ids))
|
||||
break # Maximum one update to output_token_ids
|
||||
else: # Make sure no accidental match to bad words
|
||||
output_token_ids[-1] = (bad_word_token_ids[-2] +
|
||||
1) % vocab_size
|
||||
output_token_ids[-1] = _generate_valid_token_id(
|
||||
bad_words_token_ids, vocab_size)
|
||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||
return bad_words_last_tokens
|
||||
|
||||
|
Reference in New Issue
Block a user