243 lines
7.2 KiB
Python
243 lines
7.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
This example shows how to use vLLM for running offline inference with
|
|
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
|
"""
|
|
|
|
import time
|
|
from collections.abc import Sequence
|
|
from dataclasses import asdict
|
|
from typing import NamedTuple
|
|
|
|
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
|
from vllm.assets.audio import AudioAsset
|
|
from vllm.assets.image import ImageAsset
|
|
from vllm.multimodal.utils import fetch_image
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
|
class ModelRequestData(NamedTuple):
|
|
engine_args: EngineArgs
|
|
prompts: Sequence[PromptType]
|
|
|
|
|
|
def run_donut():
|
|
engine_args = EngineArgs(
|
|
model="naver-clova-ix/donut-base-finetuned-docvqa",
|
|
max_num_seqs=2,
|
|
limit_mm_per_prompt={"image": 1},
|
|
dtype="float16",
|
|
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
|
)
|
|
|
|
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
|
|
# and the patch_size is 4 x 4.
|
|
# Therefore, the initial number of patches is:
|
|
# Height: 1920 / 4 = 480 patches
|
|
# Width: 2560 / 4 = 640 patches
|
|
# The Swin model uses a staged downsampling approach,
|
|
# defined by the "depths": [2, 2, 14, 2] configuration.
|
|
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
|
# which halves the feature map's dimensions (dividing both height and width by 2).
|
|
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
|
|
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
|
|
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
|
|
# Because vLLM needs to fill the image features with an encoder_prompt,
|
|
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
|
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
|
|
prompts = [
|
|
{
|
|
"encoder_prompt": {
|
|
"prompt": "".join(["$"] * 4799),
|
|
"multi_modal_data": {
|
|
"image": fetch_image(
|
|
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
|
) # noqa: E501
|
|
},
|
|
},
|
|
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
|
|
},
|
|
]
|
|
|
|
return ModelRequestData(
|
|
engine_args=engine_args,
|
|
prompts=prompts,
|
|
)
|
|
|
|
|
|
def run_florence2():
|
|
engine_args = EngineArgs(
|
|
model="microsoft/Florence-2-large",
|
|
tokenizer="Isotr0py/Florence-2-tokenizer",
|
|
max_num_seqs=8,
|
|
trust_remote_code=True,
|
|
limit_mm_per_prompt={"image": 1},
|
|
dtype="half",
|
|
)
|
|
|
|
prompts = [
|
|
{ # implicit prompt with task token
|
|
"prompt": "<DETAILED_CAPTION>",
|
|
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
|
|
},
|
|
{ # explicit encoder/decoder prompt
|
|
"encoder_prompt": {
|
|
"prompt": "Describe in detail what is shown in the image.",
|
|
"multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
|
|
},
|
|
"decoder_prompt": "",
|
|
},
|
|
]
|
|
|
|
return ModelRequestData(
|
|
engine_args=engine_args,
|
|
prompts=prompts,
|
|
)
|
|
|
|
|
|
def run_mllama():
|
|
engine_args = EngineArgs(
|
|
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
|
max_model_len=8192,
|
|
max_num_seqs=2,
|
|
limit_mm_per_prompt={"image": 1},
|
|
dtype="half",
|
|
)
|
|
|
|
prompts = [
|
|
{ # Implicit prompt
|
|
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
|
|
"multi_modal_data": {
|
|
"image": ImageAsset("stop_sign").pil_image,
|
|
},
|
|
},
|
|
{ # Explicit prompt
|
|
"encoder_prompt": {
|
|
"prompt": "<|image|>",
|
|
"multi_modal_data": {
|
|
"image": ImageAsset("stop_sign").pil_image,
|
|
},
|
|
},
|
|
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
|
|
},
|
|
]
|
|
|
|
return ModelRequestData(
|
|
engine_args=engine_args,
|
|
prompts=prompts,
|
|
)
|
|
|
|
|
|
def run_whisper():
|
|
engine_args = EngineArgs(
|
|
model="openai/whisper-large-v3-turbo",
|
|
max_model_len=448,
|
|
max_num_seqs=16,
|
|
limit_mm_per_prompt={"audio": 1},
|
|
dtype="half",
|
|
)
|
|
|
|
prompts = [
|
|
{ # Test implicit prompt
|
|
"prompt": "<|startoftranscript|>",
|
|
"multi_modal_data": {
|
|
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
|
},
|
|
},
|
|
{ # Test explicit encoder/decoder prompt
|
|
"encoder_prompt": {
|
|
"prompt": "",
|
|
"multi_modal_data": {
|
|
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
|
},
|
|
},
|
|
"decoder_prompt": "<|startoftranscript|>",
|
|
},
|
|
]
|
|
|
|
return ModelRequestData(
|
|
engine_args=engine_args,
|
|
prompts=prompts,
|
|
)
|
|
|
|
|
|
model_example_map = {
|
|
"donut": run_donut,
|
|
"florence2": run_florence2,
|
|
"mllama": run_mllama,
|
|
"whisper": run_whisper,
|
|
}
|
|
|
|
|
|
def parse_args():
|
|
parser = FlexibleArgumentParser(
|
|
description="Demo on using vLLM for offline inference with "
|
|
"vision language models for text generation"
|
|
)
|
|
parser.add_argument(
|
|
"--model-type",
|
|
"-m",
|
|
type=str,
|
|
default="mllama",
|
|
choices=model_example_map.keys(),
|
|
help='Huggingface "model_type".',
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Set the seed when initializing `vllm.LLM`.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main(args):
|
|
model = args.model_type
|
|
if model not in model_example_map:
|
|
raise ValueError(f"Model type {model} is not supported.")
|
|
|
|
req_data = model_example_map[model]()
|
|
|
|
# Disable other modalities to save memory
|
|
default_limits = {"image": 0, "video": 0, "audio": 0}
|
|
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
|
req_data.engine_args.limit_mm_per_prompt or {}
|
|
)
|
|
|
|
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
|
llm = LLM(**engine_args)
|
|
|
|
prompts = req_data.prompts
|
|
|
|
# Create a sampling params object.
|
|
sampling_params = SamplingParams(
|
|
temperature=0,
|
|
top_p=1.0,
|
|
max_tokens=64,
|
|
skip_special_tokens=False,
|
|
)
|
|
|
|
start = time.time()
|
|
|
|
# Generate output tokens from the prompts. The output is a list of
|
|
# RequestOutput objects that contain the prompt, generated
|
|
# text, and other information.
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
duration = time.time() - start
|
|
|
|
print("Duration:", duration)
|
|
print("RPS:", len(prompts) / duration)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|