mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: Andrew Feldman <afeld2012@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""This example demonstrates instantiating vLLM with a custom logits processor
|
|
class object.
|
|
|
|
For a basic example of implementing a custom logits processor, see
|
|
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
|
|
|
|
For testing purposes, a dummy logits processor is employed which, if
|
|
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
|
will mask out all tokens except `target_token`.
|
|
|
|
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
|
`target_token`, and for these requests - and *only* these requests - we
|
|
expect the `target_token` to be decoded in each step, yielding an output
|
|
similar to that shown below:
|
|
|
|
Generated Outputs:
|
|
------------------------------------------------------------
|
|
Prompt: 'Hello, my name is'
|
|
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
|
------------------------------------------------------------
|
|
Prompt: 'The president of the United States is'
|
|
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
|
------------------------------------------------------------
|
|
Prompt: 'The capital of France is'
|
|
Output: ' also also also also also also also also also also also also also
|
|
also also also'
|
|
------------------------------------------------------------
|
|
Prompt: 'The future of AI is'
|
|
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
|
------------------------------------------------------------
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.sample.logits_processor import (
|
|
BatchUpdate,
|
|
LogitsProcessor,
|
|
MoveDirectionality,
|
|
)
|
|
|
|
|
|
# Hypothetical custom logits processor
|
|
class DummyLogitsProcessor(LogitsProcessor):
|
|
"""Fake logit processor to support unit testing and examples"""
|
|
|
|
def __init__(
|
|
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
|
):
|
|
self.req_info: dict[int, SamplingParams] = {}
|
|
|
|
def is_argmax_invariant(self) -> bool:
|
|
"""Never impacts greedy sampling"""
|
|
return False
|
|
|
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
|
if not batch_update:
|
|
return
|
|
|
|
# Process added requests.
|
|
for index, params, _, _ in batch_update.added:
|
|
assert params is not None
|
|
if params.extra_args and (
|
|
target_token := params.extra_args.get("target_token")
|
|
):
|
|
self.req_info[index] = target_token
|
|
|
|
if self.req_info:
|
|
# Process removed requests.
|
|
for index in batch_update.removed:
|
|
self.req_info.pop(index, None)
|
|
|
|
# Process moved requests, unidirectional move (a->b) and swap
|
|
# (a<->b)
|
|
for adx, bdx, direct in batch_update.moved:
|
|
a_val = self.req_info.pop(adx, None)
|
|
b_val = self.req_info.pop(bdx, None)
|
|
if a_val is not None:
|
|
self.req_info[bdx] = a_val
|
|
if direct == MoveDirectionality.SWAP and b_val is not None:
|
|
self.req_info[adx] = b_val
|
|
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
if not self.req_info:
|
|
return logits
|
|
|
|
# Save target values before modification
|
|
rows_list = list(self.req_info.keys())
|
|
cols = torch.tensor(
|
|
[self.req_info[i] for i in rows_list],
|
|
dtype=torch.long,
|
|
device=logits.device,
|
|
)
|
|
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
|
values_to_keep = logits[rows, cols].clone()
|
|
|
|
# Mask all but target tokens
|
|
logits[rows] = float("-inf")
|
|
logits[rows, cols] = values_to_keep
|
|
|
|
return logits
|
|
|
|
|
|
# Sample prompts.
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
|
sampling_params_list = [
|
|
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
|
SamplingParams(temperature=0.0),
|
|
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
|
SamplingParams(temperature=0.0),
|
|
]
|
|
|
|
|
|
def main():
|
|
# Create an LLM.
|
|
llm = LLM(
|
|
model="facebook/opt-125m",
|
|
logits_processors=[DummyLogitsProcessor],
|
|
)
|
|
# Generate texts from the prompts.
|
|
# The output is a list of RequestOutput objects
|
|
# that contain the prompt, generated text, and other information.
|
|
outputs = llm.generate(prompts, sampling_params_list)
|
|
# Print the outputs.
|
|
print("\nGenerated Outputs:\n" + "-" * 60)
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}")
|
|
print(f"Output: {generated_text!r}")
|
|
print("-" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|