Convert examples
to ruff-format
(#18400)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -17,7 +17,7 @@ repos:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix]
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks)/.*
|
||||
files: ^(.buildkite|benchmarks|examples)/.*
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
|
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on audio language models.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
@ -22,7 +23,7 @@ audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
question_per_audio_count = {
|
||||
0: "What is 1+1?",
|
||||
1: "What is recited in the audio?",
|
||||
2: "What sport and what nursery rhyme are referenced?"
|
||||
2: "What sport and what nursery rhyme are referenced?",
|
||||
}
|
||||
|
||||
|
||||
@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# MiniCPM-O
|
||||
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
audio_placeholder = "(<audio>./</audio>)" * audio_count
|
||||
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f'{audio_placeholder}\n{question}'
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=audio_chat_template)
|
||||
messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=audio_chat_template,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
|
||||
placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
|
||||
|
||||
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join([
|
||||
f"Audio {idx+1}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
|
||||
])
|
||||
audio_in_prompt = "".join(
|
||||
[
|
||||
f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
|
||||
for idx in range(audio_count)
|
||||
]
|
||||
)
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int):
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join([
|
||||
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
|
||||
])
|
||||
audio_in_prompt = "".join(
|
||||
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
|
||||
)
|
||||
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech.")
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': "<|audio|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, (
|
||||
"Whisper only support single audio input per prompt")
|
||||
assert audio_count == 1, "Whisper only support single audio input per prompt"
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
|
||||
prompt = "<|startoftranscript|>"
|
||||
@ -252,27 +254,33 @@ model_example_map = {
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1, help="Number of prompts to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -283,29 +291,30 @@ def main(args):
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
audio_count = args.num_audios
|
||||
req_data = model_example_map[model](question_per_audio_count[audio_count],
|
||||
audio_count)
|
||||
req_data = model_example_map[model](
|
||||
question_per_audio_count[audio_count], audio_count
|
||||
)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2,
|
||||
max_tokens=64,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate
|
||||
for asset in audio_assets[:audio_count]
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
|
||||
@ -315,8 +324,9 @@ def main(args):
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
# Add LoRA request if applicable
|
||||
lora_request = (req_data.lora_requests *
|
||||
args.num_prompts if req_data.lora_requests else None)
|
||||
lora_request = (
|
||||
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
|
@ -16,13 +16,16 @@ but ask different questions.
|
||||
Run:
|
||||
python examples/offline_inference/automatic_prefix_caching.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# ruff: noqa: E501
|
||||
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
|
||||
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
|
||||
LONG_PROMPT = (
|
||||
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
|
||||
+ """
|
||||
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|
||||
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
|
||||
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
|
||||
@ -56,6 +59,7 @@ LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables i
|
||||
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
|
||||
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_generation_time(llm, sampling_params, prompts):
|
||||
@ -72,7 +76,7 @@ def get_generation_time(llm, sampling_params, prompts):
|
||||
|
||||
def main():
|
||||
# set enable_prefix_caching=True to enable APC
|
||||
llm = LLM(model='lmsys/longchat-13b-16k', enable_prefix_caching=True)
|
||||
llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
@ -80,8 +84,8 @@ def main():
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT +
|
||||
"Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||
LONG_PROMPT
|
||||
+ "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||
)
|
||||
|
||||
# Querying the age of Zack Blue
|
||||
@ -89,8 +93,8 @@ def main():
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT +
|
||||
"Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||
LONG_PROMPT
|
||||
+ "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||
)
|
||||
|
||||
|
||||
|
@ -56,22 +56,12 @@ def main(args: dict):
|
||||
|
||||
# In this script, we demonstrate how to pass input to the chat method:
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":
|
||||
"Write an essay about the importance of higher education.",
|
||||
"content": "Write an essay about the importance of higher education.",
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
||||
|
@ -10,9 +10,9 @@ def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True)
|
||||
parser.set_defaults(
|
||||
model="jason9693/Qwen2.5-1.5B-apeach", task="classify", enforce_eager=True
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -36,10 +36,11 @@ def main(args: Namespace):
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
probs = output.outputs.probs
|
||||
probs_trimmed = ((str(probs[:16])[:-1] +
|
||||
", ...]") if len(probs) > 16 else probs)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
|
||||
probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
|
||||
print(
|
||||
f"Prompt: {prompt!r} \n"
|
||||
f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
|
@ -10,9 +10,9 @@ def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True)
|
||||
parser.set_defaults(
|
||||
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -36,10 +36,10 @@ def main(args: Namespace):
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
|
@ -10,9 +10,9 @@ def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True)
|
||||
parser.set_defaults(
|
||||
model="BAAI/bge-reranker-v2-m3", task="score", enforce_eager=True
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -17,12 +17,14 @@ Ray Data provides functionality for:
|
||||
Learn more about Ray Data's LLM integration:
|
||||
https://docs.ray.io/en/latest/data/working-with-llms.html
|
||||
"""
|
||||
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
|
||||
|
||||
assert Version(ray.__version__) >= Version(
|
||||
"2.44.1"), "Ray version must be at least 2.44.1"
|
||||
assert Version(ray.__version__) >= Version("2.44.1"), (
|
||||
"Ray version must be at least 2.44.1"
|
||||
)
|
||||
|
||||
# Uncomment to reduce clutter in stdout
|
||||
# ray.init(log_to_driver=False)
|
||||
@ -53,20 +55,18 @@ config = vLLMEngineProcessorConfig(
|
||||
vllm_processor = build_llm_processor(
|
||||
config,
|
||||
preprocess=lambda row: dict(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a bot that responds with haikus."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": row["text"]
|
||||
}],
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a bot that responds with haikus."},
|
||||
{"role": "user", "content": row["text"]},
|
||||
],
|
||||
sampling_params=dict(
|
||||
temperature=0.3,
|
||||
max_tokens=250,
|
||||
)),
|
||||
),
|
||||
),
|
||||
postprocess=lambda row: dict(
|
||||
answer=row["generated_text"],
|
||||
**row # This will return all the original columns in the dataset.
|
||||
**row, # This will return all the original columns in the dataset.
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -50,87 +50,93 @@ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# or any other mistral model with function calling ability
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
|
||||
llm = LLM(model=model_name,
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral")
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
)
|
||||
|
||||
|
||||
def generate_random_id(length=9):
|
||||
characters = string.ascii_letters + string.digits
|
||||
random_id = ''.join(random.choice(characters) for _ in range(length))
|
||||
random_id = "".join(random.choice(characters) for _ in range(length))
|
||||
return random_id
|
||||
|
||||
|
||||
# simulate an API that can be called
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
def get_current_weather(city: str, state: str, unit: "str"):
|
||||
return (
|
||||
f"The weather in {city}, {state} is 85 degrees {unit}. It is "
|
||||
"partly cloudly, with highs in the 90's."
|
||||
)
|
||||
|
||||
|
||||
tool_functions = {"get_current_weather": get_current_weather}
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
|
||||
}
|
||||
]
|
||||
|
||||
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
|
||||
output = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# append the assistant message
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": output,
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output,
|
||||
}
|
||||
)
|
||||
|
||||
# let's now actually parse and execute the model's output simulating an API call by using the
|
||||
# above defined function
|
||||
tool_calls = json.loads(output)
|
||||
tool_answers = [
|
||||
tool_functions[call['name']](**call['arguments']) for call in tool_calls
|
||||
tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
|
||||
]
|
||||
|
||||
# append the answer as a tool message and let the LLM give you an answer
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": "\n\n".join(tool_answers),
|
||||
"tool_call_id": generate_random_id(),
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "\n\n".join(tool_answers),
|
||||
"tool_call_id": generate_random_id(),
|
||||
}
|
||||
)
|
||||
|
||||
outputs = llm.chat(messages, sampling_params, tools=tools)
|
||||
|
||||
|
@ -27,6 +27,7 @@ Multi-node:
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
"""
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
@ -36,46 +37,46 @@ from vllm.utils import get_open_port
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
parser.add_argument("--enforce-eager",
|
||||
action='store_true',
|
||||
help="Enforce eager mode execution.")
|
||||
parser.add_argument("--trust-remote-code",
|
||||
action='store_true',
|
||||
help="Trust remote code.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path",
|
||||
)
|
||||
parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
|
||||
parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument(
|
||||
"--node-size", type=int, default=1, help="Total number of nodes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-rank", type=int, default=0, help="Rank of the current node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master-addr", type=str, default="", help="Master node IP address"
|
||||
)
|
||||
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
|
||||
parser.add_argument(
|
||||
"--enforce-eager", action="store_true", help="Enforce eager mode execution."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code", action="store_true", help="Trust remote code."
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
|
||||
def main(
|
||||
model,
|
||||
dp_size,
|
||||
local_dp_rank,
|
||||
global_dp_rank,
|
||||
dp_master_ip,
|
||||
dp_master_port,
|
||||
GPUs_per_dp_rank,
|
||||
enforce_eager,
|
||||
trust_remote_code,
|
||||
):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
@ -110,9 +111,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
# since we are doing data parallel, every rank can have different
|
||||
# sampling params. here we set different max_tokens for different
|
||||
# ranks for demonstration.
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=[16, 20][global_dp_rank % 2])
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
|
||||
)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
@ -130,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
break
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(
|
||||
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
# Give engines time to pause their processing loops before exiting.
|
||||
sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
|
||||
dp_size = args.dp_size
|
||||
@ -160,20 +162,29 @@ if __name__ == "__main__":
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
|
||||
proc = Process(target=main,
|
||||
args=(args.model, dp_size, local_dp_rank,
|
||||
global_dp_rank, dp_master_ip, dp_master_port,
|
||||
tp_size, args.enforce_eager,
|
||||
args.trust_remote_code))
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
):
|
||||
proc = Process(
|
||||
target=main,
|
||||
args=(
|
||||
args.model,
|
||||
dp_size,
|
||||
local_dp_rank,
|
||||
global_dp_rank,
|
||||
dp_master_ip,
|
||||
dp_master_port,
|
||||
tp_size,
|
||||
args.enforce_eager,
|
||||
args.trust_remote_code,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join(timeout=300)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that "
|
||||
f"didn't stop within 5 minutes.")
|
||||
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
exit_code = 1
|
||||
elif proc.exitcode:
|
||||
|
@ -22,17 +22,18 @@ def main():
|
||||
prompts = read_prompts()
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
@ -20,15 +20,16 @@ def main():
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
|
@ -4,6 +4,7 @@ This file demonstrates the example usage of disaggregated prefilling
|
||||
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
|
||||
and then transfer the KV cache between them.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
@ -32,17 +33,21 @@ def run_prefill(prefill_done):
|
||||
# This instance is the prefill node (kv_producer, rank 0).
|
||||
# The number of parallel instances for KV cache transfer is set to 2,
|
||||
# as required for PyNcclConnector.
|
||||
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2)
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="PyNcclConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
|
||||
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
||||
# memory. You may need to adjust the value to fit your GPU.
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
llm.generate(prompts, sampling_params)
|
||||
print("Prefill node is finished.")
|
||||
@ -72,17 +77,21 @@ def run_decode(prefill_done):
|
||||
# This instance is the decode node (kv_consumer, rank 1).
|
||||
# The number of parallel instances for KV cache transfer is set to 2,
|
||||
# as required for PyNcclConnector.
|
||||
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2)
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="PyNcclConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
|
||||
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
||||
# memory. You may need to adjust the value to fit your GPU.
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
# Wait for the producer to start the pipe
|
||||
print("Waiting for prefill node to finish...")
|
||||
@ -99,8 +108,8 @@ def run_decode(prefill_done):
|
||||
|
||||
def main():
|
||||
prefill_done = Event()
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done,))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done,))
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
@ -20,9 +20,7 @@ def load_prompts(dataset_path, num_prompts):
|
||||
print(f"Error reading dataset: {e}")
|
||||
return []
|
||||
else:
|
||||
prompts = [
|
||||
"The future of AI is", "The president of the United States is"
|
||||
]
|
||||
prompts = ["The future of AI is", "The president of the United States is"]
|
||||
|
||||
return prompts[:num_prompts]
|
||||
|
||||
@ -33,34 +31,32 @@ def parse_args():
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="./examples/data/gsm8k.jsonl",
|
||||
help="downloaded from the eagle repo " \
|
||||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
|
||||
help="downloaded from the eagle repo "
|
||||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method", type=str, default="eagle", choices=["eagle", "eagle3"]
|
||||
)
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default='eagle',
|
||||
choices=['eagle', 'eagle3'])
|
||||
parser.add_argument("--max_num_seqs", type=int, default=8)
|
||||
parser.add_argument("--num_prompts", type=int, default=80)
|
||||
parser.add_argument("--num_spec_tokens", type=int, default=2)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--draft_tp", type=int, default=1)
|
||||
parser.add_argument("--enforce_eager", action='store_true')
|
||||
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
||||
parser.add_argument("--enforce_eager", action="store_true")
|
||||
parser.add_argument("--enable_chunked_prefill", action="store_true")
|
||||
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
if args.method == 'eagle':
|
||||
if args.method == "eagle":
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
elif args.method == 'eagle3':
|
||||
elif args.method == "eagle3":
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
@ -72,11 +68,9 @@ def main():
|
||||
prompts = load_prompts(args.dataset, args.num_prompts)
|
||||
|
||||
prompt_ids = [
|
||||
tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True)
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}], add_generation_prompt=True
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
@ -102,8 +96,7 @@ def main():
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
|
||||
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids,
|
||||
sampling_params=sampling_params)
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
for output in outputs:
|
||||
@ -120,19 +113,22 @@ def main():
|
||||
# accepted
|
||||
acceptance_counts = [0] * (args.num_spec_tokens + 1)
|
||||
for output in outputs:
|
||||
for step, count in enumerate(
|
||||
output.metrics.spec_token_acceptance_counts):
|
||||
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
|
||||
acceptance_counts[step] += count
|
||||
|
||||
print("-" * 50)
|
||||
print(f"mean acceptance length (including bonus tokens): \
|
||||
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}")
|
||||
print(
|
||||
f"mean acceptance length (including bonus tokens): \
|
||||
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
# print acceptance at each token position
|
||||
for i in range(len(acceptance_counts)):
|
||||
print(f"acceptance at token {i}:"
|
||||
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")
|
||||
print(
|
||||
f"acceptance at token {i}:"
|
||||
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -10,9 +10,9 @@ def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
parser.set_defaults(
|
||||
model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -41,11 +41,14 @@ def main(args: Namespace):
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings for text matching: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})")
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(
|
||||
f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings for text matching: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
|
@ -10,9 +10,9 @@ def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
parser.set_defaults(
|
||||
model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -39,11 +39,10 @@ def main(args: Namespace):
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})")
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
|
@ -1,12 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
'''
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically BART
|
||||
'''
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
TokensPrompt, zip_enc_dec_prompts)
|
||||
from vllm.inputs import (
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
zip_enc_dec_prompts,
|
||||
)
|
||||
|
||||
|
||||
def create_prompts(tokenizer):
|
||||
@ -18,8 +22,9 @@ def create_prompts(tokenizer):
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
|
||||
)
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
@ -57,14 +62,19 @@ def create_prompts(tokenizer):
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
["An encoder prompt", "Another encoder prompt"],
|
||||
["A decoder prompt", "Another decoder prompt"],
|
||||
)
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
single_text_prompt_raw,
|
||||
single_text_prompt,
|
||||
single_tokens_prompt,
|
||||
enc_dec_prompt1,
|
||||
enc_dec_prompt2,
|
||||
enc_dec_prompt3,
|
||||
] + zipped_prompt_list
|
||||
|
||||
|
||||
@ -85,10 +95,12 @@ def print_outputs(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i+1}:")
|
||||
print(f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Output {i + 1}:")
|
||||
print(
|
||||
f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict
|
||||
@ -30,18 +31,14 @@ def run_florence2():
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # implicit prompt with task token
|
||||
{ # implicit prompt with task token
|
||||
"prompt": "<DETAILED_CAPTION>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
},
|
||||
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
|
||||
},
|
||||
{ # explicit encoder/decoder prompt
|
||||
{ # explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "Describe in detail what is shown in the image.",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("cherry_blossom").pil_image
|
||||
},
|
||||
"multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
|
||||
},
|
||||
"decoder_prompt": "",
|
||||
},
|
||||
@ -63,20 +60,20 @@ def run_mllama():
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # Implicit prompt
|
||||
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
|
||||
{ # Implicit prompt
|
||||
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image,
|
||||
},
|
||||
},
|
||||
{ # Explicit prompt
|
||||
{ # Explicit prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "<|image|>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
|
||||
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
@ -96,13 +93,13 @@ def run_whisper():
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # Test implicit prompt
|
||||
{ # Test implicit prompt
|
||||
"prompt": "<|startoftranscript|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
@ -110,7 +107,7 @@ def run_whisper():
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -128,18 +125,23 @@ model_example_map = {
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for text generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -153,7 +155,8 @@ def main(args):
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
@ -179,8 +182,7 @@ def main(args):
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Decoder prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
This file demonstrates using the `LLMEngine`
|
||||
for processing prompts with various sampling parameters.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
@ -12,24 +13,26 @@ from vllm.utils import FlexibleArgumentParser
|
||||
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0.1)),
|
||||
(
|
||||
"A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
|
||||
),
|
||||
(
|
||||
"What is the meaning of life?",
|
||||
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams]]):
|
||||
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print('-' * 50)
|
||||
print("-" * 50)
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine,
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
print('-' * 50)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
description="Demo on using the LLMEngine class directly"
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -64,6 +68,6 @@ def main(args: argparse.Namespace):
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -36,22 +36,21 @@ def parse_args():
|
||||
parser.set_defaults(load_format="sharded_state")
|
||||
|
||||
# Add validation arguments
|
||||
parser.add_argument("--prompt",
|
||||
type=str,
|
||||
default="Hello, world!",
|
||||
help="Prompt for validation")
|
||||
parser.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature")
|
||||
parser.add_argument("--top-p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Top-p sampling parameter")
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="Hello, world!", help="Prompt for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.7, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=1.0, help="Top-p sampling parameter"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -60,8 +59,9 @@ def main():
|
||||
args = parse_args()
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
print(f"Loading model from {engine_args.model} "
|
||||
f"using format {engine_args.load_format}")
|
||||
print(
|
||||
f"Loading model from {engine_args.model} using format {engine_args.load_format}"
|
||||
)
|
||||
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
|
||||
|
||||
# Load the model using engine args
|
||||
@ -90,4 +90,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str
|
||||
lora_path: str,
|
||||
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||
return [
|
||||
# this is an example of using quantization without LoRA
|
||||
("My name is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128), None),
|
||||
(
|
||||
"My name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
None,
|
||||
),
|
||||
# the next three examples use quantization with LoRA
|
||||
("my name is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-1", 1, lora_path)),
|
||||
("The capital of USA is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-2", 1, lora_path)),
|
||||
("The capital of France is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-3", 1, lora_path)),
|
||||
(
|
||||
"my name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-1", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of USA is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-2", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of France is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-3", 1, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]]):
|
||||
def process_requests(
|
||||
engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
engine.add_request(
|
||||
str(request_id), prompt, sampling_params, lora_request=lora_request
|
||||
)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
|
||||
print(f"Output: {request_output.outputs[0].text}")
|
||||
|
||||
|
||||
def initialize_engine(model: str, quantization: str,
|
||||
lora_repo: Optional[str]) -> LLMEngine:
|
||||
def initialize_engine(
|
||||
model: str, quantization: str, lora_repo: Optional[str]
|
||||
) -> LLMEngine:
|
||||
"""Initialize the LLMEngine."""
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
quantization=quantization,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
quantization=quantization,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4,
|
||||
)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@ -90,32 +98,30 @@ def main():
|
||||
# QLoRA (https://arxiv.org/abs/2305.14314)
|
||||
{
|
||||
"name": "qlora_inference_example",
|
||||
'model': "huggyllama/llama-7b",
|
||||
'quantization': "bitsandbytes",
|
||||
'lora_repo': 'timdettmers/qlora-flan-7b'
|
||||
"model": "huggyllama/llama-7b",
|
||||
"quantization": "bitsandbytes",
|
||||
"lora_repo": "timdettmers/qlora-flan-7b",
|
||||
},
|
||||
{
|
||||
"name": "AWQ_inference_with_lora_example",
|
||||
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
|
||||
'quantization': "awq",
|
||||
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
"quantization": "awq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
{
|
||||
"name": "GPTQ_inference_with_lora_example",
|
||||
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
|
||||
'quantization': "gptq",
|
||||
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||
}
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
"quantization": "gptq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
]
|
||||
|
||||
for test_config in test_configs:
|
||||
print(
|
||||
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
|
||||
print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
|
||||
engine = initialize_engine(
|
||||
test_config["model"], test_config["quantization"], test_config["lora_repo"]
|
||||
)
|
||||
engine = initialize_engine(test_config['model'],
|
||||
test_config['quantization'],
|
||||
test_config['lora_repo'])
|
||||
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
|
||||
lora_path = snapshot_download(repo_id=test_config["lora_repo"])
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
@ -125,5 +131,5 @@ def main():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace):
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
]
|
||||
@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_1
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_2
|
||||
}
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": url_1}},
|
||||
{"type": "image_url", "image_url": {"url": url_2}},
|
||||
],
|
||||
},
|
||||
{
|
||||
@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_3
|
||||
}
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": url_3}},
|
||||
],
|
||||
},
|
||||
]
|
||||
@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a demo in simple or advanced mode.")
|
||||
description="Run a demo in simple or advanced mode."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
@ -179,15 +152,18 @@ def parse_args():
|
||||
help="Specify the demo mode: 'simple' or 'advanced'",
|
||||
)
|
||||
|
||||
parser.add_argument('--format',
|
||||
choices=["mistral", "hf"],
|
||||
default="mistral",
|
||||
help='Specify the format of the model to load.')
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=["mistral", "hf"],
|
||||
default="mistral",
|
||||
help="Specify the format of the model to load.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -13,8 +13,9 @@ import time
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def time_generation(llm: LLM, prompts: list[str],
|
||||
sampling_params: SamplingParams, title: str):
|
||||
def time_generation(
|
||||
llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
|
||||
):
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
# Warmup first
|
||||
@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str],
|
||||
end = time.time()
|
||||
print("-" * 50)
|
||||
print(title)
|
||||
print("time: ",
|
||||
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
|
||||
print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
@ -38,7 +38,8 @@ def main():
|
||||
template = (
|
||||
"Below is an instruction that describes a task. Write a response "
|
||||
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||
"\n\n### Response:\n")
|
||||
"\n\n### Response:\n"
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
|
@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str
|
||||
lora_path: str,
|
||||
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||
"""Create a list of test prompts with their sampling parameters.
|
||||
|
||||
@ -26,38 +26,49 @@ def create_test_prompts(
|
||||
first adapter have finished.
|
||||
"""
|
||||
return [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128), None),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8,
|
||||
top_k=5,
|
||||
presence_penalty=0.2,
|
||||
max_tokens=128), None),
|
||||
(
|
||||
"A robot may not injure a human being",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(
|
||||
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora", 1, lora_path)),
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003],
|
||||
),
|
||||
LoRARequest("sql-lora", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora2", 2, lora_path)),
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003],
|
||||
),
|
||||
LoRARequest("sql-lora2", 2, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]]):
|
||||
def process_requests(
|
||||
engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine,
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
engine.add_request(
|
||||
str(request_id), prompt, sampling_params, lora_request=lora_request
|
||||
)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine:
|
||||
# numbers will cause higher memory usage. If you know that all LoRAs will
|
||||
# use the same rank, it is recommended to set this as low as possible.
|
||||
# max_cpu_loras: controls the size of the CPU LoRA cache.
|
||||
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=256)
|
||||
engine_args = EngineArgs(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=256,
|
||||
)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@ -105,5 +117,5 @@ def main():
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -30,7 +30,8 @@ def main():
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
tensor_parallel_size=2)
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to run offline inference with an EAGLE speculative
|
||||
This example shows how to run offline inference with an EAGLE speculative
|
||||
decoding model on neuron. To use EAGLE speculative decoding, you must use
|
||||
a draft model that is specifically fine-tuned for EAGLE speculation.
|
||||
Additionally, to use EAGLE with NxD Inference, the draft model must include
|
||||
@ -24,7 +24,7 @@ llm = LLM(
|
||||
speculative_config={
|
||||
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 2048
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_num_seqs=4,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
@ -40,7 +40,7 @@ llm = LLM(
|
||||
tensor_parallel_size=32,
|
||||
override_neuron_config={
|
||||
"enable_eagle_speculation": True,
|
||||
"enable_fused_speculation": True
|
||||
"enable_fused_speculation": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -5,12 +5,12 @@ import os
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
|
||||
# Quantizes neuron model weight to int8 ,
|
||||
# The default config for quantization is int8 dtype.
|
||||
os.environ['NEURON_QUANT_DTYPE'] = "s8"
|
||||
os.environ["NEURON_QUANT_DTYPE"] = "s8"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -44,7 +44,8 @@ def main():
|
||||
override_neuron_config={
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
},
|
||||
tensor_parallel_size=2)
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to run offline inference with a speculative
|
||||
This example shows how to run offline inference with a speculative
|
||||
decoding model on neuron.
|
||||
"""
|
||||
|
||||
@ -19,9 +19,9 @@ prompts = [
|
||||
def config_buckets():
|
||||
"""Configure context length and token gen buckets."""
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
|
||||
|
||||
|
||||
def initialize_model():
|
||||
@ -31,7 +31,7 @@ def initialize_model():
|
||||
speculative_config={
|
||||
"model": "openlm-research/open_llama_3b",
|
||||
"num_speculative_tokens": 4,
|
||||
"max_model_len": 2048
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_num_seqs=4,
|
||||
max_model_len=2048,
|
||||
@ -60,5 +60,5 @@ def main():
|
||||
process_requests(model, sampling_params)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -16,7 +16,8 @@ prefix = (
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: ")
|
||||
"the following paragraph: "
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -58,9 +59,11 @@ def main():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Create an LLM with prefix caching enabled.
|
||||
prefix_cached_llm = LLM(model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.4)
|
||||
prefix_cached_llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.4,
|
||||
)
|
||||
|
||||
# Warmup so that the shared prompt's KV cache is computed.
|
||||
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
||||
@ -81,10 +84,12 @@ def main():
|
||||
print("-" * 50)
|
||||
|
||||
# Compare the results and display the speedup
|
||||
generated_same = all([
|
||||
regular_generated_texts[i] == cached_generated_texts[i]
|
||||
for i in range(len(prompts))
|
||||
])
|
||||
generated_same = all(
|
||||
[
|
||||
regular_generated_texts[i] == cached_generated_texts[i]
|
||||
for i in range(len(prompts))
|
||||
]
|
||||
)
|
||||
print(f"Generated answers are the same: {generated_same}")
|
||||
|
||||
|
||||
|
@ -16,7 +16,8 @@ The requirements for running this script are:
|
||||
Run the example:
|
||||
python prithvi_geospatial_mae.py
|
||||
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
@ -110,77 +111,67 @@ model_config = """{
|
||||
|
||||
# Temporarily creating the "config.json" for the model.
|
||||
# This is going to disappear once the correct config.json is available on HF
|
||||
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
|
||||
'w') as config_file:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
|
||||
) as config_file:
|
||||
config_file.write(model_config)
|
||||
|
||||
datamodule_config = {
|
||||
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
|
||||
'batch_size':
|
||||
16,
|
||||
'constant_scale':
|
||||
0.0001,
|
||||
'data_root':
|
||||
'/dccstor/geofm-finetuning/datasets/sen1floods11',
|
||||
'drop_last':
|
||||
True,
|
||||
'no_data_replace':
|
||||
0.0,
|
||||
'no_label_replace':
|
||||
-1,
|
||||
'num_workers':
|
||||
8,
|
||||
'test_transform': [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0)
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size": 16,
|
||||
"constant_scale": 0.0001,
|
||||
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last": True,
|
||||
"no_data_replace": 0.0,
|
||||
"no_label_replace": -1,
|
||||
"num_workers": 8,
|
||||
"test_transform": [
|
||||
albumentations.Resize(
|
||||
always_apply=False, height=448, interpolation=1, p=1, width=448
|
||||
),
|
||||
albumentations.pytorch.ToTensorV2(
|
||||
transpose_mask=False, always_apply=True, p=1.0
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
|
||||
def __init__(self):
|
||||
print("Initializing PrithviMAE model")
|
||||
self.model = LLM(model=os.path.join(os.path.dirname(__file__),
|
||||
"./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32")
|
||||
self.model = LLM(
|
||||
model=os.path.join(os.path.dirname(__file__), "./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
print("################ Running inference on vLLM ##############")
|
||||
# merge the inputs into one data structure
|
||||
mm_data = {
|
||||
"pixel_values":
|
||||
torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords":
|
||||
torch.empty(0) if location_coords is None else location_coords
|
||||
"pixel_values": torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords": torch.empty(0)
|
||||
if location_coords is None
|
||||
else location_coords,
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
print(
|
||||
"################ Inference done (it took seconds) ##############"
|
||||
)
|
||||
print("################ Inference done (it took seconds) ##############")
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
|
||||
def generate_datamodule():
|
||||
datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config['data_root'],
|
||||
data_root=datamodule_config["data_root"],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"
|
||||
""])
|
||||
test_transform=datamodule_config["test_transform"],
|
||||
)
|
||||
|
||||
return datamodule
|
||||
|
||||
@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels):
|
||||
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
||||
min_value = OFFSET
|
||||
|
||||
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
|
||||
1)
|
||||
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
||||
|
||||
# No data as zeros
|
||||
orig_img[~valid_mask] = 0
|
||||
@ -300,18 +290,21 @@ def load_example(
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r'(\d{7,8}T\d{6})', file)
|
||||
match = re.search(r"(\d{7,8}T\d{6})", file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split('T')[0][4:]
|
||||
julian_day = match.group(1).split("T")[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = datetime.datetime.strptime(
|
||||
julian_day, '%m%d').timetuple().tm_yday
|
||||
julian_day = (
|
||||
datetime.datetime.strptime(julian_day, "%m%d")
|
||||
.timetuple()
|
||||
.tm_yday
|
||||
)
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception as e:
|
||||
print(f'Could not extract timestamp for {file} ({e})')
|
||||
print(f"Could not extract timestamp for {file} ({e})")
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||
@ -320,50 +313,44 @@ def load_example(
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
def run_model(input_data,
|
||||
temporal_coords,
|
||||
location_coords,
|
||||
model,
|
||||
datamodule,
|
||||
img_size,
|
||||
lightning_model=None):
|
||||
def run_model(
|
||||
input_data,
|
||||
temporal_coords,
|
||||
location_coords,
|
||||
model,
|
||||
datamodule,
|
||||
img_size,
|
||||
lightning_model=None,
|
||||
):
|
||||
# Reflect pad if not divisible by img_size
|
||||
original_h, original_w = input_data.shape[-2:]
|
||||
pad_h = (img_size - (original_h % img_size)) % img_size
|
||||
pad_w = (img_size - (original_w % img_size)) % img_size
|
||||
input_data = np.pad(input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
mode="reflect")
|
||||
input_data = np.pad(
|
||||
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
||||
)
|
||||
|
||||
# Build sliding window
|
||||
batch_size = 1
|
||||
batch = torch.tensor(input_data, device="cpu")
|
||||
windows = (batch.unfold(3, img_size,
|
||||
img_size).unfold(4, img_size, img_size))
|
||||
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(windows,
|
||||
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||
h=img_size,
|
||||
w=img_size)
|
||||
windows = rearrange(
|
||||
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
||||
)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[
|
||||
0] > batch_size else 1
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords,
|
||||
device=device).unsqueeze(0)
|
||||
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0],
|
||||
device=device).unsqueeze(0)
|
||||
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
@ -371,26 +358,24 @@ def run_model(input_data,
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
x = datamodule.test_transform(
|
||||
image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||
x = datamodule.aug(x)['image']
|
||||
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||
x = datamodule.aug(x)["image"]
|
||||
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
if lightning_model:
|
||||
pred_lightning = lightning_model(
|
||||
x,
|
||||
temporal_coords=temporal_coords,
|
||||
location_coords=location_coords)
|
||||
x, temporal_coords=temporal_coords, location_coords=location_coords
|
||||
)
|
||||
pred_lightning = pred_lightning.output.detach().cpu()
|
||||
if not torch.equal(pred, pred_lightning):
|
||||
print("Inference output is not equal")
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
|
||||
size=img_size,
|
||||
mode="nearest")
|
||||
y_hat = torch.nn.functional.interpolate(
|
||||
y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
|
||||
)
|
||||
|
||||
pred_imgs.append(y_hat)
|
||||
|
||||
@ -437,8 +422,7 @@ def parse_args():
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
help="0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -478,17 +462,18 @@ def main(
|
||||
# Running model ------------------------------------------------------------
|
||||
|
||||
channels = [
|
||||
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
|
||||
pred = run_model(input_data, temporal_coords, location_coords, model_obj,
|
||||
datamodule, img_size)
|
||||
pred = run_model(
|
||||
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
|
||||
)
|
||||
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
||||
)
|
||||
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
||||
|
||||
# Save image + pred
|
||||
@ -502,13 +487,13 @@ def main(
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
pred[pred == 0.] = np.nan
|
||||
pred[pred == 0.0] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
||||
|
||||
img_pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(img_pred),
|
||||
output_path=img_pred_file,
|
||||
@ -518,8 +503,9 @@ def main(
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
rgb_file = os.path.join(
|
||||
output_dir, "original_rgb_"
|
||||
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
output_dir,
|
||||
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
output_path=rgb_file,
|
||||
@ -528,7 +514,6 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
@ -44,14 +44,17 @@ def get_dtype(dtype: str):
|
||||
|
||||
|
||||
OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
|
||||
def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
|
||||
-> OutputLen_NumReqs_Map:
|
||||
|
||||
|
||||
def compute_request_output_lengths(
|
||||
batch_size: int, step_requests: list[int]
|
||||
) -> OutputLen_NumReqs_Map:
|
||||
"""
|
||||
Given the number of requests, batch_size, and the number of requests
|
||||
that each engine-step should process, step_requests, determine the
|
||||
output lengths of the requests such that step_request is honoured.
|
||||
|
||||
Example:
|
||||
Example:
|
||||
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
|
||||
then return,
|
||||
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
|
||||
@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
|
||||
output_length -= 1
|
||||
|
||||
# sanity checks.
|
||||
assert sum(ol_nr.values()) == batch_size, \
|
||||
("Number of requests in output-length assignment does not match "
|
||||
f"batch-size.\n batch size {batch_size} - "
|
||||
f"step requests {step_requests} - assignments {ol_nr}")
|
||||
assert sum(ol_nr.values()) == batch_size, (
|
||||
"Number of requests in output-length assignment does not match "
|
||||
f"batch-size.\n batch size {batch_size} - "
|
||||
f"step requests {step_requests} - assignments {ol_nr}"
|
||||
)
|
||||
|
||||
# Check that the output-length is in [1, num-steps]. Output length must be
|
||||
# at least 1 as all requests must participate in the prefill-step.
|
||||
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
|
||||
("Output lengths of requests should be in range "
|
||||
f"[1, num-engine-steps].\n batch size {batch_size} - "
|
||||
f"step requests {step_requests} - assignments {ol_nr}")
|
||||
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), (
|
||||
"Output lengths of requests should be in range "
|
||||
f"[1, num-engine-steps].\n batch size {batch_size} - "
|
||||
f"step requests {step_requests} - assignments {ol_nr}"
|
||||
)
|
||||
|
||||
return ol_nr
|
||||
|
||||
@ -131,7 +136,7 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
|
||||
context: ProfileContext object.
|
||||
|
||||
Returns:
|
||||
list[int]: Number of requests to process for all engine-steps.
|
||||
list[int]: Number of requests to process for all engine-steps.
|
||||
output[i], contains the number of requests that the ith step
|
||||
should process.
|
||||
"""
|
||||
@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
|
||||
# that their output lengths must be equal to num_engine_steps.
|
||||
return [context.batch_size] * context.num_steps
|
||||
|
||||
assert context.complete_num_requests_per_step and \
|
||||
context.complete_num_requests_per_step > 0, \
|
||||
(f"Expected a positive complete_num_requests_per_step argument."
|
||||
f"Instead got {context.complete_num_requests_per_step}")
|
||||
assert (
|
||||
context.complete_num_requests_per_step
|
||||
and context.complete_num_requests_per_step > 0
|
||||
), (
|
||||
f"Expected a positive complete_num_requests_per_step argument."
|
||||
f"Instead got {context.complete_num_requests_per_step}"
|
||||
)
|
||||
|
||||
# We start dropping after the first decode step.
|
||||
step_requests = [
|
||||
@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
|
||||
return step_requests
|
||||
|
||||
|
||||
def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
json_output: Optional[str]):
|
||||
def run_profile(
|
||||
context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]
|
||||
):
|
||||
print("Run profile with:")
|
||||
for key, value in asdict(context).items():
|
||||
print(f" {key} = {value}")
|
||||
@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
requests_per_step: list[int] = determine_requests_per_step(context)
|
||||
|
||||
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
|
||||
context.batch_size, requests_per_step)
|
||||
context.batch_size, requests_per_step
|
||||
)
|
||||
|
||||
num_steps_to_profile: int = len(requests_per_step)
|
||||
max_output_len: int = max(ol_nr.keys())
|
||||
@ -186,7 +196,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
top_p=0.95,
|
||||
# max_tokens is set on a per-request basis.
|
||||
max_tokens=None,
|
||||
ignore_eos=True)
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
# Create LLM
|
||||
llm = LLM(**asdict(context.engine_args))
|
||||
@ -199,31 +210,37 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
max_num_seqs = scheduler_config.max_num_seqs
|
||||
|
||||
if batch_size * prompt_len > max_num_batched_tokens:
|
||||
print(f"ERROR: chosen batch_size * prompt_len "
|
||||
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
|
||||
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
|
||||
f"and therefore cannot be run in a single profile step, please "
|
||||
f"choose a smaller batch size or prompt length, or increase "
|
||||
f"--max-num-batched-tokens")
|
||||
print(
|
||||
f"ERROR: chosen batch_size * prompt_len "
|
||||
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
|
||||
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
|
||||
f"and therefore cannot be run in a single profile step, please "
|
||||
f"choose a smaller batch size or prompt length, or increase "
|
||||
f"--max-num-batched-tokens"
|
||||
)
|
||||
sys.exit(-1)
|
||||
if batch_size > max_num_seqs:
|
||||
print(
|
||||
f"ERROR: chosen batch_size ({batch_size}) is larger than "
|
||||
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
|
||||
f"single profile step, please choose a smaller batch size")
|
||||
f"single profile step, please choose a smaller batch size"
|
||||
)
|
||||
sys.exit(-1)
|
||||
print("llm.llm_engine.model_config.max_model_len: ",
|
||||
llm.llm_engine.model_config.max_model_len)
|
||||
print(
|
||||
"llm.llm_engine.model_config.max_model_len: ",
|
||||
llm.llm_engine.model_config.max_model_len,
|
||||
)
|
||||
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
|
||||
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
|
||||
f"{max_output_len} = {prompt_len + max_output_len}) is larger "
|
||||
f"than the model's max_model_len ({max_model_len}), please "
|
||||
f"choose a smaller prompt_len or max_output_len, or increase "
|
||||
f"--max-model-len")
|
||||
print(
|
||||
f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
|
||||
f"{max_output_len} = {prompt_len + max_output_len}) is larger "
|
||||
f"than the model's max_model_len ({max_model_len}), please "
|
||||
f"choose a smaller prompt_len or max_output_len, or increase "
|
||||
f"--max-model-len"
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
def add_requests():
|
||||
|
||||
def get_output_len_generator() -> Generator[int, Any, Any]:
|
||||
for output_len, num_reqs in ol_nr.items():
|
||||
for _ in range(num_reqs):
|
||||
@ -234,13 +251,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
sampling_params.max_tokens = next(output_len_generator)
|
||||
assert isinstance(sampling_params.max_tokens, int)
|
||||
|
||||
prompt_token_ids = torch.randint(llm.get_tokenizer().vocab_size,
|
||||
size=(prompt_len, )).tolist()
|
||||
prompt_token_ids = torch.randint(
|
||||
llm.get_tokenizer().vocab_size, size=(prompt_len,)
|
||||
).tolist()
|
||||
|
||||
llm.llm_engine.add_request(
|
||||
request_id=f"seq{i}",
|
||||
prompt={'prompt_token_ids': prompt_token_ids},
|
||||
params=sampling_params)
|
||||
prompt={"prompt_token_ids": prompt_token_ids},
|
||||
params=sampling_params,
|
||||
)
|
||||
|
||||
def abort_requests():
|
||||
for i in range(batch_size):
|
||||
@ -261,10 +280,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
|
||||
decode_profs = []
|
||||
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
|
||||
num_running_seqs = llm.llm_engine.scheduler[
|
||||
0].get_num_unfinished_seq_groups()
|
||||
with layerwise_profile(
|
||||
num_running_seqs=num_running_seqs) as decode_prof:
|
||||
num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups()
|
||||
with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof:
|
||||
llm.llm_engine.step()
|
||||
decode_profs.append(decode_prof)
|
||||
|
||||
@ -274,8 +291,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
|
||||
LINE_WIDTH = 80
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= Prefill Model Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
prefill_results.print_model_table()
|
||||
@ -283,16 +299,17 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
if has_decode:
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= First Decode Step Model Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print(
|
||||
f"= First Decode Step Model Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})"
|
||||
)
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
decode_results_list[0].print_model_table()
|
||||
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= Prefill Summary Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
prefill_results.print_summary_table()
|
||||
@ -300,25 +317,32 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
if has_decode:
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= First Decode Step Summary Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print(
|
||||
f"= First Decode Step Summary Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})"
|
||||
)
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
decode_results_list[0].print_summary_table()
|
||||
|
||||
if csv_output:
|
||||
csv_filename_base = csv_output[:-4] \
|
||||
if csv_output.endswith('.csv') else csv_output
|
||||
csv_filename_base = (
|
||||
csv_output[:-4] if csv_output.endswith(".csv") else csv_output
|
||||
)
|
||||
prefill_results.export_model_stats_table_csv(
|
||||
csv_filename_base + "_prefill_model_table.csv")
|
||||
csv_filename_base + "_prefill_model_table.csv"
|
||||
)
|
||||
prefill_results.export_summary_stats_table_csv(
|
||||
csv_filename_base + "_prefill_summary_table.csv")
|
||||
csv_filename_base + "_prefill_summary_table.csv"
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_results_list[0].export_model_stats_table_csv(\
|
||||
csv_filename_base + "_decode_model_table.csv")
|
||||
decode_results_list[0].export_model_stats_table_csv(
|
||||
csv_filename_base + "_decode_model_table.csv"
|
||||
)
|
||||
decode_results_list[0].export_summary_stats_table_csv(
|
||||
csv_filename_base + "_decode_summary_table.csv")
|
||||
csv_filename_base + "_decode_summary_table.csv"
|
||||
)
|
||||
|
||||
if json_output:
|
||||
cuda_devices = [
|
||||
@ -332,7 +356,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
"torch_version": f"{torch.__version__}",
|
||||
"torch_cuda_version": f"{torch.version.cuda}",
|
||||
"cuda_devices": f"{cuda_devices}",
|
||||
**asdict(context)
|
||||
**asdict(context),
|
||||
},
|
||||
"prefill": prefill_results.convert_stats_to_dict(),
|
||||
}
|
||||
@ -342,8 +366,9 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
||||
|
||||
# Add .json to json_output filename if it doesn't exist already.
|
||||
json_output_file = json_output if json_output.endswith(
|
||||
'.json') else json_output + '.json'
|
||||
json_output_file = (
|
||||
json_output if json_output.endswith(".json") else json_output + ".json"
|
||||
)
|
||||
with open(json_output_file, "w+") as f:
|
||||
json.dump(json_dict, f, indent=2)
|
||||
pass
|
||||
@ -351,16 +376,21 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
if context.save_chrome_traces_folder is not None:
|
||||
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
|
||||
prefill_prof.profiler.export_chrome_trace(
|
||||
context.save_chrome_traces_folder + "/prefill.json")
|
||||
context.save_chrome_traces_folder + "/prefill.json"
|
||||
)
|
||||
for idx, decode_prof in enumerate(decode_profs):
|
||||
decode_prof.profiler.export_chrome_trace(
|
||||
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
|
||||
print("Traces saved as prefill.json and decode_1.json, etc."
|
||||
f" in folder {context.save_chrome_traces_folder}")
|
||||
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json"
|
||||
)
|
||||
print(
|
||||
"Traces saved as prefill.json and decode_1.json, etc."
|
||||
f" in folder {context.save_chrome_traces_folder}"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(description="""
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Profile a model
|
||||
|
||||
example:
|
||||
@ -384,7 +414,8 @@ Profile a model
|
||||
--output-directory profile_breakdown --plot-metric pct_cuda_time
|
||||
```
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
@ -393,59 +424,68 @@ Profile a model
|
||||
"filename, will create <filename>_prefill_model_table.csv, "
|
||||
"<filename>_prefill_summary_table.csv, "
|
||||
"<filename>_decode_model_table.csv, and "
|
||||
"<filename>_decode_summary_table.csv")
|
||||
"<filename>_decode_summary_table.csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export the results as a json file. This should be the filename")
|
||||
parser.add_argument("--save-chrome-traces-folder",
|
||||
type=str,
|
||||
help="Save chrome traces for the prefill and decode "
|
||||
"will save traces as prefill.json and decode_1.json, "
|
||||
"etc. inside this folder")
|
||||
help="Export the results as a json file. This should be the filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-chrome-traces-folder",
|
||||
type=str,
|
||||
help="Save chrome traces for the prefill and decode "
|
||||
"will save traces as prefill.json and decode_1.json, "
|
||||
"etc. inside this folder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-len",
|
||||
type=int,
|
||||
default=PROMPT_LEN_DEFAULT,
|
||||
help=f"Length of the random prompt to use when profiling, all batched "
|
||||
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
|
||||
parser.add_argument("--batch-size",
|
||||
type=int,
|
||||
default=BATCH_SIZE_DEFAULT,
|
||||
help=f"Number of requests to run as a single batch, "
|
||||
f"default={BATCH_SIZE_DEFAULT}")
|
||||
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=BATCH_SIZE_DEFAULT,
|
||||
help=f"Number of requests to run as a single batch, "
|
||||
f"default={BATCH_SIZE_DEFAULT}",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
run_num_steps_parser = subparsers.add_parser(
|
||||
"run_num_steps",
|
||||
help="This variation profiles n engine.step() invocations.")
|
||||
"run_num_steps", help="This variation profiles n engine.step() invocations."
|
||||
)
|
||||
run_num_steps_parser.add_argument(
|
||||
'-n',
|
||||
'--num-steps',
|
||||
"-n",
|
||||
"--num-steps",
|
||||
type=int,
|
||||
help="Number of engine steps to profile.\n"
|
||||
"Setting it to 1, profiles only the prefill step.\n"
|
||||
"Setting it to 2, profiles the prefill and first decode step\n"
|
||||
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
|
||||
"and so on ...")
|
||||
"and so on ...",
|
||||
)
|
||||
|
||||
run_to_completion_parser = subparsers.add_parser(
|
||||
"run_to_completion",
|
||||
help="This variation profiles all the engine.step() invocations"
|
||||
"until the engine exhausts all submitted requests.")
|
||||
"until the engine exhausts all submitted requests.",
|
||||
)
|
||||
run_to_completion_parser.add_argument(
|
||||
'-n',
|
||||
'--complete-num-requests-per-step',
|
||||
"-n",
|
||||
"--complete-num-requests-per-step",
|
||||
type=int,
|
||||
help=
|
||||
"Complete complete_num_requests_per_step requests every decode step."
|
||||
help="Complete complete_num_requests_per_step requests every decode step."
|
||||
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
|
||||
"the profiler is run for 6 engine steps, with the steps processing, "
|
||||
"128, 128, 96, 64, 32, 1 requests respectively.\n"
|
||||
"Note that we tack-on a one-request step at the end as it is often "
|
||||
"useful.")
|
||||
"useful.",
|
||||
)
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
@ -459,7 +499,8 @@ def main(args):
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k in inspect.signature(ProfileContext).parameters
|
||||
})
|
||||
},
|
||||
)
|
||||
run_profile(context, csv_output=args.csv, json_output=args.json)
|
||||
|
||||
|
||||
|
@ -31,18 +31,16 @@ def main(args: argparse.Namespace):
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def run_to_completion():
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
@ -58,10 +56,9 @@ def main(args: argparse.Namespace):
|
||||
profile_dir = args.profile_result_dir
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
# Enable tracing on server
|
||||
xp.trace_detached("localhost:9012",
|
||||
profile_dir,
|
||||
delay_ms=DELAY_MS,
|
||||
duration_ms=DURATION_MS)
|
||||
xp.trace_detached(
|
||||
"localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
|
||||
)
|
||||
if DELAY_MS == 0:
|
||||
time.sleep(1.0)
|
||||
profile_latencies = []
|
||||
@ -72,30 +69,36 @@ def main(args: argparse.Namespace):
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--num-iters-warmup',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of iterations to run for warmup.')
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of iterations to run for profiling.')
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations to run for profiling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-result-dir",
|
||||
type=str,
|
||||
default="profiles",
|
||||
help=
|
||||
('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard '
|
||||
'(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).'
|
||||
))
|
||||
help=(
|
||||
"path to save the pytorch profiler output. Can be visualized "
|
||||
"with ui.perfetto.dev or Tensorboard "
|
||||
"(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
@ -18,8 +18,7 @@ Run:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizer)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
@ -32,27 +31,29 @@ def init_tokenizer_and_llm(model_name: str):
|
||||
return tokenizer, embedding_layer, llm
|
||||
|
||||
|
||||
def get_prompt_embeds(chat: list[dict[str,
|
||||
str]], tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
token_ids = tokenizer.apply_chat_template(chat,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt')
|
||||
def get_prompt_embeds(
|
||||
chat: list[dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module,
|
||||
):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
chat, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chat = [{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}]
|
||||
def single_prompt_inference(
|
||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||
):
|
||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt_embeds": prompt_embeds,
|
||||
}
|
||||
)
|
||||
|
||||
print("\n[Single Inference Output]")
|
||||
print("-" * 30)
|
||||
@ -61,34 +62,26 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chats = [[{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "When is the day longest during the year?"
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "Where is bigger, the moon or the sun?"
|
||||
}]]
|
||||
def batch_prompt_inference(
|
||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||
):
|
||||
chats = [
|
||||
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
||||
]
|
||||
|
||||
prompt_embeds_list = [
|
||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||
]
|
||||
|
||||
outputs = llm.generate([{
|
||||
"prompt_embeds": embeds
|
||||
} for embeds in prompt_embeds_list])
|
||||
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
||||
|
||||
print("\n[Batch Inference Outputs]")
|
||||
print("-" * 30)
|
||||
for i, o in enumerate(outputs):
|
||||
print(f"Q{i+1}: {chats[i][0]['content']}")
|
||||
print(f"A{i+1}: {o.outputs[0].text}\n")
|
||||
print(f"Q{i + 1}: {chats[i][0]['content']}")
|
||||
print(f"A{i + 1}: {o.outputs[0].text}\n")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
"""
|
||||
|
||||
@ -27,51 +27,55 @@ class QueryResult(NamedTuple):
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech.")
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
|
||||
def get_mixed_modalities_query() -> QueryResult:
|
||||
question = ("What is recited in the audio? "
|
||||
"What is the content of this image? Why is this video funny?")
|
||||
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||
"<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n")
|
||||
question = (
|
||||
"What is recited in the audio? "
|
||||
"What is the content of this image? Why is this video funny?"
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||
"<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio":
|
||||
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
"image":
|
||||
convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB"),
|
||||
"video":
|
||||
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
"image": convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB"
|
||||
),
|
||||
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"audio": 1,
|
||||
"image": 1,
|
||||
"video": 1
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_use_audio_in_video_query() -> QueryResult:
|
||||
question = ("Describe the content of the video, "
|
||||
"then convert what the baby say into text.")
|
||||
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n")
|
||||
question = (
|
||||
"Describe the content of the video, then convert what the baby say into text."
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||
audio = asset.get_audio(sampling_rate=16000)
|
||||
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
|
||||
"Please launch this example with "
|
||||
"`VLLM_USE_V1=0`.")
|
||||
assert not envs.VLLM_USE_V1, (
|
||||
"V1 does not support use_audio_in_video. "
|
||||
"Please launch this example with "
|
||||
"`VLLM_USE_V1=0`."
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
@ -83,20 +87,19 @@ def get_use_audio_in_video_query() -> QueryResult:
|
||||
"use_audio_in_video": True,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"audio": 1,
|
||||
"video": 1
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_multi_audios_query() -> QueryResult:
|
||||
question = "Are these two audio clips the same?"
|
||||
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n")
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
@ -124,18 +127,19 @@ def main(args):
|
||||
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||
query_result = query_map[args.query_type]()
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=5632,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=5632,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
|
||||
|
||||
outputs = llm.generate(query_result.inputs,
|
||||
sampling_params=sampling_params)
|
||||
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
@ -144,18 +148,23 @@ def main(args):
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--query-type',
|
||||
'-q',
|
||||
type=str,
|
||||
default="mixed_modalities",
|
||||
choices=query_map.keys(),
|
||||
help='Query type.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-type",
|
||||
"-q",
|
||||
type=str,
|
||||
default="mixed_modalities",
|
||||
choices=query_map.keys(),
|
||||
help="Query type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -17,10 +17,10 @@ def load_prompt() -> str:
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
|
||||
|
||||
with urlopen(
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
|
||||
"/Qwen2.5-1M/test-data/600k.txt",
|
||||
timeout=5) as response:
|
||||
prompt = response.read().decode('utf-8')
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
|
||||
timeout=5,
|
||||
) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
return prompt
|
||||
|
||||
|
||||
@ -41,18 +41,22 @@ def process_requests(llm: LLM, prompts: list[str]) -> None:
|
||||
for output in outputs:
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt length: {len(prompt_token_ids)}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(
|
||||
f"Prompt length: {len(prompt_token_ids)}, "
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
|
||||
# Create an LLM.
|
||||
def initialize_engine() -> LLM:
|
||||
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
max_model_len=1048576,
|
||||
tensor_parallel_size=4,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=131072)
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
max_model_len=1048576,
|
||||
tensor_parallel_size=4,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=131072,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@ -62,5 +66,5 @@ def main():
|
||||
process_requests(llm, [prompt])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,6 +12,7 @@ inference instance. In practice, there could be multiple training instances
|
||||
and multiple inference instances. For the full implementation, please refer
|
||||
to the OpenRLHF framework.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
@ -26,7 +27,6 @@ from vllm.utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
@ -89,8 +89,7 @@ print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# set up the communication between the training process
|
||||
@ -98,11 +97,13 @@ for output in outputs:
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote("init_weight_update_group",
|
||||
args=(master_address, master_port, 1, 3))
|
||||
handle = llm.collective_rpc.remote(
|
||||
"init_weight_update_group", args=(master_address, master_port, 1, 3)
|
||||
)
|
||||
|
||||
model_update_group = stateless_init_process_group(master_address, master_port,
|
||||
0, 3, torch.device("cuda:0"))
|
||||
model_update_group = stateless_init_process_group(
|
||||
master_address, master_port, 0, 3, torch.device("cuda:0")
|
||||
)
|
||||
ray.get(handle)
|
||||
|
||||
# simulate training, modify the weights of the model.
|
||||
@ -111,8 +112,7 @@ for name, p in train_model.named_parameters():
|
||||
|
||||
# sync weight from the training process to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
handle = llm.collective_rpc.remote("update_weight",
|
||||
args=(name, p.dtype, p.shape))
|
||||
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
@ -126,6 +126,5 @@ print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
@ -9,6 +9,7 @@ The key points:
|
||||
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
|
||||
multiple processes on the same GPU.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
@ -20,7 +21,6 @@ from vllm import LLM
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
|
||||
def __init__(self, *args, bundle_indices: list, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
@ -29,17 +29,16 @@ class MyLLM(LLM):
|
||||
# every worker will use 0.4 GPU, so that we can schedule
|
||||
# 2 instances on the same GPUs.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(
|
||||
map(str, bundle_indices))
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
|
||||
print(f"creating LLM with bundle_indices={bundle_indices}")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class RayTrainingActor:
|
||||
|
||||
def __init__(self):
|
||||
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
self.model.to("cuda:0")
|
||||
for name, p in self.model.named_parameters():
|
||||
@ -48,6 +47,7 @@ class RayTrainingActor:
|
||||
# the argument for get_device_uuid is the index
|
||||
# of the GPU in the visible devices.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(0)
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
@ -55,6 +55,7 @@ class RayTrainingActor:
|
||||
|
||||
def get_weight_ipc_handles(self):
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
data = {}
|
||||
for name, p in self.model.named_parameters():
|
||||
# the training actor might only have a subset of the weights
|
||||
@ -101,7 +102,7 @@ for bundle_index, training_actor in enumerate(training_actors):
|
||||
print(f"training actor {bundle_index} is on {device_id}")
|
||||
training_actor_device_ids.append(device_id)
|
||||
|
||||
for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
|
||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
# IMPORTANT: when creating vLLM instances, we need to
|
||||
# make sure there are no GPU activities on the target GPUs,
|
||||
# otherwise, they will interfere with the vLLM memory profiling,
|
||||
@ -128,7 +129,8 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())))
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
||||
)
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
|
||||
# check the placement
|
||||
@ -147,9 +149,10 @@ for actor in training_actors:
|
||||
print("update the weights of the inference engines")
|
||||
for llm in inference_engines:
|
||||
ray.get(
|
||||
llm.collective_rpc.remote("update_weights_from_ipc_handles",
|
||||
args=(ipc_handles, )))
|
||||
llm.collective_rpc.remote(
|
||||
"update_weights_from_ipc_handles", args=(ipc_handles,)
|
||||
)
|
||||
)
|
||||
print("check if the weights are updated")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(
|
||||
llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
|
@ -2,21 +2,20 @@
|
||||
import torch
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
|
||||
pg = StatelessProcessGroup.create(
|
||||
host=master_address, port=master_port, rank=rank, world_size=world_size
|
||||
)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
@ -31,9 +30,11 @@ class WorkerExtension:
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
def init_weight_update_group(
|
||||
self, master_address, master_port, rank_offset, world_size
|
||||
):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
@ -45,9 +46,9 @@ class WorkerExtension:
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
self.model_update_group.broadcast(
|
||||
weight, src=0, stream=torch.cuda.current_stream()
|
||||
)
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
@ -59,8 +60,7 @@ class WorkerExtension:
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
@ -76,6 +76,7 @@ class ColocateWorkerExtension:
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
@ -100,6 +101,5 @@ class ColocateWorkerExtension:
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
@ -21,6 +21,7 @@ llm = LLM(
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import shutil
|
||||
@ -33,18 +34,18 @@ from vllm.utils import FlexibleArgumentParser
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to output checkpoint")
|
||||
parser.add_argument("--file-pattern",
|
||||
type=str,
|
||||
help="string pattern of saved filenames")
|
||||
parser.add_argument("--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file")
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True, type=str, help="path to output checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-pattern", type=str, help="string pattern of saved filenames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -68,23 +69,23 @@ def main(args):
|
||||
# For V1 engine, we need to use engine_core.save_sharded_state
|
||||
print("Using V1 engine save path")
|
||||
llm.llm_engine.engine_core.save_sharded_state(
|
||||
path=args.output,
|
||||
pattern=args.file_pattern,
|
||||
max_size=args.max_file_size)
|
||||
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||
)
|
||||
else:
|
||||
# For V0 engine
|
||||
print("Using V0 engine save path")
|
||||
model_executor = llm.llm_engine.model_executor
|
||||
model_executor.save_sharded_state(path=args.output,
|
||||
pattern=args.file_pattern,
|
||||
max_size=args.max_file_size)
|
||||
model_executor.save_sharded_state(
|
||||
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||
)
|
||||
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(model_path):
|
||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||
if os.path.isdir(os.path.join(model_path, file)):
|
||||
shutil.copytree(os.path.join(model_path, file),
|
||||
os.path.join(args.output, file))
|
||||
shutil.copytree(
|
||||
os.path.join(model_path, file), os.path.join(args.output, file)
|
||||
)
|
||||
else:
|
||||
shutil.copy(os.path.join(model_path, file), args.output)
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file demonstrates the example usage of guided decoding
|
||||
to generate structured outputs using vLLM. It shows how to apply
|
||||
different guided decoding techniques such as Choice, Regex, JSON schema,
|
||||
and Grammar to produce structured and formatted results
|
||||
This file demonstrates the example usage of guided decoding
|
||||
to generate structured outputs using vLLM. It shows how to apply
|
||||
different guided decoding techniques such as Choice, Regex, JSON schema,
|
||||
and Grammar to produce structured and formatted results
|
||||
based on specific prompts.
|
||||
"""
|
||||
|
||||
@ -15,20 +15,20 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
# Guided decoding by Choice (list of possible options)
|
||||
guided_decoding_params_choice = GuidedDecodingParams(
|
||||
choice=["Positive", "Negative"])
|
||||
sampling_params_choice = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_choice)
|
||||
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
|
||||
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
|
||||
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
|
||||
|
||||
# Guided decoding by Regex
|
||||
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
||||
sampling_params_regex = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_regex, stop=["\n"])
|
||||
guided_decoding=guided_decoding_params_regex, stop=["\n"]
|
||||
)
|
||||
prompt_regex = (
|
||||
"Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n")
|
||||
"alan.turing@enigma.com\n"
|
||||
)
|
||||
|
||||
|
||||
# Guided decoding by JSON using Pydantic schema
|
||||
@ -47,10 +47,11 @@ class CarDescription(BaseModel):
|
||||
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
||||
sampling_params_json = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_json)
|
||||
prompt_json = ("Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's")
|
||||
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
|
||||
prompt_json = (
|
||||
"Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
|
||||
# Guided decoding by Grammar
|
||||
simplified_sql_grammar = """
|
||||
@ -61,12 +62,11 @@ table ::= "table_1 " | "table_2 "
|
||||
condition ::= column "= " number
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
guided_decoding_params_grammar = GuidedDecodingParams(
|
||||
grammar=simplified_sql_grammar)
|
||||
sampling_params_grammar = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_grammar)
|
||||
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table.")
|
||||
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
|
||||
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
|
||||
prompt_grammar = (
|
||||
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
|
||||
)
|
||||
|
||||
|
||||
def format_output(title: str, output: str):
|
||||
@ -90,8 +90,7 @@ def main():
|
||||
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
||||
format_output("Guided decoding by JSON", json_output)
|
||||
|
||||
grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
|
||||
llm)
|
||||
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
|
||||
format_output("Guided decoding by Grammar", grammar_output)
|
||||
|
||||
|
||||
|
@ -45,8 +45,7 @@ if dist.get_rank() == 0:
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}\n")
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
|
||||
print("-" * 50)
|
||||
"""
|
||||
Further tips:
|
||||
|
@ -20,10 +20,12 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
|
||||
def main():
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4,
|
||||
max_model_len=128)
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4,
|
||||
max_model_len=128,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("-" * 50)
|
||||
for output, answer in zip(outputs, answers):
|
||||
|
@ -6,6 +6,7 @@ the correct prompt format on vision language models for text generation.
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
@ -49,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
for question in questions]
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
"<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
|
||||
@ -135,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
|
||||
for question in questions
|
||||
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -198,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [("<bos><start_of_turn>user\n"
|
||||
f"<start_of_image>{question}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n") for question in questions]
|
||||
prompts = [
|
||||
(
|
||||
"<bos><start_of_turn>user\n"
|
||||
f"<start_of_image>{question}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -225,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
prompts = [
|
||||
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>" for question in questions
|
||||
{question}<|assistant|>"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
stop_token_ids = [151329, 151336, 151338]
|
||||
@ -250,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for H2OVL-Mississippi
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
|
||||
@ -284,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# if you are running out of memory, you can reduce the "longest_edge".
|
||||
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
|
||||
mm_processor_kwargs={
|
||||
"size": {
|
||||
"longest_edge": 3 * 364
|
||||
},
|
||||
"size": {"longest_edge": 3 * 364},
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
prompts = [(
|
||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||
) for question in questions]
|
||||
prompts = [
|
||||
(f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:")
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -311,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
enforce_eager=True,
|
||||
mm_processor_kwargs={
|
||||
"max_image_size": {
|
||||
"longest_edge": 384
|
||||
},
|
||||
"max_image_size": {"longest_edge": 384},
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
@ -330,7 +335,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# InternVL
|
||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
model_name = "OpenGVLab/InternVL3-2B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -345,15 +349,14 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
elif modality == "video":
|
||||
placeholder = "<video>"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"{placeholder}\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"{placeholder}\n{question}"}]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for InternVL
|
||||
# models variants may have different stop tokens
|
||||
@ -361,9 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
stop_token_ids = [
|
||||
token_id for token_id in stop_token_ids if token_id is not None
|
||||
]
|
||||
stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -379,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = [
|
||||
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
|
||||
f"<|media_pad|><|media_end|>{question}<|im_end|>"
|
||||
"<|im_assistant|>assistant<|im_middle|>" for question in questions
|
||||
"<|im_assistant|>assistant<|im_middle|>"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -399,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
|
||||
]
|
||||
prompts = [f"USER: <image>\n{question}\nASSISTANT:" for question in questions]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
@ -434,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# LlaVA-NeXT-Video
|
||||
# Currently only support for video input
|
||||
def run_llava_next_video(questions: list[str],
|
||||
modality: str) -> ModelRequestData:
|
||||
def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "video"
|
||||
|
||||
prompts = [
|
||||
f"USER: <video>\n{question} ASSISTANT:" for question in questions
|
||||
]
|
||||
prompts = [f"USER: <video>\n{question} ASSISTANT:" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
@ -455,19 +452,19 @@ def run_llava_next_video(questions: list[str],
|
||||
|
||||
|
||||
# LLaVA-OneVision
|
||||
def run_llava_onevision(questions: list[str],
|
||||
modality: str) -> ModelRequestData:
|
||||
|
||||
def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
if modality == "video":
|
||||
prompts = [
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n" for question in questions
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
elif modality == "image":
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n" for question in questions
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -486,11 +483,8 @@ def run_llava_onevision(questions: list[str],
|
||||
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
||||
prompts = [
|
||||
llama3_template.format(f"{question}\n<image>")
|
||||
for question in questions
|
||||
]
|
||||
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" # noqa: E501
|
||||
prompts = [llama3_template.format(f"{question}\n<image>") for question in questions]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
@ -530,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
# 2.6: image, video
|
||||
# o2.6: image, video, audio
|
||||
# model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
@ -547,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||
|
||||
# 2.6 / o2.6
|
||||
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
modality_placeholder = {
|
||||
@ -557,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[{
|
||||
'role': 'user',
|
||||
'content': f"{modality_placeholder[modality]}\n{question}"
|
||||
}],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{modality_placeholder[modality]}\n{question}",
|
||||
}
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True) for question in questions
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -622,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": question
|
||||
}]
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": question}],
|
||||
}
|
||||
]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -657,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": f"{question}"
|
||||
}]
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": f"{question}"}],
|
||||
}
|
||||
]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
stop_token_ids = None
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -693,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n" for question in questions
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -717,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -748,15 +742,13 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -847,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
prompts = [
|
||||
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
|
||||
for question in questions
|
||||
f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions
|
||||
]
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
@ -915,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -936,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompts = [
|
||||
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions
|
||||
(
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -950,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# Qwen2.5-VL
|
||||
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -971,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompts = [
|
||||
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions
|
||||
(
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -1007,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech.")
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions]
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
@ -1032,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for SkyworkR1V
|
||||
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
|
||||
@ -1104,8 +1103,7 @@ def get_multi_modal_input(args):
|
||||
"""
|
||||
if args.modality == "image":
|
||||
# Input image and question
|
||||
image = convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB")
|
||||
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
|
||||
img_questions = [
|
||||
"What is the content of this image?",
|
||||
"Describe the content of this image in detail.",
|
||||
@ -1120,8 +1118,7 @@ def get_multi_modal_input(args):
|
||||
|
||||
if args.modality == "video":
|
||||
# Input video and question
|
||||
video = VideoAsset(name="baby_reading",
|
||||
num_frames=args.num_frames).np_ndarrays
|
||||
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
|
||||
vid_questions = ["Why is this video funny?"]
|
||||
|
||||
return {
|
||||
@ -1133,12 +1130,13 @@ def get_multi_modal_input(args):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def apply_image_repeat(image_repeat_prob, num_prompts, data,
|
||||
prompts: list[str], modality):
|
||||
"""Repeats images with provided probability of "image_repeat_prob".
|
||||
def apply_image_repeat(
|
||||
image_repeat_prob, num_prompts, data, prompts: list[str], modality
|
||||
):
|
||||
"""Repeats images with provided probability of "image_repeat_prob".
|
||||
Used to simulate hit/miss for the MM preprocessor cache.
|
||||
"""
|
||||
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
|
||||
assert image_repeat_prob <= 1.0 and image_repeat_prob >= 0
|
||||
no_yes = [0, 1]
|
||||
probs = [1.0 - image_repeat_prob, image_repeat_prob]
|
||||
|
||||
@ -1153,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
|
||||
new_val = (i // 256 // 256, i // 256, i % 256)
|
||||
cur_image.putpixel((0, 0), new_val)
|
||||
|
||||
inputs.append({
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {
|
||||
modality: cur_image
|
||||
inputs.append(
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: cur_image},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
@ -1167,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
|
||||
def time_counter(enable: bool):
|
||||
if enable:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
@ -1179,54 +1178,65 @@ def time_counter(enable: bool):
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=['image', 'video'],
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for text generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=4, help="Number of prompts to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modality",
|
||||
type=str,
|
||||
default="image",
|
||||
choices=["image", "video"],
|
||||
help="Modality of the input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of frames to extract from the video.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
"--image-repeat-prob",
|
||||
type=float,
|
||||
default=None,
|
||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||
' (if enabled)')
|
||||
help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--time-generate',
|
||||
action='store_true',
|
||||
help='If True, then print the total generate() call time')
|
||||
"--time-generate",
|
||||
action="store_true",
|
||||
help="If True, then print the total generate() call time",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--use-different-prompt-per-request',
|
||||
action='store_true',
|
||||
help='If True, then use different prompt (with the same multi-modal '
|
||||
'data) for each request.')
|
||||
"--use-different-prompt-per-request",
|
||||
action="store_true",
|
||||
help="If True, then use different prompt (with the same multi-modal "
|
||||
"data) for each request.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -1245,7 +1255,8 @@ def main(args):
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {
|
||||
"seed": args.seed,
|
||||
@ -1254,44 +1265,46 @@ def main(args):
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# Don't want to check the flag multiple times, so just hijack `prompts`.
|
||||
prompts = req_data.prompts if args.use_different_prompt_per_request else [
|
||||
req_data.prompts[0]
|
||||
]
|
||||
prompts = (
|
||||
req_data.prompts
|
||||
if args.use_different_prompt_per_request
|
||||
else [req_data.prompts[0]]
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2,
|
||||
max_tokens=64,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
# Single inference
|
||||
inputs = {
|
||||
"prompt": prompts[0],
|
||||
"multi_modal_data": {
|
||||
modality: data
|
||||
},
|
||||
"multi_modal_data": {modality: data},
|
||||
}
|
||||
else:
|
||||
# Batch inference
|
||||
if args.image_repeat_prob is not None:
|
||||
# Repeat images with specified probability of "image_repeat_prob"
|
||||
inputs = apply_image_repeat(args.image_repeat_prob,
|
||||
args.num_prompts, data, prompts,
|
||||
modality)
|
||||
inputs = apply_image_repeat(
|
||||
args.image_repeat_prob, args.num_prompts, data, prompts, modality
|
||||
)
|
||||
else:
|
||||
# Use the same image for all prompts
|
||||
inputs = [{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {
|
||||
modality: data
|
||||
},
|
||||
} for i in range(args.num_prompts)]
|
||||
inputs = [
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: data},
|
||||
}
|
||||
for i in range(args.num_prompts)
|
||||
]
|
||||
|
||||
# Add LoRA request if applicable
|
||||
lora_request = (req_data.lora_requests *
|
||||
args.num_prompts if req_data.lora_requests else None)
|
||||
lora_request = (
|
||||
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
|
||||
)
|
||||
|
||||
with time_counter(args.time_generate):
|
||||
outputs = llm.generate(
|
||||
|
@ -6,6 +6,7 @@ the correct prompt format on vision language models for multimodal embedding.
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
|
||||
@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple):
|
||||
|
||||
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
|
||||
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
|
||||
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
prompt = llama3_template.format(
|
||||
f"{text}\nSummary above sentence in one word: ")
|
||||
prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = llama3_template.format(
|
||||
"<image>\nSummary above image in one word: ")
|
||||
prompt = llama3_template.format("<image>\nSummary above image in one word: ")
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query['modality']
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
image = query["image"]
|
||||
elif query["modality"] == "text+image":
|
||||
text = query["text"]
|
||||
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
|
||||
prompt = (
|
||||
f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
|
||||
)
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query['modality']
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
if req_data.image is not None:
|
||||
mm_data["image"] = req_data.image
|
||||
|
||||
outputs = llm.embed({
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
outputs = llm.embed(
|
||||
{
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
}
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
@ -164,23 +168,30 @@ model_example_map = {
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for multimodal embedding')
|
||||
parser.add_argument('--model-name',
|
||||
'-m',
|
||||
type=str,
|
||||
default="vlm2vec",
|
||||
choices=model_example_map.keys(),
|
||||
help='The name of the embedding model.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=get_args(QueryModality),
|
||||
help='Modality of the input.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for multimodal embedding"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
"-m",
|
||||
type=str,
|
||||
default="vlm2vec",
|
||||
choices=model_example_map.keys(),
|
||||
help="The name of the embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modality",
|
||||
type=str,
|
||||
default="image",
|
||||
choices=get_args(QueryModality),
|
||||
help="Modality of the input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ This example shows how to use vLLM for running offline inference with
|
||||
multi-image input on vision language models for text generation,
|
||||
using the chat template defined by the model.
|
||||
"""
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
|
||||
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
prompt = (
|
||||
f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
|
||||
return ModelRequestData(
|
||||
@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_deepseek_vl2(question: str,
|
||||
image_urls: list[str]) -> ModelRequestData:
|
||||
def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholder = "".join(f"image_{i}:<image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
placeholder = "".join(
|
||||
f"image_{i}:<image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
|
||||
|
||||
return ModelRequestData(
|
||||
@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for H2OVL-Mississippi
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
|
||||
@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
# if you are running out of memory, you can reduce the "longest_edge".
|
||||
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
|
||||
mm_processor_kwargs={
|
||||
"size": {
|
||||
"longest_edge": 2 * 364
|
||||
},
|
||||
"size": {"longest_edge": 2 * 364},
|
||||
},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={
|
||||
"max_image_size": {
|
||||
"longest_edge": 384
|
||||
},
|
||||
"max_image_size": {"longest_edge": 384},
|
||||
},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
)
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for InternVL
|
||||
# models variants may have different stop tokens
|
||||
@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -449,15 +442,15 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -509,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
)
|
||||
placeholders = "\n".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
placeholders = "\n".join(
|
||||
f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
|
||||
return ModelRequestData(
|
||||
@ -542,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
mm_processor_kwargs={"dynamic_hd": 4},
|
||||
)
|
||||
|
||||
placeholders = "".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
placeholders = "".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
return ModelRequestData(
|
||||
@ -554,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_qwen_vl_chat(question: str,
|
||||
image_urls: list[str]) -> ModelRequestData:
|
||||
def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
@ -565,24 +557,26 @@ def load_qwen_vl_chat(question: str,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
placeholders = "".join(
|
||||
f"Picture {i}: <img></img>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
|
||||
# This model does not have a chat_template attribute on its tokenizer,
|
||||
# so we need to explicitly pass it. We use ChatML since it's used in the
|
||||
# generation utils of the model:
|
||||
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
||||
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
@ -600,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||
'be automatically resized. You can enable this functionality by '
|
||||
'`pip install qwen-vl-utils`.')
|
||||
print(
|
||||
"WARNING: `qwen-vl-utils` not installed, input images will not "
|
||||
"be automatically resized. You can enable this functionality by "
|
||||
"`pip install qwen-vl-utils`."
|
||||
)
|
||||
process_vision_info = None
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
@ -616,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
@ -653,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||
'be automatically resized. You can enable this functionality by '
|
||||
'`pip install qwen-vl-utils`.')
|
||||
print(
|
||||
"WARNING: `qwen-vl-utils` not installed, input images will not "
|
||||
"be automatically resized. You can enable this functionality by "
|
||||
"`pip install qwen-vl-utils`."
|
||||
)
|
||||
process_vision_info = None
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
@ -668,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages,
|
||||
return_video_kwargs=False)
|
||||
image_data, _ = process_vision_info(messages, return_video_kwargs=False)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -726,23 +715,20 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=256,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": {
|
||||
"image": req_data.image_data
|
||||
},
|
||||
"multi_modal_data": {"image": req_data.image_data},
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
lora_request=req_data.lora_requests,
|
||||
@ -755,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str],
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_chat(model: str, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=256,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
outputs = llm.chat(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": question,
|
||||
},
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": question,
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
],
|
||||
}],
|
||||
*(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
}
|
||||
for image_url in image_urls
|
||||
),
|
||||
],
|
||||
}
|
||||
],
|
||||
sampling_params=sampling_params,
|
||||
chat_template=req_data.chat_template,
|
||||
lora_request=req_data.lora_requests,
|
||||
@ -801,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models that support multi-image input for text '
|
||||
'generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="phi3_v",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default="generate",
|
||||
choices=["generate", "chat"],
|
||||
help="The method to run in `vllm.LLM`.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models that support multi-image input for text "
|
||||
"generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="phi3_v",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="generate",
|
||||
choices=["generate", "chat"],
|
||||
help="The method to run in `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-images",
|
||||
"-n",
|
||||
type=int,
|
||||
choices=list(range(1,
|
||||
len(IMAGE_URLS) + 1)), # the max number of images
|
||||
choices=list(range(1, len(IMAGE_URLS) + 1)), # the max number of images
|
||||
default=2,
|
||||
help="Number of images to use for the demo.")
|
||||
help="Number of images to use for the demo.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -835,7 +830,7 @@ def main(args: Namespace):
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
image_urls = IMAGE_URLS[:args.num_images]
|
||||
image_urls = IMAGE_URLS[: args.num_images]
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, image_urls, seed)
|
||||
|
@ -17,16 +17,15 @@ import requests
|
||||
|
||||
|
||||
def clear_line(n: int = 1) -> None:
|
||||
LINE_UP = '\033[1A'
|
||||
LINE_CLEAR = '\x1b[2K'
|
||||
LINE_UP = "\033[1A"
|
||||
LINE_CLEAR = "\x1b[2K"
|
||||
for _ in range(n):
|
||||
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||
|
||||
|
||||
def post_http_request(prompt: str,
|
||||
api_url: str,
|
||||
n: int = 1,
|
||||
stream: bool = False) -> requests.Response:
|
||||
def post_http_request(
|
||||
prompt: str, api_url: str, n: int = 1, stream: bool = False
|
||||
) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
@ -35,17 +34,14 @@ def post_http_request(prompt: str,
|
||||
"max_tokens": 16,
|
||||
"stream": stream,
|
||||
}
|
||||
response = requests.post(api_url,
|
||||
headers=headers,
|
||||
json=pload,
|
||||
stream=stream)
|
||||
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
|
||||
return response
|
||||
|
||||
|
||||
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\n"):
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
||||
):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"]
|
||||
|
@ -6,6 +6,7 @@ Note that `pip install cohere` is needed to run this example.
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import cohere
|
||||
@ -16,28 +17,28 @@ model = "BAAI/bge-reranker-base"
|
||||
query = "What is the capital of France?"
|
||||
|
||||
documents = [
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
"The capital of France is Paris",
|
||||
"Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving",
|
||||
]
|
||||
|
||||
|
||||
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str,
|
||||
documents: list[str]) -> dict:
|
||||
def cohere_rerank(
|
||||
client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
|
||||
) -> dict:
|
||||
return client.rerank(model=model, query=query, documents=documents)
|
||||
|
||||
|
||||
def main():
|
||||
# cohere v1 client
|
||||
cohere_v1 = cohere.Client(base_url="http://localhost:8000",
|
||||
api_key="sk-fake-key")
|
||||
cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
||||
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
|
||||
print("-" * 50)
|
||||
print("rerank_v1_result:\n", rerank_v1_result)
|
||||
print("-" * 50)
|
||||
|
||||
# or the v2
|
||||
cohere_v2 = cohere.ClientV2("sk-fake-key",
|
||||
base_url="http://localhost:8000")
|
||||
cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
||||
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
|
||||
print("rerank_v2_result:\n", rerank_v2_result)
|
||||
print("-" * 50)
|
||||
|
@ -13,6 +13,7 @@ launch this proxy demo through:
|
||||
Note: This demo will be removed once the PDController implemented in PR 15343
|
||||
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ipaddress
|
||||
import itertools
|
||||
@ -26,8 +27,7 @@ from typing import Callable, Optional
|
||||
import aiohttp
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException,
|
||||
Request, status)
|
||||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
@ -36,24 +36,24 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class SchedulingPolicy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self, cycler: itertools.cycle):
|
||||
raise NotImplementedError("Scheduling Proxy is not set.")
|
||||
|
||||
|
||||
class Proxy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_instances: list[str],
|
||||
decode_instances: list[str],
|
||||
model: str,
|
||||
scheduling_policy: SchedulingPolicy,
|
||||
custom_create_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
custom_create_chat_completion: Optional[Callable[
|
||||
[Request], StreamingResponse]] = None,
|
||||
custom_create_completion: Optional[
|
||||
Callable[[Request], StreamingResponse]
|
||||
] = None,
|
||||
custom_create_chat_completion: Optional[
|
||||
Callable[[Request], StreamingResponse]
|
||||
] = None,
|
||||
):
|
||||
self.prefill_instances = prefill_instances
|
||||
self.decode_instances = decode_instances
|
||||
@ -68,30 +68,30 @@ class Proxy:
|
||||
|
||||
def setup_routes(self):
|
||||
self.router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[
|
||||
Depends(self.validate_json_request)
|
||||
])(self.custom_create_completion if self.
|
||||
custom_create_completion else self.create_completion)
|
||||
"/v1/completions", dependencies=[Depends(self.validate_json_request)]
|
||||
)(
|
||||
self.custom_create_completion
|
||||
if self.custom_create_completion
|
||||
else self.create_completion
|
||||
)
|
||||
self.router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[
|
||||
Depends(self.validate_json_request)
|
||||
])(self.custom_create_chat_completion if self.
|
||||
custom_create_chat_completion else self.create_chat_completion)
|
||||
self.router.get("/status",
|
||||
response_class=JSONResponse)(self.get_status)
|
||||
self.router.post("/instances/add",
|
||||
dependencies=[Depends(self.api_key_authenticate)
|
||||
])(self.add_instance_endpoint)
|
||||
"/v1/chat/completions", dependencies=[Depends(self.validate_json_request)]
|
||||
)(
|
||||
self.custom_create_chat_completion
|
||||
if self.custom_create_chat_completion
|
||||
else self.create_chat_completion
|
||||
)
|
||||
self.router.get("/status", response_class=JSONResponse)(self.get_status)
|
||||
self.router.post(
|
||||
"/instances/add", dependencies=[Depends(self.api_key_authenticate)]
|
||||
)(self.add_instance_endpoint)
|
||||
|
||||
async def validate_json_request(self, raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
if content_type != "application/json":
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail=
|
||||
"Unsupported Media Type: Only 'application/json' is allowed",
|
||||
detail="Unsupported Media Type: Only 'application/json' is allowed",
|
||||
)
|
||||
|
||||
def api_key_authenticate(self, x_api_key: str = Header(...)):
|
||||
@ -103,8 +103,7 @@ class Proxy:
|
||||
detail="Server configuration error.",
|
||||
)
|
||||
if x_api_key != expected_api_key:
|
||||
logger.warning("Unauthorized access attempt with API Key: %s",
|
||||
x_api_key)
|
||||
logger.warning("Unauthorized access attempt with API Key: %s", x_api_key)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Forbidden: Invalid API Key.",
|
||||
@ -113,8 +112,7 @@ class Proxy:
|
||||
async def validate_instance(self, instance: str) -> bool:
|
||||
url = f"http://{instance}/v1/models"
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=AIOHTTP_TIMEOUT) as client:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client:
|
||||
logger.info("Verifying %s ...", instance)
|
||||
async with client.get(url) as response:
|
||||
if response.status == 200:
|
||||
@ -122,12 +120,15 @@ class Proxy:
|
||||
if "data" in data and len(data["data"]) > 0:
|
||||
model_cur = data["data"][0].get("id", "")
|
||||
if model_cur == self.model:
|
||||
logger.info("Instance: %s could be added.",
|
||||
instance)
|
||||
logger.info("Instance: %s could be added.", instance)
|
||||
return True
|
||||
else:
|
||||
logger.warning("Mismatch model %s : %s != %s",
|
||||
instance, model_cur, self.model)
|
||||
logger.warning(
|
||||
"Mismatch model %s : %s != %s",
|
||||
instance,
|
||||
model_cur,
|
||||
self.model,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@ -147,48 +148,47 @@ class Proxy:
|
||||
instance_type = data.get("type")
|
||||
instance = data.get("instance")
|
||||
if instance_type not in ["prefill", "decode"]:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance type.")
|
||||
raise HTTPException(status_code=400, detail="Invalid instance type.")
|
||||
if not instance or ":" not in instance:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance format.")
|
||||
raise HTTPException(status_code=400, detail="Invalid instance format.")
|
||||
host, port_str = instance.split(":")
|
||||
try:
|
||||
if host != "localhost":
|
||||
ipaddress.ip_address(host)
|
||||
port = int(port_str)
|
||||
if not (0 < port < 65536):
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid port number.")
|
||||
raise HTTPException(status_code=400, detail="Invalid port number.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid instance address.") from e
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid instance address."
|
||||
) from e
|
||||
|
||||
is_valid = await self.validate_instance(instance)
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance validation failed.")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance validation failed."
|
||||
)
|
||||
|
||||
if instance_type == "prefill":
|
||||
if instance not in self.prefill_instances:
|
||||
self.prefill_instances.append(instance)
|
||||
self.prefill_cycler = itertools.cycle(
|
||||
self.prefill_instances)
|
||||
self.prefill_cycler = itertools.cycle(self.prefill_instances)
|
||||
else:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance already exists.")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance already exists."
|
||||
)
|
||||
else:
|
||||
if instance not in self.decode_instances:
|
||||
self.decode_instances.append(instance)
|
||||
self.decode_cycler = itertools.cycle(self.decode_instances)
|
||||
else:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Instance already exists.")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance already exists."
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message":
|
||||
f"Added {instance} to {instance_type}_instances."
|
||||
})
|
||||
return JSONResponse(
|
||||
content={"message": f"Added {instance} to {instance_type}_instances."}
|
||||
)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
@ -197,16 +197,16 @@ class Proxy:
|
||||
|
||||
async def forward_request(self, url, data, use_chunked=True):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
try:
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
async with session.post(
|
||||
url=url, json=data, headers=headers
|
||||
) as response:
|
||||
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
|
||||
if use_chunked:
|
||||
async for chunk_bytes in response.content.iter_chunked( # noqa: E501
|
||||
1024):
|
||||
1024
|
||||
):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
@ -217,20 +217,21 @@ class Proxy:
|
||||
error_content = json.loads(error_content)
|
||||
except json.JSONDecodeError:
|
||||
error_content = error_content
|
||||
logger.error("Request failed with status %s: %s",
|
||||
response.status, error_content)
|
||||
logger.error(
|
||||
"Request failed with status %s: %s",
|
||||
response.status,
|
||||
error_content,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=
|
||||
f"Request failed with status {response.status}: "
|
||||
detail=f"Request failed with status {response.status}: "
|
||||
f"{error_content}",
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error("ClientError occurred: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=
|
||||
"Bad Gateway: Error communicating with upstream server.",
|
||||
detail="Bad Gateway: Error communicating with upstream server.",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error: %s", str(e))
|
||||
@ -258,8 +259,8 @@ class Proxy:
|
||||
prefill_instance = self.schedule(self.prefill_cycler)
|
||||
try:
|
||||
async for _ in self.forward_request(
|
||||
f"http://{prefill_instance}/v1/completions",
|
||||
kv_prepare_request):
|
||||
f"http://{prefill_instance}/v1/completions", kv_prepare_request
|
||||
):
|
||||
continue
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("prefill", prefill_instance)
|
||||
@ -270,7 +271,8 @@ class Proxy:
|
||||
|
||||
try:
|
||||
generator = self.forward_request(
|
||||
f"http://{decode_instance}/v1/completions", request)
|
||||
f"http://{decode_instance}/v1/completions", request
|
||||
)
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("decode", decode_instance)
|
||||
raise http_exc
|
||||
@ -295,8 +297,8 @@ class Proxy:
|
||||
prefill_instance = self.schedule(self.prefill_cycler)
|
||||
try:
|
||||
async for _ in self.forward_request(
|
||||
f"http://{prefill_instance}/v1/chat/completions",
|
||||
kv_prepare_request):
|
||||
f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request
|
||||
):
|
||||
continue
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("prefill", prefill_instance)
|
||||
@ -306,8 +308,8 @@ class Proxy:
|
||||
|
||||
try:
|
||||
generator = self.forward_request(
|
||||
"http://" + decode_instance + "/v1/chat/completions",
|
||||
request)
|
||||
"http://" + decode_instance + "/v1/chat/completions", request
|
||||
)
|
||||
except HTTPException as http_exc:
|
||||
self.remove_instance_endpoint("decode", decode_instance)
|
||||
raise http_exc
|
||||
@ -318,20 +320,20 @@ class Proxy:
|
||||
error_messages = [str(e) for e in exc_info if e]
|
||||
print("Error occurred in disagg proxy server")
|
||||
print(error_messages)
|
||||
return StreamingResponse(content=iter(error_messages),
|
||||
media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
content=iter(error_messages), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
def remove_instance_endpoint(self, instance_type, instance):
|
||||
if (instance_type == "decode" and instance in self.decode_instances):
|
||||
if instance_type == "decode" and instance in self.decode_instances:
|
||||
self.decode_instances.remove(instance)
|
||||
self.decode_cycler = itertools.cycle(self.decode_instances)
|
||||
if (instance_type == "prefill" and instance in self.decode_instances):
|
||||
if instance_type == "prefill" and instance in self.decode_instances:
|
||||
self.prefill_instances.remove(instance)
|
||||
self.prefill_cycler = itertools.cycle(self.decode_instances)
|
||||
|
||||
|
||||
class RoundRobinSchedulingPolicy(SchedulingPolicy):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -340,15 +342,12 @@ class RoundRobinSchedulingPolicy(SchedulingPolicy):
|
||||
|
||||
|
||||
class ProxyServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: argparse.Namespace,
|
||||
scheduling_policy: Optional[SchedulingPolicy] = None,
|
||||
create_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
create_chat_completion: Optional[Callable[[Request],
|
||||
StreamingResponse]] = None,
|
||||
create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
|
||||
create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
|
||||
):
|
||||
self.validate_parsed_serve_args(args)
|
||||
self.port = args.port
|
||||
@ -356,8 +355,11 @@ class ProxyServer:
|
||||
prefill_instances=[] if args.prefill is None else args.prefill,
|
||||
decode_instances=[] if args.decode is None else args.decode,
|
||||
model=args.model,
|
||||
scheduling_policy=(scheduling_policy if scheduling_policy
|
||||
is not None else RoundRobinSchedulingPolicy()),
|
||||
scheduling_policy=(
|
||||
scheduling_policy
|
||||
if scheduling_policy is not None
|
||||
else RoundRobinSchedulingPolicy()
|
||||
),
|
||||
custom_create_completion=create_completion,
|
||||
custom_create_chat_completion=create_chat_completion,
|
||||
)
|
||||
@ -382,11 +384,9 @@ class ProxyServer:
|
||||
ipaddress.ip_address(host)
|
||||
port = int(port)
|
||||
if not (0 < port < 65536):
|
||||
raise ValueError(
|
||||
f"Invalid port number in instance: {instance}")
|
||||
raise ValueError(f"Invalid port number in instance: {instance}")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Invalid instance {instance}: {str(e)}") from e
|
||||
raise ValueError(f"Invalid instance {instance}: {str(e)}") from e
|
||||
|
||||
def verify_model_config(self, instances: list, model: str) -> None:
|
||||
model_suffix = model.split("/")[-1]
|
||||
@ -399,12 +399,14 @@ class ProxyServer:
|
||||
if model_cur_suffix != model_suffix:
|
||||
raise ValueError(
|
||||
f"{instance} serves a different model: "
|
||||
f"{model_cur} != {model}")
|
||||
f"{model_cur} != {model}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Cannot get model id from {instance}!")
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Error communicating with {instance}: {str(e)}") from e
|
||||
f"Error communicating with {instance}: {str(e)}"
|
||||
) from e
|
||||
|
||||
def run_server(self):
|
||||
app = FastAPI()
|
||||
@ -417,11 +419,7 @@ class ProxyServer:
|
||||
def parse_args():
|
||||
# Todo: allow more config
|
||||
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
|
||||
parser.add_argument("--model",
|
||||
"-m",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name")
|
||||
parser.add_argument("--model", "-m", type=str, required=True, help="Model name")
|
||||
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
|
@ -17,6 +17,7 @@ you can install it manually by following these steps:
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
@ -24,16 +25,12 @@ from openai import OpenAI
|
||||
|
||||
|
||||
def format_history_to_openai(history):
|
||||
history_openai_format = [{
|
||||
"role": "system",
|
||||
"content": "You are a great AI assistant."
|
||||
}]
|
||||
history_openai_format = [
|
||||
{"role": "system", "content": "You are a great AI assistant."}
|
||||
]
|
||||
for human, assistant in history:
|
||||
history_openai_format.append({"role": "user", "content": human})
|
||||
history_openai_format.append({
|
||||
"role": "assistant",
|
||||
"content": assistant
|
||||
})
|
||||
history_openai_format.append({"role": "assistant", "content": assistant})
|
||||
return history_openai_format
|
||||
|
||||
|
||||
@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
|
||||
temperature=temp,
|
||||
stream=True,
|
||||
extra_body={
|
||||
'repetition_penalty':
|
||||
1,
|
||||
'stop_token_ids':
|
||||
[int(id.strip())
|
||||
for id in stop_token_ids.split(',')] if stop_token_ids else []
|
||||
})
|
||||
"repetition_penalty": 1,
|
||||
"stop_token_ids": [int(id.strip()) for id in stop_token_ids.split(",")]
|
||||
if stop_token_ids
|
||||
else [],
|
||||
},
|
||||
)
|
||||
|
||||
# Collect all chunks and concatenate them into a full message
|
||||
full_message = ""
|
||||
for chunk in stream:
|
||||
full_message += (chunk.choices[0].delta.content or "")
|
||||
full_message += chunk.choices[0].delta.content or ""
|
||||
|
||||
# Return the full message as a single response
|
||||
return full_message
|
||||
@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
description="Chatbot Interface with Customizable Parameters"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model", type=str, required=True, help="Model name for the chatbot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="Temperature for text generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_gradio_interface(client, model_name, temp, stop_token_ids):
|
||||
|
||||
def chat_predict(message, history):
|
||||
return predict(message, history, client, model_name, temp,
|
||||
stop_token_ids)
|
||||
return predict(message, history, client, model_name, temp, stop_token_ids)
|
||||
|
||||
return gr.ChatInterface(fn=chat_predict,
|
||||
title="Chatbot Interface",
|
||||
description="A simple chatbot powered by vLLM")
|
||||
return gr.ChatInterface(
|
||||
fn=chat_predict,
|
||||
title="Chatbot Interface",
|
||||
description="A simple chatbot powered by vLLM",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@ -113,12 +106,13 @@ def main():
|
||||
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
|
||||
|
||||
# Define the Gradio chatbot interface using the predict function
|
||||
gradio_interface = build_gradio_interface(client, args.model, args.temp,
|
||||
args.stop_token_ids)
|
||||
gradio_interface = build_gradio_interface(
|
||||
client, args.model, args.temp, args.stop_token_ids
|
||||
)
|
||||
|
||||
gradio_interface.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
gradio_interface.queue().launch(
|
||||
server_name=args.host, server_port=args.port, share=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,6 +17,7 @@ you can install it manually by following these steps:
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@ -31,14 +32,11 @@ def http_bot(prompt):
|
||||
"stream": True,
|
||||
"max_tokens": 128,
|
||||
}
|
||||
response = requests.post(args.model_url,
|
||||
headers=headers,
|
||||
json=pload,
|
||||
stream=True)
|
||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
||||
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\n"):
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
||||
):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"][0]
|
||||
@ -48,10 +46,10 @@ def http_bot(prompt):
|
||||
def build_demo():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# vLLM text completion demo\n")
|
||||
inputbox = gr.Textbox(label="Input",
|
||||
placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output",
|
||||
placeholder="Generated result from the model")
|
||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(
|
||||
label="Output", placeholder="Generated result from the model"
|
||||
)
|
||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||
return demo
|
||||
|
||||
@ -60,17 +58,15 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
parser.add_argument(
|
||||
"--model-url", type=str, default="http://localhost:8000/generate"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
demo = build_demo()
|
||||
demo.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
demo.queue().launch(server_name=args.host, server_port=args.port, share=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -5,6 +5,7 @@ Jina and Cohere https://jina.ai/reranker
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
@ -14,14 +15,13 @@ url = "http://127.0.0.1:8000/rerank"
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model":
|
||||
"BAAI/bge-reranker-base",
|
||||
"query":
|
||||
"What is the capital of France?",
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.", "Horses and cows are both animals"
|
||||
]
|
||||
"The capital of France is Paris.",
|
||||
"Horses and cows are both animals",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
@ -9,17 +9,14 @@ from msgspec.msgpack import Decoder
|
||||
#
|
||||
# Types copied from vllm.distributed.kv_events
|
||||
#
|
||||
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
|
||||
gc=False):
|
||||
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
|
||||
|
||||
class KVCacheEvent(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False,
|
||||
tag=True):
|
||||
class KVCacheEvent(
|
||||
msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
|
||||
):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
@ -77,8 +74,9 @@ def main():
|
||||
|
||||
if last_seq >= 0 and seq > last_seq + 1:
|
||||
missed = seq - last_seq - 1
|
||||
print(f"Missed {missed} messages"
|
||||
f" (last: {last_seq}, current: {seq})")
|
||||
print(
|
||||
f"Missed {missed} messages (last: {last_seq}, current: {seq})"
|
||||
)
|
||||
|
||||
replay.send((last_seq + 1).to_bytes(8, "big"))
|
||||
|
||||
|
@ -12,26 +12,22 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The Los Angeles Dodgers won the World Series in 2020.",
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Client for vLLM API server")
|
||||
parser.add_argument("--stream",
|
||||
action="store_true",
|
||||
help="Enable streaming response")
|
||||
parser.add_argument(
|
||||
"--stream", action="store_true", help="Enable streaming response"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
|
||||
|
||||
with requests.get(content_url) as response:
|
||||
response.raise_for_status()
|
||||
result = base64.b64encode(response.content).decode('utf-8')
|
||||
result = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
return result
|
||||
|
||||
@ -51,10 +51,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
|
||||
# Text-only inference
|
||||
def run_text_only(model: str) -> None:
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's the capital of France?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "What's the capital of France?"}],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -65,26 +62,21 @@ def run_text_only(model: str) -> None:
|
||||
|
||||
# Single-image input inference
|
||||
def run_single_image(model: str) -> None:
|
||||
|
||||
## Use image url in the payload
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -95,22 +87,18 @@ def run_single_image(model: str) -> None:
|
||||
## Use base64 encoded image in the payload
|
||||
image_base64 = encode_base64_content_from_url(image_url)
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -124,28 +112,22 @@ def run_multi_image(model: str) -> None:
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the animals in these images?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_duck
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What are the animals in these images?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url_duck},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_lion
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url_lion},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -161,22 +143,18 @@ def run_video(model: str) -> None:
|
||||
|
||||
## Use video url in the payload
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {"url": video_url},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -186,22 +164,18 @@ def run_video(model: str) -> None:
|
||||
|
||||
## Use base64 encoded video in the payload
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": f"data:video/mp4;base64,{video_base64}"
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -219,24 +193,22 @@ def run_audio(model: str) -> None:
|
||||
|
||||
# OpenAI-compatible schema (`input_audio`)
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this audio?"
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
# Any format supported by librosa is supported
|
||||
"data": audio_base64,
|
||||
"format": "wav"
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this audio?"},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
# Any format supported by librosa is supported
|
||||
"data": audio_base64,
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -246,23 +218,21 @@ def run_audio(model: str) -> None:
|
||||
|
||||
# HTTP URL
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this audio?"
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
# Any format supported by librosa is supported
|
||||
"url": audio_url
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this audio?"},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
# Any format supported by librosa is supported
|
||||
"url": audio_url
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -272,23 +242,21 @@ def run_audio(model: str) -> None:
|
||||
|
||||
# base64 URL
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this audio?"
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
# Any format supported by librosa is supported
|
||||
"url": f"data:audio/ogg;base64,{audio_base64}"
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this audio?"},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
# Any format supported by librosa is supported
|
||||
"url": f"data:audio/ogg;base64,{audio_base64}"
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
@ -308,14 +276,17 @@ example_function_map = {
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using OpenAI client for online serving with '
|
||||
'multimodal language models served with vLLM.')
|
||||
parser.add_argument('--chat-type',
|
||||
'-c',
|
||||
type=str,
|
||||
default="single-image",
|
||||
choices=list(example_function_map.keys()),
|
||||
help='Conversation type with multimodal data.')
|
||||
description="Demo on using OpenAI client for online serving with "
|
||||
"multimodal language models served with vLLM."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-type",
|
||||
"-c",
|
||||
type=str,
|
||||
default="single-image",
|
||||
choices=list(example_function_map.keys()),
|
||||
help="Conversation type with multimodal data.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \
|
||||
--chat-template examples/tool_chat_template_hermes.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
@ -25,55 +26,55 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
properties = {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
}
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
def get_current_weather(city: str, state: str, unit: "str"):
|
||||
return (
|
||||
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's."
|
||||
)
|
||||
|
||||
|
||||
def handle_tool_calls_stream(
|
||||
@ -82,10 +83,9 @@ def handle_tool_calls_stream(
|
||||
model: str,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[Any]:
|
||||
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True)
|
||||
tool_calls_stream = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools, stream=True
|
||||
)
|
||||
chunks = []
|
||||
print("chunks: ")
|
||||
for chunk in tool_calls_stream:
|
||||
@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != tool_call_idx:
|
||||
if tool_call_idx >= 0:
|
||||
print(f"streamed tool call arguments: "
|
||||
f"{arguments[tool_call_idx]}")
|
||||
print(f"streamed tool call arguments: {arguments[tool_call_idx]}")
|
||||
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||
arguments.append("")
|
||||
if tool_call.id:
|
||||
@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(
|
||||
f"streamed tool call name: {tool_call.function.name}")
|
||||
print(f"streamed tool call name: {tool_call.function.name}")
|
||||
|
||||
if tool_call.function.arguments:
|
||||
arguments[tool_call_idx] += tool_call.function.arguments
|
||||
@ -136,9 +134,9 @@ def main():
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools)
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools
|
||||
)
|
||||
|
||||
print("-" * 70)
|
||||
print("Chat completion results:")
|
||||
@ -158,10 +156,12 @@ def main():
|
||||
print("-" * 70)
|
||||
|
||||
# Add tool call results to the conversation
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": chat_completion.choices[0].message.tool_calls
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": chat_completion.choices[0].message.tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
# Now, simulate a tool call
|
||||
available_tools = {"get_current_weather": get_current_weather}
|
||||
@ -172,17 +172,18 @@ def main():
|
||||
args = json.loads(call.function.arguments)
|
||||
result = tool_to_call(**args)
|
||||
print("tool_to_call result: ", result)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result,
|
||||
"tool_call_id": call.id,
|
||||
"name": call.function.name
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": result,
|
||||
"tool_call_id": call.id,
|
||||
"name": call.function.name,
|
||||
}
|
||||
)
|
||||
|
||||
chat_completion_2 = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=False)
|
||||
chat_completion_2 = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools, stream=False
|
||||
)
|
||||
print("Chat completion2 results:")
|
||||
print(chat_completion_2)
|
||||
print("-" * 70)
|
||||
|
@ -28,18 +28,16 @@ tools = [
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for"
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the "
|
||||
"city is in, e.g. 'CA' which would mean 'California'",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"the two-letter abbreviation for the state that the "
|
||||
"city is in, e.g. 'CA' which would mean 'California'"
|
||||
),
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
@ -60,22 +58,20 @@ tools = [
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'New York'",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The city to get the forecast for, e.g. 'New York'"
|
||||
),
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The two-letter abbreviation for the state, e.g. 'NY'",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The two-letter abbreviation for the state, e.g. 'NY'"
|
||||
),
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
"type": "integer",
|
||||
"description": "Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
@ -90,19 +86,11 @@ tools = [
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Dallas \
|
||||
"content": "Can you tell me what the current weather is in Dallas \
|
||||
and the forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
@ -123,17 +111,16 @@ def main():
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True # Enable streaming response
|
||||
stream=True, # Enable streaming response
|
||||
)
|
||||
|
||||
for chunk in chat_completion:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required")
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools, tool_choice="required"
|
||||
)
|
||||
|
||||
print(chat_completion.choices[0].message.tool_calls)
|
||||
|
||||
|
@ -20,10 +20,9 @@ openai_api_base = "http://localhost:8000/v1"
|
||||
def guided_choice_completion(client: OpenAI, model: str):
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Classify this sentiment: vLLM is wonderful!"
|
||||
}],
|
||||
messages=[
|
||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||
],
|
||||
extra_body={"guided_choice": ["positive", "negative"]},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
@ -31,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str):
|
||||
|
||||
# Guided decoding by Regex
|
||||
def guided_regex_completion(client: OpenAI, model: str):
|
||||
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n")
|
||||
prompt = (
|
||||
"Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n"
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={
|
||||
"guided_regex": r"\w+@\w+\.com\n",
|
||||
"stop": ["\n"]
|
||||
},
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
@ -66,14 +66,18 @@ class CarDescription(BaseModel):
|
||||
def guided_json_completion(client: OpenAI, model: str):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's")
|
||||
prompt = (
|
||||
"Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
@ -95,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
|
||||
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table.")
|
||||
prompt = (
|
||||
"Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table."
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
@ -110,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str):
|
||||
|
||||
# Extra backend options
|
||||
def extra_backend_options_completion(client: OpenAI, model: str):
|
||||
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n")
|
||||
prompt = (
|
||||
"Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n"
|
||||
)
|
||||
|
||||
try:
|
||||
# The guided_decoding_disable_fallback option forces vLLM to use
|
||||
# xgrammar, so when it fails you get a 400 with the reason why
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={
|
||||
"guided_regex": r"\w+@\w+\.com\n",
|
||||
"stop": ["\n"],
|
||||
|
@ -17,11 +17,10 @@ def main():
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
@ -58,29 +57,28 @@ You are a helpful assistant.
|
||||
|
||||
Given the previous instructions, what is the weather in New York City, Boston,
|
||||
and San Francisco?
|
||||
"""
|
||||
}]
|
||||
""",
|
||||
}
|
||||
]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=client.models.list().data[0].id,
|
||||
messages=messages,
|
||||
response_format={
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
})
|
||||
"type": "structural_tag",
|
||||
"structures": [
|
||||
{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
"end": "</function>",
|
||||
}
|
||||
],
|
||||
"triggers": ["<function="],
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
|
@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
|
||||
def print_completion_details(completion):
|
||||
print("reasoning_content: ",
|
||||
completion.choices[0].message.reasoning_content)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
|
||||
|
||||
# Guided decoding by Regex
|
||||
def guided_regex_completion(client: OpenAI, model: str):
|
||||
prompt = ("What is the capital of France?")
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={
|
||||
"guided_regex": "(Paris|London)",
|
||||
},
|
||||
@ -57,13 +58,15 @@ class People(BaseModel):
|
||||
def guided_json_completion(client: OpenAI, model: str):
|
||||
json_schema = People.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the name and age of one random person.")
|
||||
prompt = "Generate a JSON with the name and age of one random person."
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print_completion_details(completion)
|
||||
@ -86,14 +89,18 @@ class CarDescription(BaseModel):
|
||||
def guided_car_json_completion(client: OpenAI, model: str):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's")
|
||||
prompt = (
|
||||
"Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print_completion_details(completion)
|
||||
@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
|
||||
"""
|
||||
|
||||
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
|
||||
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table.")
|
||||
prompt = (
|
||||
"Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table."
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||
)
|
||||
print_completion_details(completion)
|
||||
|
@ -20,9 +20,11 @@ from openai import OpenAI
|
||||
|
||||
|
||||
# Now, simulate a tool call
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
def get_current_weather(city: str, state: str, unit: "str"):
|
||||
return (
|
||||
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's."
|
||||
)
|
||||
|
||||
|
||||
available_tools = {"get_current_weather": get_current_weather}
|
||||
@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather}
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
properties = {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
}
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}
|
||||
}]
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
]
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def extract_reasoning_and_calls(chunks: list):
|
||||
@ -110,73 +110,55 @@ def main():
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
print(
|
||||
"---------Full Generate With Automatic Function Calling-------------")
|
||||
tool_calls = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools)
|
||||
print(
|
||||
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
|
||||
print("---------Full Generate With Automatic Function Calling-------------")
|
||||
tool_calls = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools
|
||||
)
|
||||
print(f"function name: "
|
||||
f"{tool_calls.choices[0].message.tool_calls[0].function.name}")
|
||||
print(f"function arguments: "
|
||||
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
|
||||
|
||||
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
|
||||
print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
|
||||
print(
|
||||
"----------Stream Generate With Automatic Function Calling-----------")
|
||||
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True)
|
||||
f"function arguments: "
|
||||
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}"
|
||||
)
|
||||
|
||||
print("----------Stream Generate With Automatic Function Calling-----------")
|
||||
tool_calls_stream = client.chat.completions.create(
|
||||
messages=messages, model=model, tools=tools, stream=True
|
||||
)
|
||||
|
||||
chunks = list(tool_calls_stream)
|
||||
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||
chunks)
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
|
||||
print(f"reasoning_content: {reasoning_content}")
|
||||
print(f"function name: {function_names[0]}")
|
||||
print(f"function arguments: {arguments[0]}")
|
||||
|
||||
print(
|
||||
"----------Full Generate With Named Function Calling-----------------")
|
||||
tool_calls = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"get_current_weather"
|
||||
}
|
||||
})
|
||||
print("----------Full Generate With Named Function Calling-----------------")
|
||||
tool_calls = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
|
||||
)
|
||||
|
||||
tool_call = tool_calls.choices[0].message.tool_calls[0].function
|
||||
print(
|
||||
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
|
||||
)
|
||||
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
|
||||
print(f"function name: {tool_call.name}")
|
||||
print(f"function arguments: {tool_call.arguments}")
|
||||
print(
|
||||
"----------Stream Generate With Named Function Calling--------------")
|
||||
print("----------Stream Generate With Named Function Calling--------------")
|
||||
|
||||
tool_calls_stream = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
},
|
||||
stream=True)
|
||||
tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = list(tool_calls_stream)
|
||||
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||
chunks)
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
print(f"reasoning_content: {reasoning_content}")
|
||||
print(f"function name: {function_names[0]}")
|
||||
print(f"function arguments: {arguments[0]}")
|
||||
|
@ -45,12 +45,12 @@ def main():
|
||||
|
||||
# Round 2
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
messages.append({
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"How many Rs are there in the word 'strawberry'?",
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How many Rs are there in the word 'strawberry'?",
|
||||
}
|
||||
)
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
|
@ -43,9 +43,7 @@ def main():
|
||||
|
||||
# ruff: noqa: E501
|
||||
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
stream = client.chat.completions.create(model=model, messages=messages, stream=True)
|
||||
|
||||
print("client: Start streaming chat completions...")
|
||||
printed_reasoning_content = False
|
||||
|
@ -14,26 +14,17 @@ def vlm2vec():
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/embeddings",
|
||||
json={
|
||||
"model":
|
||||
"TIGER-Lab/VLM2Vec-Full",
|
||||
"messages": [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Represent the given image."
|
||||
},
|
||||
],
|
||||
}],
|
||||
"encoding_format":
|
||||
"float",
|
||||
"model": "TIGER-Lab/VLM2Vec-Full",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "Represent the given image."},
|
||||
],
|
||||
}
|
||||
],
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -45,19 +36,20 @@ def vlm2vec():
|
||||
def dse_qwen2_vl(inp: dict):
|
||||
# Embedding an Image
|
||||
if inp["type"] == "image":
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": inp["image_url"],
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?"
|
||||
}]
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": inp["image_url"],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Embedding a Text Query
|
||||
else:
|
||||
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||
@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict):
|
||||
image_placeholder = Image.new("RGB", (56, 56))
|
||||
image_placeholder.save(buffer, "png")
|
||||
buffer.seek(0)
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Query: {inp['content']}"
|
||||
},
|
||||
]
|
||||
}]
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": f"Query: {inp['content']}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/embeddings",
|
||||
@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict):
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Script to call a specified VLM through the API. Make sure to serve "
|
||||
"the model with --task embed before running this.")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
choices=["vlm2vec", "dse_qwen2_vl"],
|
||||
required=True,
|
||||
help="Which model to call.")
|
||||
"the model with --task embed before running this."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=["vlm2vec", "dse_qwen2_vl"],
|
||||
required=True,
|
||||
help="Which model to call.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -114,16 +107,20 @@ def main(args):
|
||||
if args.model == "vlm2vec":
|
||||
vlm2vec()
|
||||
elif args.model == "dse_qwen2_vl":
|
||||
dse_qwen2_vl({
|
||||
"type": "image",
|
||||
"image_url": image_url,
|
||||
})
|
||||
dse_qwen2_vl({
|
||||
"type": "text",
|
||||
"content": "What is the weather like today?",
|
||||
})
|
||||
dse_qwen2_vl(
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": image_url,
|
||||
}
|
||||
)
|
||||
dse_qwen2_vl(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "What is the weather like today?",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -16,9 +16,7 @@ def parse_args():
|
||||
parse = argparse.ArgumentParser()
|
||||
parse.add_argument("--host", type=str, default="localhost")
|
||||
parse.add_argument("--port", type=int, default=8000)
|
||||
parse.add_argument("--model",
|
||||
type=str,
|
||||
default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
return parse.parse_args()
|
||||
|
||||
|
||||
|
@ -11,9 +11,9 @@ openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Client for vLLM API server")
|
||||
parser.add_argument("--stream",
|
||||
action="store_true",
|
||||
help="Enable streaming response")
|
||||
parser.add_argument(
|
||||
"--stream", action="store_true", help="Enable streaming response"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -34,7 +34,8 @@ def main(args):
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=args.stream,
|
||||
logprobs=3)
|
||||
logprobs=3,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("Completion results:")
|
||||
|
@ -4,6 +4,7 @@ Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --task score` to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
@ -38,9 +39,7 @@ def main(args):
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 is string and text_2 is a list:")
|
||||
@ -48,12 +47,8 @@ def main(args):
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = [
|
||||
"What is the capital of Brazil?", "What is the capital of France?"
|
||||
]
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
|
||||
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 and text_2 are both lists:")
|
||||
|
@ -21,7 +21,7 @@ def main():
|
||||
# ruff: noqa: E501
|
||||
input=[
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ Example online usage of Pooling API.
|
||||
Run `vllm serve <model> --task <embed|classify|reward|score>`
|
||||
to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
@ -21,9 +22,7 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -42,15 +41,13 @@ def main(args):
|
||||
|
||||
# Input like Chat API
|
||||
prompt = {
|
||||
"model":
|
||||
model_name,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "vLLM is great!"
|
||||
}],
|
||||
}]
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "vLLM is great!"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Pooling Response:")
|
||||
|
@ -7,8 +7,8 @@ from openai import OpenAI
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path()
|
||||
winning_call = AudioAsset('winning_call').get_local_path()
|
||||
mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
|
||||
winning_call = AudioAsset("winning_call").get_local_path()
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
@ -31,7 +31,8 @@ def sync_openai():
|
||||
extra_body=dict(
|
||||
seed=4419,
|
||||
repetition_penalty=1.3,
|
||||
))
|
||||
),
|
||||
)
|
||||
print("transcription result:", transcription.text)
|
||||
|
||||
|
||||
@ -42,33 +43,30 @@ sync_openai()
|
||||
async def stream_openai_response():
|
||||
data = {
|
||||
"language": "en",
|
||||
'stream': True,
|
||||
"stream": True,
|
||||
"model": "openai/whisper-large-v3",
|
||||
}
|
||||
url = openai_api_base + "/audio/transcriptions"
|
||||
headers = {"Authorization": f"Bearer {openai_api_key}"}
|
||||
print("transcription result:", end=' ')
|
||||
print("transcription result:", end=" ")
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(str(winning_call), "rb") as f:
|
||||
async with client.stream('POST',
|
||||
url,
|
||||
files={'file': f},
|
||||
data=data,
|
||||
headers=headers) as response:
|
||||
async with client.stream(
|
||||
"POST", url, files={"file": f}, data=data, headers=headers
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
# Each line is a JSON object prefixed with 'data: '
|
||||
if line:
|
||||
if line.startswith('data: '):
|
||||
line = line[len('data: '):]
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
# Last chunk, stream ends
|
||||
if line.strip() == '[DONE]':
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
# Parse the JSON response
|
||||
chunk = json.loads(line)
|
||||
# Extract and print the content
|
||||
content = chunk['choices'][0].get('delta',
|
||||
{}).get('content')
|
||||
print(content, end='')
|
||||
content = chunk["choices"][0].get("delta", {}).get("content")
|
||||
print(content, end="")
|
||||
|
||||
|
||||
# Run the asynchronous function
|
||||
|
@ -1,14 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import requests
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||
OTLPSpanExporter)
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import (BatchSpanProcessor,
|
||||
ConsoleSpanExporter)
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
|
||||
from opentelemetry.trace import SpanKind, set_tracer_provider
|
||||
from opentelemetry.trace.propagation.tracecontext import (
|
||||
TraceContextTextMapPropagator)
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
trace_provider = TracerProvider()
|
||||
set_tracer_provider(trace_provider)
|
||||
|
@ -26,6 +26,7 @@ Dependencies:
|
||||
- torch
|
||||
- openai
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
@ -44,17 +45,13 @@ def main():
|
||||
|
||||
# Transformers
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_name)
|
||||
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
chat = [{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}]
|
||||
token_ids = tokenizer.apply_chat_template(chat,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt')
|
||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
chat, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
embedding_layer = transformers_model.get_input_embeddings()
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
@ -64,7 +61,7 @@ def main():
|
||||
torch.save(prompt_embeds, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
|
||||
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
@ -75,7 +72,8 @@ def main():
|
||||
temperature=0.0,
|
||||
# NOTE: The OpenAI client allows passing in extra JSON body via the
|
||||
# `extra_body` argument.
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
|
||||
print("-" * 30)
|
||||
print(completion.choices[0].text)
|
||||
|
@ -28,9 +28,7 @@ llm_config = LLMConfig(
|
||||
},
|
||||
# Change to the accelerator type of the node
|
||||
accelerator_type="H100",
|
||||
runtime_env={"env_vars": {
|
||||
"VLLM_USE_V1": "1"
|
||||
}},
|
||||
runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
|
||||
# Customize engine arguments as needed (e.g. vLLM engine kwargs)
|
||||
engine_kwargs={
|
||||
"tensor_parallel_size": 8,
|
||||
|
@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]):
|
||||
Load and split documents from web URL
|
||||
"""
|
||||
try:
|
||||
loader = WebBaseLoader(web_paths=(config["url"], ))
|
||||
loader = WebBaseLoader(web_paths=(config["url"],))
|
||||
docs = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
|
||||
"""
|
||||
Set up question answering chain
|
||||
"""
|
||||
return ({
|
||||
"context": retriever | format_docs,
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser())
|
||||
return (
|
||||
{
|
||||
"context": retriever | format_docs,
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Parse command line arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain')
|
||||
parser = argparse.ArgumentParser(description="RAG with vLLM and langchain")
|
||||
|
||||
# Add command line arguments
|
||||
parser.add_argument('--vllm-api-key',
|
||||
default="EMPTY",
|
||||
help='API key for vLLM compatible services')
|
||||
parser.add_argument('--vllm-embedding-endpoint',
|
||||
default="http://localhost:8000/v1",
|
||||
help='Base URL for embedding service')
|
||||
parser.add_argument('--vllm-chat-endpoint',
|
||||
default="http://localhost:8001/v1",
|
||||
help='Base URL for chat service')
|
||||
parser.add_argument('--uri',
|
||||
default="./milvus.db",
|
||||
help='URI for Milvus database')
|
||||
parser.add_argument(
|
||||
'--url',
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"quickstart.html"),
|
||||
help='URL of the document to process')
|
||||
parser.add_argument('--embedding-model',
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help='Model name for embeddings')
|
||||
parser.add_argument('--chat-model',
|
||||
default="qwen/Qwen1.5-0.5B-Chat",
|
||||
help='Model name for chat')
|
||||
parser.add_argument('-i',
|
||||
'--interactive',
|
||||
action='store_true',
|
||||
help='Enable interactive Q&A mode')
|
||||
parser.add_argument('-k',
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of top results to retrieve')
|
||||
parser.add_argument('-c',
|
||||
'--chunk-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Chunk size for document splitting')
|
||||
parser.add_argument('-o',
|
||||
'--chunk-overlap',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Chunk overlap for document splitting')
|
||||
"--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-embedding-endpoint",
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL for embedding service",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-chat-endpoint",
|
||||
default="http://localhost:8001/v1",
|
||||
help="Base URL for chat service",
|
||||
)
|
||||
parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
|
||||
help="URL of the document to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help="Model name for embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Chunk size for document splitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Chunk overlap for document splitting",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -198,7 +205,7 @@ def init_config(args: Namespace):
|
||||
"url": args.url,
|
||||
"chunk_size": args.chunk_size,
|
||||
"chunk_overlap": args.chunk_overlap,
|
||||
"top_k": args.top_k
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
|
||||
|
||||
@ -230,7 +237,7 @@ def main():
|
||||
|
||||
while True:
|
||||
question = input("\nPlease enter your question: ")
|
||||
if question.lower() in ['q', 'quit']:
|
||||
if question.lower() in ["q", "quit"]:
|
||||
print("\nThank you for using! Goodbye!")
|
||||
break
|
||||
|
||||
@ -238,7 +245,7 @@ def main():
|
||||
print(output)
|
||||
else:
|
||||
# Default single question mode
|
||||
question = ("How to install vLLM?")
|
||||
question = "How to install vLLM?"
|
||||
output = qa_chain.invoke(question)
|
||||
print("-" * 50)
|
||||
print(output)
|
||||
|
@ -35,6 +35,7 @@ Notes:
|
||||
- Default ports: 8000 (embedding), 8001 (chat)
|
||||
- First run may take time to download models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from typing import Any
|
||||
@ -59,7 +60,7 @@ def init_config(args: Namespace):
|
||||
"db_path": args.db_path,
|
||||
"chunk_size": args.chunk_size,
|
||||
"chunk_overlap": args.chunk_overlap,
|
||||
"top_k": args.top_k
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
|
||||
|
||||
@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int):
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='RAG with vLLM and LlamaIndex')
|
||||
parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex")
|
||||
|
||||
# Add command line arguments
|
||||
parser.add_argument(
|
||||
'--url',
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"quickstart.html"),
|
||||
help='URL of the document to process')
|
||||
parser.add_argument('--embedding-model',
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help='Model name for embeddings')
|
||||
parser.add_argument('--chat-model',
|
||||
default="qwen/Qwen1.5-0.5B-Chat",
|
||||
help='Model name for chat')
|
||||
parser.add_argument('--vllm-api-key',
|
||||
default="EMPTY",
|
||||
help='API key for vLLM compatible services')
|
||||
parser.add_argument('--embedding-endpoint',
|
||||
default="http://localhost:8000/v1",
|
||||
help='Base URL for embedding service')
|
||||
parser.add_argument('--chat-endpoint',
|
||||
default="http://localhost:8001/v1",
|
||||
help='Base URL for chat service')
|
||||
parser.add_argument('--db-path',
|
||||
default="./milvus_demo.db",
|
||||
help='Path to Milvus database')
|
||||
parser.add_argument('-i',
|
||||
'--interactive',
|
||||
action='store_true',
|
||||
help='Enable interactive Q&A mode')
|
||||
parser.add_argument('-c',
|
||||
'--chunk-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Chunk size for document splitting')
|
||||
parser.add_argument('-o',
|
||||
'--chunk-overlap',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Chunk overlap for document splitting')
|
||||
parser.add_argument('-k',
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of top results to retrieve')
|
||||
"--url",
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
|
||||
help="URL of the document to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help="Model name for embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-endpoint",
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL for embedding service",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-endpoint",
|
||||
default="http://localhost:8001/v1",
|
||||
help="Base URL for chat service",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path", default="./milvus_demo.db", help="Path to Milvus database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Chunk size for document splitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Chunk overlap for document splitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -193,7 +200,7 @@ def main():
|
||||
question = input("\nEnter your question: ")
|
||||
|
||||
# Check for exit command
|
||||
if question.lower() in ['quit', 'exit', 'q']:
|
||||
if question.lower() in ["quit", "exit", "q"]:
|
||||
print("Exiting interactive mode...")
|
||||
break
|
||||
|
||||
|
@ -26,6 +26,7 @@ Usage:
|
||||
streamlit run streamlit_openai_chatbot_webserver.py \
|
||||
--logger.level=debug
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
@ -33,8 +34,8 @@ import streamlit as st
|
||||
from openai import OpenAI
|
||||
|
||||
# Get command line arguments from environment variables
|
||||
openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY")
|
||||
openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1")
|
||||
openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
|
||||
openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
|
||||
|
||||
# Initialize session states for managing chat sessions
|
||||
if "sessions" not in st.session_state:
|
||||
@ -81,9 +82,9 @@ def get_llm_response(messages, model):
|
||||
Streaming response object or error message string
|
||||
"""
|
||||
try:
|
||||
response = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, stream=True
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
st.error(f"Error details: {str(e)}")
|
||||
@ -92,8 +93,9 @@ def get_llm_response(messages, model):
|
||||
|
||||
# Sidebar - API Settings first
|
||||
st.sidebar.title("API Settings")
|
||||
new_api_base = st.sidebar.text_input("API Base URL:",
|
||||
value=st.session_state.api_base_url)
|
||||
new_api_base = st.sidebar.text_input(
|
||||
"API Base URL:", value=st.session_state.api_base_url
|
||||
)
|
||||
if new_api_base != st.session_state.api_base_url:
|
||||
st.session_state.api_base_url = new_api_base
|
||||
st.rerun()
|
||||
@ -109,16 +111,20 @@ if st.sidebar.button("New Session"):
|
||||
for session_id in sorted(st.session_state.sessions.keys(), reverse=True):
|
||||
# Mark the active session with a pinned button
|
||||
if session_id == st.session_state.active_session:
|
||||
st.sidebar.button(f"📍 {session_id}",
|
||||
key=session_id,
|
||||
type="primary",
|
||||
on_click=switch_to_chat_session,
|
||||
args=(session_id, ))
|
||||
st.sidebar.button(
|
||||
f"📍 {session_id}",
|
||||
key=session_id,
|
||||
type="primary",
|
||||
on_click=switch_to_chat_session,
|
||||
args=(session_id,),
|
||||
)
|
||||
else:
|
||||
st.sidebar.button(f"Session {session_id}",
|
||||
key=session_id,
|
||||
on_click=switch_to_chat_session,
|
||||
args=(session_id, ))
|
||||
st.sidebar.button(
|
||||
f"Session {session_id}",
|
||||
key=session_id,
|
||||
on_click=switch_to_chat_session,
|
||||
args=(session_id,),
|
||||
)
|
||||
|
||||
# Main interface
|
||||
st.title("vLLM Chat Assistant")
|
||||
@ -145,18 +151,18 @@ for message in st.session_state.messages:
|
||||
if prompt := st.chat_input("Type your message here..."):
|
||||
# Save user message to session
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
st.session_state.sessions[
|
||||
st.session_state.current_session] = st.session_state.messages
|
||||
st.session_state.sessions[st.session_state.current_session] = (
|
||||
st.session_state.messages
|
||||
)
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.write(prompt)
|
||||
|
||||
# Prepare messages for llm
|
||||
messages_for_llm = [{
|
||||
"role": m["role"],
|
||||
"content": m["content"]
|
||||
} for m in st.session_state.messages]
|
||||
messages_for_llm = [
|
||||
{"role": m["role"], "content": m["content"]} for m in st.session_state.messages
|
||||
]
|
||||
|
||||
# Generate and display llm response
|
||||
with st.chat_message("assistant"):
|
||||
@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."):
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
# Save llm response to session history
|
||||
st.session_state.messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_response
|
||||
})
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
|
@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str:
|
||||
f"{client.base_url} with API key {client.api_key}. Check\n"
|
||||
"1. the server is running\n"
|
||||
"2. the server URL is correct\n"
|
||||
"3. the API key is correct") from e
|
||||
"3. the API key is correct"
|
||||
) from e
|
||||
|
||||
if len(models.data) == 0:
|
||||
raise RuntimeError(
|
||||
f"No models found on the vLLM server at {client.base_url}")
|
||||
raise RuntimeError(f"No models found on the vLLM server at {client.base_url}")
|
||||
|
||||
return models.data[0].id
|
||||
|
@ -20,6 +20,7 @@ Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
|
||||
Learn more about LMCache environment setup, please refer to:
|
||||
https://docs.lmcache.ai/getting_started/installation.html
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
@ -49,8 +50,7 @@ def setup_environment_variables(vllm_version: str):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_llm_with_lmcache(lmcache_connector: str, model: str,
|
||||
vllm_version: str):
|
||||
def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector=lmcache_connector,
|
||||
kv_role="kv_both",
|
||||
@ -97,18 +97,19 @@ def print_output(
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print(f"Generation took {time.time() - start:.2f} seconds, "
|
||||
f"{req_str} request done.")
|
||||
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-v",
|
||||
"--version",
|
||||
choices=["v0", "v1"],
|
||||
default="v1",
|
||||
help="Specify vLLM version (default: v1)")
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
choices=["v0", "v1"],
|
||||
default="v1",
|
||||
help="Specify vLLM version (default: v1)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -125,7 +126,6 @@ def main():
|
||||
setup_environment_variables(args.version)
|
||||
|
||||
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
|
||||
|
||||
# This example script runs two requests with a shared prefix.
|
||||
# Define the shared prompt and specific prompts
|
||||
shared_prompt = "Hello, how are you?" * 1000
|
||||
@ -136,9 +136,7 @@ def main():
|
||||
shared_prompt + "Tell me a very long story",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
top_p=0.95,
|
||||
max_tokens=10)
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
# Print the first output
|
||||
print_output(llm, first_prompt, sampling_params, "first")
|
||||
|
@ -4,12 +4,13 @@ This file demonstrates the example usage of disaggregated prefilling
|
||||
with LMCache.
|
||||
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
|
||||
and launch an additional LMCache server.
|
||||
KV cache is transferred in the following manner:
|
||||
KV cache is transferred in the following manner:
|
||||
vLLM prefill node -> LMCache server -> vLLM decode node.
|
||||
|
||||
Note that `pip install lmcache` is needed to run this example.
|
||||
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
@ -49,19 +50,23 @@ def run_prefill(prefill_done, prompts):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2)
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="LMCacheConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
#llm.generate(prompts, sampling_params)
|
||||
# llm.generate(prompts, sampling_params)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
@ -79,17 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2)
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="LMCacheConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# of memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
@ -105,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1):
|
||||
|
||||
|
||||
def run_lmcache_server(port):
|
||||
server_proc = subprocess.Popen([
|
||||
"python", "-m", "lmcache.experimental.server", "localhost",
|
||||
str(port)
|
||||
])
|
||||
server_proc = subprocess.Popen(
|
||||
["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
|
||||
)
|
||||
return server_proc
|
||||
|
||||
|
||||
|
@ -17,13 +17,17 @@ async def lifespan(app: FastAPI):
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize clients
|
||||
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
|
||||
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
|
||||
prefiller_base_url = (
|
||||
f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
|
||||
)
|
||||
decoder_base_url = (
|
||||
f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
|
||||
)
|
||||
|
||||
app.state.prefill_client = httpx.AsyncClient(timeout=None,
|
||||
base_url=prefiller_base_url)
|
||||
app.state.decode_client = httpx.AsyncClient(timeout=None,
|
||||
base_url=decoder_base_url)
|
||||
app.state.prefill_client = httpx.AsyncClient(
|
||||
timeout=None, base_url=prefiller_base_url
|
||||
)
|
||||
app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url)
|
||||
|
||||
yield
|
||||
|
||||
@ -37,7 +41,6 @@ app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class StatsCalculator:
|
||||
|
||||
def __init__(self):
|
||||
self._stats = []
|
||||
self._last_log_time = time.time()
|
||||
@ -51,13 +54,18 @@ class StatsCalculator:
|
||||
def _log_stats(self):
|
||||
# Print average, median, and 99th percentile
|
||||
np_arr = np.array(self._stats)
|
||||
output_str = f"\nNum requests: {len(self._stats)}" + \
|
||||
"\nPrefill node TTFT stats:" + \
|
||||
f"\n - Average (ms): {np.mean(np_arr)}" + \
|
||||
f"\n - Median (ms): {np.median(np_arr)}" + \
|
||||
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
|
||||
print("===============================", output_str,
|
||||
"===============================")
|
||||
output_str = (
|
||||
f"\nNum requests: {len(self._stats)}"
|
||||
+ "\nPrefill node TTFT stats:"
|
||||
+ f"\n - Average (ms): {np.mean(np_arr)}"
|
||||
+ f"\n - Median (ms): {np.median(np_arr)}"
|
||||
+ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
|
||||
)
|
||||
print(
|
||||
"===============================",
|
||||
output_str,
|
||||
"===============================",
|
||||
)
|
||||
|
||||
|
||||
stats_calculator = StatsCalculator()
|
||||
@ -82,15 +90,16 @@ app.state.prefill_client = None
|
||||
app.state.decode_client = None
|
||||
|
||||
|
||||
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
|
||||
req_data: dict):
|
||||
async def send_request_to_service(
|
||||
client: httpx.AsyncClient, endpoint: str, req_data: dict
|
||||
):
|
||||
"""
|
||||
Send a request to a service using a persistent client.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data['max_tokens'] = 1
|
||||
if 'max_completion_tokens' in req_data:
|
||||
req_data['max_completion_tokens'] = 1
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||
@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
|
||||
req_data: dict):
|
||||
async def stream_service_response(
|
||||
client: httpx.AsyncClient, endpoint: str, req_data: dict
|
||||
):
|
||||
"""
|
||||
Asynchronously stream the response from a service using a persistent client.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
async with client.stream("POST", endpoint, json=req_data,
|
||||
headers=headers) as response:
|
||||
async with client.stream(
|
||||
"POST", endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
@ -121,28 +132,28 @@ async def handle_completions(request: Request):
|
||||
req_data = await request.json()
|
||||
|
||||
# Send request to prefill service, ignore the response
|
||||
await send_request_to_service(app.state.prefill_client, "/completions",
|
||||
req_data)
|
||||
await send_request_to_service(
|
||||
app.state.prefill_client, "/completions", req_data
|
||||
)
|
||||
|
||||
et = time.time()
|
||||
stats_calculator.add(et - st)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(app.state.decode_client,
|
||||
"/completions",
|
||||
req_data):
|
||||
async for chunk in stream_service_response(
|
||||
app.state.decode_client, "/completions", req_data
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="text/event-stream")
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
" - completions endpoint")
|
||||
print("Error occurred in disagg prefill proxy server - completions endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request):
|
||||
req_data = await request.json()
|
||||
|
||||
# Send request to prefill service, ignore the response
|
||||
await send_request_to_service(app.state.prefill_client,
|
||||
"/chat/completions", req_data)
|
||||
await send_request_to_service(
|
||||
app.state.prefill_client, "/chat/completions", req_data
|
||||
)
|
||||
|
||||
et = time.time()
|
||||
stats_calculator.add(et - st)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(app.state.decode_client,
|
||||
"/chat/completions",
|
||||
req_data):
|
||||
async for chunk in stream_service_response(
|
||||
app.state.decode_client, "/chat/completions", req_data
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="text/event-stream")
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server "
|
||||
" - chat completions endpoint")
|
||||
print(
|
||||
"Error occurred in disagg prefill proxy server - chat completions endpoint"
|
||||
)
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
|
@ -3,13 +3,14 @@
|
||||
This file demonstrates the example usage of remote KV cache sharing
|
||||
with LMCache.
|
||||
We will launch 2 vllm instances, and launch an additional LMCache server.
|
||||
KV cache is transferred in the following manner:
|
||||
KV cache is transferred in the following manner:
|
||||
(1) vLLM instance 1 -> LMCache server (KV cache store).
|
||||
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
|
||||
|
||||
Note that lmcache needs to be installed to run this example.
|
||||
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
@ -49,15 +50,16 @@ def run_store(store_done, prompts):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
|
||||
kv_role="kv_both")
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
|
||||
kv_role="kv_both")
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# of memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True)
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
print("Waiting for KV cache store to finish...")
|
||||
store_done.wait()
|
||||
@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1):
|
||||
|
||||
|
||||
def run_lmcache_server(port):
|
||||
server_proc = subprocess.Popen([
|
||||
"python", "-m", "lmcache.experimental.server", "localhost",
|
||||
str(port)
|
||||
])
|
||||
server_proc = subprocess.Popen(
|
||||
["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
|
||||
)
|
||||
return server_proc
|
||||
|
||||
|
||||
|
@ -10,8 +10,11 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerArgs, TensorizerConfig, tensorize_lora_adapter,
|
||||
tensorize_vllm_model)
|
||||
TensorizerArgs,
|
||||
TensorizerConfig,
|
||||
tensorize_lora_adapter,
|
||||
tensorize_vllm_model,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
|
54
examples/pyproject.toml
Normal file
54
examples/pyproject.toml
Normal file
@ -0,0 +1,54 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/other/fp8/quantizer/quantize.py",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["vllm"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
@ -57,6 +57,7 @@ ignore_patterns = [
|
||||
".buildkite/**",
|
||||
"benchmarks/**",
|
||||
"build/**",
|
||||
"examples/**",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
@ -144,6 +145,7 @@ skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora
|
||||
skip_glob = [
|
||||
".buildkite/*",
|
||||
"benchmarks/*",
|
||||
"examples/*",
|
||||
]
|
||||
use_parentheses = true
|
||||
skip_gitignore = true
|
||||
|
Reference in New Issue
Block a user