Files
vllm/vllm/logits_process.py
2025-10-12 09:51:31 -07:00

122 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from typing import TypeAlias
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer
LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor]
| Callable[[list[int], list[int], torch.Tensor], torch.Tensor]
)
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer
) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != bad_words_ids[-1][0]
and len(prompt_token_ids) == len(bad_words_ids[-1])
):
bad_words_ids.append(prompt_token_ids)
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: list[list[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None
def __call__(
self,
past_tokens_ids: Sequence[int],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
self._init_word_bias(logits=logits)
last_token_bias = torch.zeros_like(logits)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1: # 1-token words already processed
continue
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:]
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
is_match = tuple(actual_prefix) == tuple(expected_prefix)
last_token_bias[last_token_id] += (
self._SMALLEST_LOGIT if is_match else self._NEUTRAL_LOGIT
)
logits = logits + self.word_bias + last_token_bias
return logits
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
vocab_size = logits.shape[-1]
self._check_token_ids_bounds(vocab_size=vocab_size)
self.word_bias = torch.zeros(
(vocab_size,), dtype=torch.float, device=logits.device
)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1:
bad_word_id = bad_word_ids[-1]
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
def _check_token_ids_bounds(self, vocab_size: int) -> None:
invalid_token_ids = []
for bad_word_ids in self.bad_words_ids:
for token_id in bad_word_ids:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {vocab_size},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id < {vocab_size}."
)