Compare commits

...

19 Commits

Author SHA1 Message Date
6d8d0a24c0 Add think chunk (#21333)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
2025-07-23 21:51:32 -07:00
11ef7a611e [BugFix] Set CUDA_VISIBLE_DEVICES before spawning the subprocesses (#21211)
Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
2025-07-23 21:44:04 -07:00
dc2f159f8a Dump input metadata on crash for async scheduling (#21258)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-07-23 21:10:30 -07:00
d5b981f8b1 [DP] Internal Load Balancing Per Node [one-pod-per-node] (#21238)
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-07-23 20:57:32 -07:00
eec6942014 [BugFix] Fix KVConnector TP worker aggregation (#21473)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-07-23 20:56:49 -07:00
fd48d99ffd [BugFix]: Batch generation from prompt_embeds fails for long prompts (#21390)
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Co-authored-by: KazusatoOko <kazusto.oko@sakana.ai>
2025-07-23 20:43:17 -07:00
f8c15c4efb [Bugfix] Fix example disagg_example_p2p_nccl_xpyd.sh zombie process (#21437)
Signed-off-by: David Chen <530634352@qq.com>
2025-07-23 20:42:11 -07:00
aa08a954f9 [Bugfix] Fix casing warning (#21468)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-07-23 20:41:23 -07:00
13e4ee1dc3 [XPU][UT] increase intel xpu CI test scope (#21492)
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
2025-07-23 20:24:04 -07:00
772ce5af97 [Misc] Add dummy maverick test to CI (#21324)
Signed-off-by: Ming Yang <minos.future@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-07-23 20:22:42 -07:00
63d92abb7c [Frontend] Set MAX_AUDIO_CLIP_FILESIZE_MB via env var instead of hardcoding (#21374)
Signed-off-by: Deven Labovitch <deven@videa.ai>
2025-07-23 20:22:19 -07:00
11599b0e1f feat(gguf_loader): accept HF repo paths & URLs for GGUF (#20793)
Signed-off-by: Hardik <hardikgupta1999@gmail.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-07-23 20:21:02 -07:00
f3137cdd81 [Core] Freeze gc during cuda graph capture to speed up init (#21146)
Signed-off-by: Codex <codex@openai.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-23 17:20:14 -07:00
82ec66f514 [V0 Deprecation] Remove Prompt Adapters (#20588)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-23 16:36:48 -07:00
78c13e30e1 [V1] Fix local chunked attention always disabled (#21419)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
2025-07-23 15:59:30 -07:00
5c9b807b34 [Core] Add reload_weights RPC method (#20096)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
2025-07-23 14:24:52 -07:00
14bf19e39f [TPU][TEST] Fix the downloading issue in TPU v1 test 11. (#21418)
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
2025-07-23 11:29:36 -07:00
4ac7713e32 Add test case for compiling multiple graphs (#21044)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
2025-07-23 11:00:47 -07:00
8560a5b258 [Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
2025-07-23 11:00:23 -07:00
102 changed files with 2629 additions and 2124 deletions

View File

@ -62,7 +62,8 @@ echo "Results will be stored in: $RESULTS_DIR"
echo "--- Installing Python dependencies ---"
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
&& python3 -m pip install --progress-bar off hf-transfer
echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1
@ -150,7 +151,7 @@ run_and_track_test 9 "test_multimodal.py" \
run_and_track_test 10 "test_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py"
run_and_track_test 11 "test_struct_output_generate.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
run_and_track_test 13 "test_lora.py" \

View File

@ -31,4 +31,13 @@ docker run \
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
cd tests
pytest -v -s v1/core
pytest -v -s v1/engine
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
pytest -v -s v1/structured_output
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
pytest -v -s v1/test_serial_utils.py
pytest -v -s v1/test_utils.py
pytest -v -s v1/test_metrics_reader.py
'

View File

@ -166,6 +166,7 @@ steps:
- tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
@ -178,6 +179,7 @@ steps:
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
@ -718,6 +720,7 @@ steps:
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s models/multimodal/generation/test_maverick.py
- label: Plugin Tests (2 GPUs) # 40min
mirror_hardwares: [amdexperimental]

View File

@ -265,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
#################### EXTENSION Build IMAGE ####################
#################### DEV IMAGE ####################
FROM base as dev
FROM base AS dev
ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL

View File

@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer pytest modelscope
pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
ENV VLLM_USAGE_SOURCE production-docker-image \
TRITON_XPU_PROFILE 1

View File

@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes.
- [vllm.config.DeviceConfig][]
- [vllm.config.SpeculativeConfig][]
- [vllm.config.LoRAConfig][]
- [vllm.config.PromptAdapterConfig][]
- [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][]
- [vllm.config.DecodingConfig][]

View File

@ -34,23 +34,22 @@ th:not(:first-child) {
}
</style>
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | | | | | | | | | | | | |
| [SD](spec_decode.md) | ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | |
| CUDA graph | | | | | | ✅ | | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | | | ✅ | | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | | [](gh-issue:7366) | | | [](gh-issue:7366) | | ✅ | ✅ | | | | | | | |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | | ✅ | | ❌ | ✅ | ✅ | ✅ | | | | | |
| <abbr title="Async Output Processing">async output</abbr> | | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| multi-step | ❌ | ✅ | ❌ | ✅ | | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | | | ✅ | ✅ | ✅ | | ✅ | ✅ | ❔ | ✅ | | |
| best-of | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | | ✅ | |
| beam-search | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
| [SD](spec_decode.md) | ✅ | ✅ | | | | | | | | | | | | |
| CUDA graph | ✅ | ✅ | | ✅ | ✅ | | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | | | | | | ✅ | | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | [](gh-issue:7366) | | | ✅ | | | | | | | |
| <abbr title="Logprobs">logP</abbr> | | | | | | | ✅ | ✅ | | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | | ✅ | | ❌ | ✅ | ✅ | ✅ | | | | |
| multi-step | | ✅ | | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | | ✅ | | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
| best-of | ✅ | ✅ | ✅ | [](gh-issue:6137) | | | ✅ | ✅ | ✅ | | [](gh-issue:7968) | ✅ | ✅ | |
| beam-search | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | | ✅ | |
[](){ #feature-x-hardware }
@ -59,10 +58,9 @@ th:not(:first-child) {
| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU |
|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----|
| [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | [](gh-issue:8475) | ✅ | ❌ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
| <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |

View File

@ -351,6 +351,11 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
<!-- TODO: api enforced limits + uploading audios -->
#### API Enforced Limits
Set the maximum audio file size (in MB) that VLLM will accept, via the
`VLLM_MAX_AUDIO_CLIP_FILESIZE_MB` environment variable. Default is 25 MB.
#### Extra Parameters
The following [sampling parameters][sampling-params] are supported.

View File

@ -1,122 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
""" # noqa: E501
import argparse
import datetime
import os
import re
from typing import Union
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm import LLM
torch.set_default_dtype(torch.float16)
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
model_config = """{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config)
datamodule_config = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size": 16,
@ -138,28 +43,24 @@ datamodule_config = {
class PrithviMAE:
def __init__(self):
print("Initializing PrithviMAE model")
self.llm = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
def __init__(self, model):
self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
)
def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure
if input_data is not None and input_data.dtype == torch.float32:
input_data = input_data.to(torch.float16)
input_data = input_data[0]
mm_data = {
"pixel_values": torch.empty(0) if input_data is None else input_data,
"location_coords": torch.empty(0)
if location_coords is None
else location_coords,
"pixel_values": input_data,
"location_coords": location_coords,
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.llm.encode(prompt, use_tqdm=False)
print("################ Inference done (it took seconds) ##############")
outputs = self.model.encode(prompt, use_tqdm=False)
return outputs[0].outputs.data
@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width)
for original image
"""
orig_img = orig_img[channels, ...]
@ -260,10 +162,10 @@ def load_example(
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images
in *file_paths*.
std: list containing std values for each band in the images
in *file_paths*.
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
@ -308,7 +210,7 @@ def load_example(
print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
@ -332,8 +234,10 @@ def run_model(
)
# Build sliding window
batch_size = 1
batch = torch.tensor(input_data, device="cpu")
# batch = torch.tensor(input_data, device="cpu")
batch = torch.tensor(input_data)
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5]
windows = rearrange(
@ -344,18 +248,16 @@ def run_model(
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None
# Run model
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs = []
for x in windows:
# Apply standardization
@ -363,15 +265,7 @@ def run_model(
x = datamodule.aug(x)["image"]
with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
x, temporal_coords=temporal_coords, location_coords=location_coords
)
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate(
@ -403,52 +297,18 @@ def run_model(
return pred_imgs
def parse_args():
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
def main(
data_file: str,
model: str,
output_dir: str,
rgb_outputs: bool,
input_indices: list[int] = None,
):
os.makedirs(output_dir, exist_ok=True)
# Load model ---------------------------------------------------------------
model_obj = PrithviMAE()
model_obj = PrithviMAE(model=model)
datamodule = generate_datamodule()
img_size = 256 # Size of Sen1Floods11
# Loading data -------------------------------------------------------------
img_size = 512 # Size of Sen1Floods11
input_data, temporal_coords, location_coords, meta_data = load_example(
file_paths=[data_file],
@ -460,8 +320,6 @@ def main(
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
# Running model ------------------------------------------------------------
channels = [
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB
@ -469,7 +327,6 @@ def main(
pred = run_model(
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
)
# Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join(
@ -487,6 +344,7 @@ def main(
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
channels=channels,
)
rgb_orig = rgb_orig.to(torch.float32)
pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3
@ -503,9 +361,10 @@ def main(
# Save image rgb
if rgb_outputs:
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
rgb_file = os.path.join(
output_dir,
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
f"original_rgb_{name_suffix}.tiff",
)
save_geotiff(
image=_convert_np_uint8(rgb_orig),
@ -515,6 +374,42 @@ def main(
if __name__ == "__main__":
args = parse_args()
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--model",
type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
help="Path to a checkpoint file to load from.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
""",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
args = parser.parse_args()
main(**vars(args))

View File

@ -93,6 +93,7 @@ ensure_python_library_installed() {
cleanup() {
echo "Stopping everything…"
trap - INT TERM # prevent re-entrancy
pkill -9 -f "disagg_proxy_p2p_nccl_xpyd.py"
kill -- -$$ # negative PID == "this whole process-group"
wait # reap children so we don't leave zombies
exit 0

View File

@ -72,7 +72,6 @@ line-length = 80
"vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"]
# Python 3.8 typing - skip utils for ROCm
"vllm/utils/__init__.py" = ["UP006", "UP035"]

View File

@ -33,7 +33,7 @@ pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
importlib_metadata; python_version < '3.10'
mistral_common[opencv] >= 1.8.0
mistral_common[image,audio] >= 1.8.2
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test

View File

@ -28,7 +28,7 @@ torchvision==0.22.1
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test
@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test

View File

@ -6,6 +6,10 @@ accelerate==1.0.1
# via
# lm-eval
# peft
aenum==3.1.16
# via lightly
affine==2.4.0
# via rasterio
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.11
@ -21,8 +25,18 @@ aiosignal==1.3.1
# via
# aiohttp
# ray
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via terratorch
alembic==1.16.4
# via mlflow
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.6.2.post1
# via
# httpx
@ -34,10 +48,12 @@ arrow==1.3.0
attrs==24.2.0
# via
# aiohttp
# fiona
# hypothesis
# jsonlines
# jsonschema
# pytest-subtests
# rasterio
# referencing
audioread==3.0.1
# via librosa
@ -46,9 +62,13 @@ backoff==2.2.1
# -r requirements/test.in
# schemathesis
bitsandbytes==0.46.1
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
black==24.10.0
# via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0
# via -r requirements/test.in
bm25s==0.2.13
@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
# via
# google-auth
# mlflow-skinny
certifi==2024.8.30
# via
# fiona
# httpcore
# httpx
# lightly
# pyogrio
# pyproj
# rasterio
# requests
cffi==1.17.1
# via soundfile
@ -79,11 +106,28 @@ charset-normalizer==3.4.0
click==8.1.7
# via
# black
# click-plugins
# cligj
# fiona
# flask
# jiwer
# mlflow-skinny
# nltk
# rasterio
# ray
# schemathesis
# typer
# uvicorn
click-plugins==1.1.1.2
# via
# fiona
# rasterio
cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6
# via
# sacrebleu
@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
# via ray
cycler==0.12.1
# via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3
# via -r requirements/test.in
dataproperty==1.0.1
@ -122,13 +168,21 @@ distlib==0.3.9
# via virtualenv
dnspython==2.7.0
# via email-validator
docker==7.1.0
# via mlflow
docopt==0.6.2
# via num2words
einops==0.8.0
docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch
# vocos
einx==0.3.0
@ -141,6 +195,8 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
fastapi==0.116.1
# via mlflow-skinny
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
@ -156,6 +212,10 @@ filelock==3.16.1
# torch
# transformers
# virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1
# via matplotlib
fqdn==1.5.1
@ -173,6 +233,8 @@ fsspec==2024.9.0
# evaluate
# fastparquet
# huggingface-hub
# lightning
# pytorch-lightning
# torch
ftfy==6.3.1
# via open-clip-torch
@ -180,18 +242,41 @@ genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
# via
# databricks-sdk
# google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphene==3.4.3
# via mlflow
graphql-core==3.2.6
# via hypothesis-graphql
# via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3
# via sqlalchemy
grpcio==1.71.0
# via ray
gunicorn==23.0.0
# via mlflow
h11==0.14.0
# via httpcore
# via
# httpcore
# uvicorn
h5py==3.13.0
# via terratorch
harfile==0.3.0
# via schemathesis
hf-xet==1.1.3
@ -204,7 +289,7 @@ httpx==0.27.2
# via
# -r requirements/test.in
# schemathesis
huggingface-hub==0.33.0
huggingface-hub==0.33.1
# via
# -r requirements/test.in
# accelerate
@ -212,13 +297,19 @@ huggingface-hub==0.33.0
# evaluate
# open-clip-torch
# peft
# segmentation-models-pytorch
# sentence-transformers
# terratorch
# timm
# tokenizers
# transformers
# vocos
humanize==4.11.0
# via runai-model-streamer
hydra-core==1.3.2
# via
# lightly
# lightning
hypothesis==6.131.0
# via
# hypothesis-graphql
@ -236,6 +327,14 @@ idna==3.10
# jsonschema
# requests
# yarl
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
inflect==5.6.2
# via datamodel-code-generator
iniconfig==2.0.0
@ -244,9 +343,13 @@ isoduration==20.11.0
# via jsonschema
isort==5.13.2
# via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via
# datamodel-code-generator
# flask
# mlflow
# torch
jiwer==3.0.5
# via -r requirements/test.in
@ -259,6 +362,10 @@ joblib==1.4.2
# librosa
# nltk
# scikit-learn
jsonargparse==4.35.0
# via
# lightning
# terratorch
jsonlines==4.0.0
# via lm-eval
jsonpointer==3.0.0
@ -277,12 +384,33 @@ kaleido==0.2.1
# via genai-perf
kiwisolver==1.4.7
# via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
# via librosa
# via
# librosa
# scikit-image
libnacl==2.1.0
# via tensorizer
librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.20
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0
# via numba
lm-eval==0.4.8
@ -291,16 +419,27 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mako==1.3.10
# via alembic
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.1
# via
# flask
# jinja2
# mako
# werkzeug
matplotlib==3.9.2
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
# via
# dataproperty
@ -308,8 +447,12 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.0
mistral-common==1.8.2
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0
# via lm-eval
mpmath==1.3.0
@ -328,10 +471,14 @@ multiprocess==0.70.16
# via
# datasets
# evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
# via
# scikit-image
# torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
@ -348,6 +495,8 @@ numpy==1.26.4
# via
# -r requirements/test.in
# accelerate
# albucore
# albumentations
# bitsandbytes
# bm25s
# contourpy
@ -358,9 +507,15 @@ numpy==1.26.4
# evaluate
# fastparquet
# genai-perf
# geopandas
# h5py
# imageio
# librosa
# lightly
# lightly-utils
# matplotlib
# mistral-common
# mlflow
# mteb
# numba
# numexpr
@ -368,18 +523,30 @@ numpy==1.26.4
# pandas
# patsy
# peft
# pycocotools
# pyogrio
# rasterio
# rioxarray
# rouge-score
# runai-model-streamer
# sacrebleu
# scikit-image
# scikit-learn
# scipy
# segmentation-models-pytorch
# shapely
# soxr
# statsmodels
# tensorboardx
# tensorizer
# tifffile
# torchgeo
# torchmetrics
# torchvision
# transformers
# tritonclient
# vocos
# xarray
nvidia-cublas-cu12==12.8.3.14
# via
# nvidia-cudnn-cu12
@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0
# via -r requirements/test.in
opencensus==0.11.4
@ -426,7 +597,18 @@ opencensus-context==0.1.3
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
# albucore
# albumentations
# mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-sdk==1.35.0
# via mlflow-skinny
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
# via
# accelerate
@ -435,26 +617,44 @@ packaging==24.2
# datasets
# evaluate
# fastparquet
# geopandas
# gunicorn
# huggingface-hub
# hydra-core
# kornia
# lazy-loader
# lightning
# lightning-utilities
# mamba-ssm
# matplotlib
# mlflow-skinny
# peft
# plotly
# pooch
# pyogrio
# pytest
# pytest-rerunfailures
# pytorch-lightning
# ray
# rioxarray
# scikit-image
# statsmodels
# tensorboardx
# torchmetrics
# transformers
# typepy
# xarray
pandas==2.2.3
# via
# datasets
# evaluate
# fastparquet
# genai-perf
# geopandas
# mlflow
# statsmodels
# torchgeo
# xarray
pathspec==0.12.1
# via black
pathvalidate==3.2.1
@ -468,9 +668,14 @@ peft==0.13.2
pillow==10.4.0
# via
# genai-perf
# imageio
# lightly-utils
# matplotlib
# mistral-common
# scikit-image
# segmentation-models-pytorch
# sentence-transformers
# torchgeo
# torchvision
platformdirs==4.3.6
# via
@ -489,6 +694,8 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
# via ray
propcache==0.2.0
@ -499,8 +706,10 @@ protobuf==5.28.3
# via
# google-api-core
# googleapis-common-protos
# mlflow-skinny
# proto-plus
# ray
# tensorboardx
# tensorizer
psutil==6.1.0
# via
@ -515,6 +724,7 @@ pyarrow==18.0.0
# via
# datasets
# genai-perf
# mlflow
pyasn1==0.6.1
# via
# pyasn1-modules
@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
@ -532,8 +744,12 @@ pycryptodomex==3.22.0
pydantic==2.11.5
# via
# -r requirements/test.in
# albumentations
# datamodel-code-generator
# fastapi
# lightly
# mistral-common
# mlflow-skinny
# mteb
# pydantic-extra-types
# ray
@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0
# via rich
pyogrio==0.11.0
# via geopandas
pyparsing==3.2.0
# via matplotlib
# via
# matplotlib
# rasterio
pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0
# via schemathesis
pystemmer==3.0.0
# via mteb
pytablewriter==1.2.0
# via lm-eval
pytest==8.3.3
pytest==8.3.5
# via
# -r requirements/test.in
# buildkite-test-collector
@ -564,6 +789,7 @@ pytest==8.3.3
# pytest-subtests
# pytest-timeout
# schemathesis
# terratorch
pytest-asyncio==0.24.0
# via -r requirements/test.in
pytest-forked==1.6.0
@ -578,15 +804,23 @@ pytest-subtests==0.14.1
# via schemathesis
pytest-timeout==2.3.1
# via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0
# via
# arrow
# botocore
# graphene
# lightly
# matplotlib
# pandas
# typepy
python-rapidjson==1.20
# via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
pytrec-eval-terrier==0.5.7
# via mteb
pytz==2024.2
@ -596,11 +830,17 @@ pytz==2024.2
pyyaml==6.0.2
# via
# accelerate
# albumentations
# datamodel-code-generator
# datasets
# genai-perf
# huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# peft
# pytorch-lightning
# ray
# responses
# schemathesis
@ -609,6 +849,11 @@ pyyaml==6.0.2
# vocos
rapidfuzz==3.12.1
# via jiwer
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.43.0
# via -r requirements/test.in
redis==5.2.0
@ -627,12 +872,16 @@ regex==2024.9.11
requests==2.32.3
# via
# buildkite-test-collector
# databricks-sdk
# datasets
# docker
# evaluate
# google-api-core
# huggingface-hub
# lightly
# lm-eval
# mistral-common
# mlflow-skinny
# mteb
# pooch
# ray
@ -650,8 +899,11 @@ rfc3987==1.3.8
rich==13.9.4
# via
# genai-perf
# lightning
# mteb
# typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
@ -660,6 +912,8 @@ rpds-py==0.20.1
# referencing
rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
@ -677,21 +931,32 @@ safetensors==0.4.5
# transformers
schemathesis==3.39.15
# via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
scikit-learn==1.5.2
# via
# albumentations
# librosa
# lm-eval
# mlflow
# mteb
# sentence-transformers
scipy==1.13.1
# via
# albumentations
# bm25s
# librosa
# mlflow
# mteb
# scikit-image
# scikit-learn
# sentence-transformers
# statsmodels
# vocos
segmentation-models-pytorch==0.4.0
# via
# terratorch
# torchgeo
sentence-transformers==3.2.1
# via
# -r requirements/test.in
@ -700,21 +965,30 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3
# via
# lightning-utilities
# mamba-ssm
# pytablewriter
# torch
# triton
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4
# via typer
six==1.16.0
# via
# junit-xml
# lightly
# opencensus
# python-dateutil
# rfc3339-validator
# rouge-score
# segmentation-models-pytorch
smart-open==7.1.0
# via ray
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via
# anyio
@ -725,12 +999,22 @@ soundfile==0.12.1
# via
# -r requirements/test.in
# librosa
# mistral-common
soxr==0.5.0.post1
# via librosa
# via
# librosa
# mistral-common
sqlalchemy==2.0.41
# via
# alembic
# mlflow
sqlitedict==2.1.0
# via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.46.2
# via
# fastapi
# schemathesis
# starlette-testclient
starlette-testclient==0.4.1
@ -751,18 +1035,29 @@ tenacity==9.0.0
# via
# lm-eval
# plotly
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.7.0
# via
# lm-eval
# mistral-common
timm==1.0.11
timm==1.0.15
# via
# -r requirements/test.in
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.21.1
# via
# -r requirements/test.in
@ -776,18 +1071,28 @@ torch==2.7.1+cu128
# -r requirements/test.in
# accelerate
# bitsandbytes
# efficientnet-pytorch
# encodec
# fastsafetensors
# kornia
# lightly
# lightning
# lm-eval
# mamba-ssm
# mteb
# open-clip-torch
# peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers
# tensorizer
# terratorch
# timm
# torchaudio
# torchgeo
# torchmetrics
# torchvision
# vector-quantize-pytorch
# vocos
@ -796,22 +1101,40 @@ torchaudio==2.7.1+cu128
# -r requirements/test.in
# encodec
# vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.22.1+cu128
# via
# -r requirements/test.in
# lightly
# open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm
# torchgeo
tqdm==4.66.6
# via
# datasets
# evaluate
# huggingface-hub
# lightly
# lightning
# lm-eval
# mteb
# nltk
# open-clip-torch
# peft
# pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# tqdm-multiprocess
# transformers
@ -843,18 +1166,34 @@ typer==0.15.2
# via fastsafetensors
types-python-dateutil==2.9.0.20241206
# via arrow
typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.12.2
# via
# albumentations
# alembic
# fastapi
# graphene
# huggingface-hub
# librosa
# lightning
# lightning-utilities
# mistral-common
# mlflow-skinny
# mteb
# opentelemetry-api
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pqdm
# pydantic
# pydantic-core
# pydantic-extra-types
# pytorch-lightning
# sqlalchemy
# torch
# torchgeo
# typer
# typeshed-client
# typing-inspection
typing-inspection==0.4.1
# via pydantic
@ -866,9 +1205,13 @@ urllib3==2.2.3
# via
# blobfile
# botocore
# docker
# lightly
# requests
# responses
# tritonclient
uvicorn==0.35.0
# via mlflow-skinny
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
@ -880,11 +1223,15 @@ wcwidth==0.2.13
webcolors==24.11.1
# via jsonschema
werkzeug==3.1.3
# via schemathesis
# via
# flask
# schemathesis
word2number==1.1
# via lm-eval
wrapt==1.17.2
# via smart-open
xarray==2025.7.1
# via rioxarray
xxhash==3.5.0
# via
# datasets
@ -893,5 +1240,7 @@ yarl==1.17.1
# via
# aiohttp
# schemathesis
zipp==3.23.0
# via importlib-metadata
zstandard==0.23.0
# via lm-eval

View File

@ -0,0 +1,350 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
"""
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
BATCH_SIZE = 32
MLP_SIZE = 128
HIDDEN_SIZE = 1024
RANDOM_SEED = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@support_torch_compile
class ParentModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class Attention(nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False)
self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size))
# Initialize to same weights for testing
nn.init.xavier_normal_(
self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
nn.init.xavier_normal_(
self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
return (x_f32 * torch.rsqrt(
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
self.rms_norm_weight).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
x = self.rms_norm_ref(x)
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = self.rms_norm_ref(x)
x = self.post_attn(x)
return x
@support_torch_compile
class CompiledAttention(nn.Module):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x)
@support_torch_compile
class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x
@ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)"
# This is because CompiledAttention and CompiledAttentionTwo
# have different implmentations but the same torch.compile
# cache dir will be used as default prefix is 'model_tag'
with set_model_tag("attn_one"):
self.attn_one = CompiledAttention(
mlp_size=mlp_size,
hidden_size=hidden_size,
vllm_config=vllm_config,
prefix=f"{prefix}.attn_one",
)
with set_model_tag("attn_two"):
self.attn_two = CompiledAttentionTwo(
mlp_size=mlp_size,
hidden_size=hidden_size,
vllm_config=vllm_config,
prefix=f"{prefix}.attn_two",
)
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz = x.shape[0]
# CUDAGraph expects same tensor addresses for each run
self.hidden_states[:bsz].copy_(x)
x = self.attn_one(self.hidden_states[:bsz])
self.hidden_states[:bsz].copy_(x)
x = self.attn_two(self.hidden_states[:bsz])
return x
def test_ignore_torch_compile_decorator():
assert VLLM_USE_V1
# piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
@support_torch_compile
class A(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
@ignore_torch_compile
class B(A):
...
@support_torch_compile
class C(B):
...
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
# first run is for compile
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# run cudagraph captured sizes
mod_A(torch.randn(2, MLP_SIZE).cuda())
mod_A(torch.randn(1, MLP_SIZE).cuda())
with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
# B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
), set_forward_context({}, vllm_config=vllm_config):
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
mod_B(torch.randn(2, MLP_SIZE).cuda())
mod_B(torch.randn(1, MLP_SIZE).cuda())
with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
# C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
mod_C(torch.randn(2, MLP_SIZE).cuda())
mod_C(torch.randn(1, MLP_SIZE).cuda())
@torch.inference_mode
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
with set_forward_context({}, vllm_config=vllm_config):
# First run is for compile
model(inputs)
# Run CUDAGraph captured sizes
model(inputs[:2])
model(inputs[:1])
output = model(inputs[:2])
output = output.cpu()
return output.cpu()
def test_multi_graph_piecewise_compile_outputs_equal():
outputs = []
# piecewise compile
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(run_model(vllm_config, model, inputs))
# no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ))
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(vllm_config, model, inputs))
# piecewise compile without CUDA graph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly.attention"],
))
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=0, # no cudagraph captured
):
outputs.append(run_model(vllm_config, model, inputs))
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
assert torch.allclose(outputs[0], outputs[1])
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
assert torch.equal(outputs[0], outputs[2])

View File

@ -69,8 +69,9 @@ def run_test(more_args):
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
and not current_platform.is_tpu()
and not current_platform.is_xpu(),
reason="V1 currently only supported on CUDA, XPU and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests
import json
import os
import shutil
from tempfile import TemporaryDirectory
from typing import Optional
@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64",
"--max-cpu-loras",
"2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
]
@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture
@ -110,14 +103,11 @@ async def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter
# just test 1 lora
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs

View File

@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME

View File

@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
lora_modules=None)
await serving_models.init_static_loras()
return serving_models

View File

@ -6,6 +6,10 @@ from collections.abc import Mapping
from typing import Literal, Optional
import pytest
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
SpecialTokens)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
Tekkenizer)
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
@ -21,6 +25,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
encode_video_base64)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
)
assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
mistral_tokenizer):
messages = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type":
"thinking",
"closed":
True,
"thinking":
"Only return the answer when you are confident."
}]
}, {
"role": "user",
"content": "What is 2+2?"
}, {
"role":
"assistant",
"content": [{
"type": "text",
"text": "Let me think about it."
}, {
"type": "thinking",
"closed": True,
"thinking": "2+2 = 4"
}, {
"type": "text",
"text": "The answer is 4.",
}],
}]
conversation_with_thinking, _ = parse_chat_messages(
messages,
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
expected_conversation = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type": "text",
"text": "Only return the answer when you are confident."
}],
}, {
"role":
"user",
"content": [{
"type": "text",
"text": "What is 2+2?"
}],
}, {
"role":
"assistant",
"content": [
{
"type": "text",
"text": "Let me think about it."
},
{
"type": "text",
"text": "2+2 = 4"
},
{
"type": "text",
"text": "The answer is 4."
},
]
}]
assert conversation_with_thinking == expected_conversation
def test_apply_mistral_chat_template_thinking_chunk():
# Moved import here to avoid yapf and isort conflicts
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
messages = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type":
"thinking",
"closed":
True,
"thinking":
"Only return the answer when you are confident."
}]
}, {
"role": "user",
"content": "What is 2+2?"
}, {
"role":
"assistant",
"content": [{
"type": "text",
"text": "Let me think about it."
}, {
"type": "thinking",
"closed": True,
"thinking": "2+2 = 4"
}, {
"type": "text",
"text": "The answer is 4.",
}],
}, {
"role": "user",
"content": "Thanks, what is 3+3?"
}]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507")
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
messages,
chat_template=None,
tools=None)
string_tokens = mistral_tokenizer.mistral.decode(
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
expected_tokens = (
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r"[INST]What is 2+2?[/INST]"
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r"[INST]Thanks, what is 3+3?[/INST]")
assert string_tokens == expected_tokens

View File

@ -23,6 +23,8 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams
from ....utils import multi_gpu_test
# Sample prompts for testing
PROMPTS: list[str] = [
"Hello, my name is",
@ -541,6 +543,7 @@ def run_reduced_model(model_path: str,
print("-" * 40)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"original_model_name,text_layers,num_experts,vision_layers,",
[("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)])

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.utils import set_default_torch_num_threads
from ....conftest import VllmRunner
def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data
def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
prompt = [
{
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": generate_test_mm_data(),
} for _ in range(10)
]
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)

View File

@ -1,48 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output

View File

@ -1,56 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output

View File

@ -1,64 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output

View File

@ -0,0 +1,341 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
Tekkenizer)
from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
parser_name = "mistral"
@pytest.fixture(scope="module")
def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507")
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer
SIMPLE_REASONING = {
"output": "This is a reasoning section[/THINK]This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section[/THINK]",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThat[/THINK]This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "[/THINK]This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "[/THINK]This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "[THINK]This is a reasoning section[/THINK]",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "[THINK]This\nThat[/THINK]This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "[/THINK]This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "[/THINK]This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
THINK_NO_END = {
"output": "[THINK]This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
EMPTY = {
"output": "",
"reasoning_content": "",
"content": None,
"is_reasoning_end": False,
}
EMPTY_STREAMING = {
"output": "",
"reasoning_content": None,
"content": None,
"is_reasoning_end": False,
}
NEW_LINE = {
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"reasoning_content": "This is a reasoning section",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
# Streaming cannot handle new lines at the beginning of the output
# because we need to support [THINK]...[/THINK] and [/THINK]...
# We cannot know if the text before [THINK] is reasoning content
# or not.
NEW_LINE_STREAMING = {
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"reasoning_content": "\nThis is a reasoning section",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
),
pytest.param(
False,
NO_CONTENT,
id="no_content_token",
),
pytest.param(
True,
NO_REASONING_STREAMING,
id="no_reasoning_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
True,
SHORTEST_REASONING,
id="shortest",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_streaming",
),
pytest.param(
False,
REASONING_WITH_THINK,
id="reasoning_with_think",
),
pytest.param(
True,
REASONING_WITH_THINK,
id="reasoning_with_think_streaming",
),
pytest.param(
False,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think",
),
pytest.param(
True,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think_streaming",
),
pytest.param(
False,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think",
),
pytest.param(
True,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think_streaming",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
id="shortest_with_think",
),
pytest.param(
True,
SHORTEST_REASONING_WITH_THINK,
id="shortest_with_think_streaming",
),
pytest.param(
False,
THINK_NO_END,
id="think_no_end",
),
pytest.param(
True,
THINK_NO_END,
id="think_no_end_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
NEW_LINE,
id="new_line",
),
pytest.param(
True,
NEW_LINE_STREAMING,
id="new_line_streaming",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_mistral_reasoning(
streaming: bool,
param_dict: dict,
mistral_tokenizer: MistralTokenizer,
):
output = param_dict["output"]
index_think = output.find("[THINK]")
len_think = len("[THINK]")
index_end_think = output.find("[/THINK]")
len_end_think = len("[/THINK]")
# encode everything to tokens ids
output_tokens = []
if index_think != -1:
output_before_think = output[:index_think]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_before_think, False, False)
output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK]
if index_end_think != -1:
output_middle = output[index_think + len_think:index_end_think]
output_after_think = output[index_end_think + len_end_think:]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_middle, False, False)
output_tokens += [mistral_tokenizer.instruct.END_THINK]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_after_think, False, False)
else:
output_middle = output[index_think + len_think:]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_middle, False, False)
elif index_end_think != -1:
output_before_think = output[:index_end_think]
output_after_think = output[index_end_think + len_end_think:]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_before_think, False, False)
output_tokens += [mistral_tokenizer.instruct.END_THINK]
output_tokens += mistral_tokenizer.tokenizer.encode(
output_after_think, False, False)
else:
output_tokens += mistral_tokenizer.tokenizer.encode(
output, False, False)
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(mistral_tokenizer)
reasoning, content = run_reasoning_extraction_mistral(parser,
output_tokens,
streaming=streaming)
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
# Test is_reasoning_end
is_reasoning_end = parser.is_reasoning_end(output_tokens)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_tokens)
assert content == mistral_tokenizer.tokenizer.encode(
param_dict["content"], bos=False, eos=False)
else:
content = parser.extract_content_ids(output_tokens)
assert content == []

View File

@ -6,6 +6,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
class StreamingReasoningReconstructor:
@ -54,6 +55,32 @@ def run_reasoning_extraction(
return reasoning, content
def run_reasoning_extraction_mistral(
reasoning_parser: ReasoningParser,
model_output: list[int],
request: Union[ChatCompletionRequest, None] = None,
streaming: bool = False,
) -> tuple[Optional[str], Optional[str]]:
assert isinstance(reasoning_parser.model_tokenizer,
MistralTokenizer), type(reasoning_parser.model_tokenizer)
if streaming:
reconstructor = run_reasoning_extraction_streaming_mistral(
reasoning_parser,
model_output,
request,
)
return (
reconstructor.reasoning_content,
reconstructor.other_content or None,
)
else:
str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
model_output)
reasoning, content = run_reasoning_extraction_nonstreaming(
reasoning_parser, str_output, request)
return reasoning, content
def run_reasoning_extraction_nonstreaming(
reasoning_parser: ReasoningParser,
model_output: list[str],
@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming(
previous_text = current_text
previous_tokens = current_tokens
return reconstructor
def run_reasoning_extraction_streaming_mistral(
reasoning_parser: ReasoningParser,
model_deltas: list[int],
request: Union[ChatCompletionRequest, None] = None,
) -> StreamingReasoningReconstructor:
assert isinstance(reasoning_parser.model_tokenizer,
MistralTokenizer), type(reasoning_parser.model_tokenizer)
request = request or ChatCompletionRequest(messages=[], model="test-model")
reconstructor = StreamingReasoningReconstructor()
previous_text = ""
previous_tokens: list[int] = []
for model_delta in model_deltas:
token_delta = [model_delta]
delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
[model_delta])[0]
current_text = previous_text + delta
current_tokens = previous_tokens + token_delta
delta_message = reasoning_parser.extract_reasoning_content_streaming(
previous_text,
current_text,
delta,
previous_tokens,
current_tokens,
token_delta,
)
if delta_message is not None:
reconstructor.append_delta(delta_message)
previous_text = current_text
previous_tokens = current_tokens
return reconstructor

View File

@ -565,8 +565,8 @@ def test_engine_core_proc_instantiation_cuda_empty(
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
def mock_startup_handshake(self, handshake_socket, local_client,
headless, parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,

View File

@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for hybrid LB testing (4 total)
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
# Number of nodes (2 nodes, each with 2 DP ranks)
NUM_NODES = 2
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node
class HybridLBServerManager:
"""Manages hybrid data parallel vLLM server instances where each node
runs a single logical API server that balances requests only to the
DP engines running on that same node."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
dp_size_local: int = DP_SIZE_LOCAL,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.dp_size_local = dp_size_local
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
self.num_nodes = dp_size // dp_size_local
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for hybrid LB mode."""
for node_id in range(self.num_nodes):
# Create server args for this specific node
server_args = self.base_server_args.copy()
# Calculate start rank for this node
start_rank = node_id * self.dp_size_local
# Add hybrid LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-size-local",
str(self.dp_size_local),
"--data-parallel-start-rank",
str(start_rank),
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + node_id), # Different port for each node
"--api-server-count",
str(self.api_server_count),
"--data-parallel-address",
"127.0.0.1",
"--data-parallel-rpc-port",
"13345",
])
# Use a thread to start each server to allow parallel initialization
def start_server(node: int, sargs: list[str]):
try:
# Calculate GPU devices for this node
gpus_per_node = self.dp_size_local * self.tp_size
gpu_start = node * gpus_per_node
gpu_end = gpu_start + gpus_per_node
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i)) for i in range(gpu_start, gpu_end))
})
server.__enter__()
print(f"Hybrid LB node {node} started successfully with "
f"{self.dp_size_local} local DP ranks and "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start hybrid LB node {node}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(node_id, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != self.num_nodes:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
def servers(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each node (each node has its own API endpoint)
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
servers: list[tuple[RemoteOpenAIServer,
list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each node
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single completion request successfully"
)
await asyncio.sleep(0.5)
# Send requests to all nodes - each should balance within its local DP ranks
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed hybrid LB test with {len(clients)} nodes "
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
)
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each node
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single streaming request successfully"
)
await asyncio.sleep(0.5)
# Send streaming requests to all nodes
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed hybrid LB streaming test with "
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
f"API server count: {api_server_count})")
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking streaming request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)

View File

@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
model_runner_2.reload_weights() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
def test_reload_weights_before_load_model(model_runner):
with pytest.raises(AssertionError):
model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"

View File

@ -31,6 +31,5 @@ run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/worker
run_mypy vllm/v1

View File

@ -143,6 +143,8 @@ class Attention(nn.Module):
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
@ -177,7 +179,6 @@ class Attention(nn.Module):
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
self.use_irope = extra_impl_args.get("use_irope", False)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant

View File

@ -423,6 +423,12 @@ class InductorAdaptor(CompilerInterface):
if is_torch_equal_or_newer("2.6"):
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False))
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False))
stack.enter_context(
torch._functorch.config.patch(
enable_remote_autograd_cache=False))

View File

@ -20,9 +20,38 @@ from .monitor import start_monitoring_torch_compile
logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=type[nn.Module])
def ignore_torch_compile(cls: _T) -> _T:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.
If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.
If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr(cls, IGNORE_COMPILE_KEY, True)
return cls
def _should_ignore_torch_compile(cls) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
@ -148,6 +177,8 @@ def _support_torch_compile(
old_init = cls.__init__
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
@ -156,9 +187,11 @@ def _support_torch_compile(
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
if self.do_not_compile:
return
compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

View File

@ -651,6 +651,8 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self.registry.supports_multimodal_raw_input(self.architectures))
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@ -1243,10 +1245,10 @@ class ModelConfig:
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
@property
def is_deepseek_mla(self) -> bool:
@ -1906,8 +1908,16 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
@ -3141,59 +3151,6 @@ class LoRAConfig:
self.lora_dtype = getattr(torch, self.lora_dtype)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config
@dataclass
class MultiModalConfig:
@ -4400,8 +4357,6 @@ class VllmConfig:
"""Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None
"""Quantization configuration."""
compilation_config: CompilationConfig = field(
@ -4498,10 +4453,6 @@ class VllmConfig:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
@ -4609,9 +4560,6 @@ class VllmConfig:
if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(

View File

@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
@ -165,8 +164,6 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
@ -194,14 +191,6 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
@ -1648,7 +1637,6 @@ class Scheduler:
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for

View File

@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -295,9 +295,11 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_eplb: bool = ParallelConfig.enable_eplb
@ -358,11 +360,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
@ -437,6 +434,8 @@ class EngineArgs:
ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -607,6 +606,11 @@ class EngineArgs:
type=int,
help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-start-rank',
'-dpr',
type=int,
help='Starting data parallel rank '
'for secondary nodes.')
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
@ -628,6 +632,9 @@ class EngineArgs:
default='mp',
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--data-parallel-hybrid-lb",
**parallel_kwargs["data_parallel_hybrid_lb"])
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
@ -729,23 +736,6 @@ class EngineArgs:
lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
"--enable-prompt-adapter",
action=argparse.BooleanOptionalAction,
help="If True, enable handling of PromptAdapters.")
prompt_adapter_group.add_argument(
"--max-prompt-adapters",
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Speculative arguments
speculative_group = parser.add_argument_group(
title="SpeculativeConfig",
@ -850,6 +840,12 @@ class EngineArgs:
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
parser.add_argument('--enable-prompt-adapter',
action='store_true',
deprecated=True,
help='[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.')
return parser
@ -986,6 +982,7 @@ class EngineArgs:
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
@ -1074,15 +1071,41 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in "
"headless mode")
data_parallel_external_lb = self.data_parallel_rank is not None
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
self.data_parallel_hybrid_lb = False
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
if self.data_parallel_start_rank and not headless:
# Infer hybrid LB mode.
self.data_parallel_hybrid_lb = True
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
if data_parallel_size_local == self.data_parallel_size:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
self.data_parallel_rank = self.data_parallel_start_rank or 0
else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
@ -1139,6 +1162,7 @@ class EngineArgs:
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
@ -1234,11 +1258,6 @@ class EngineArgs:
load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
@ -1266,7 +1285,6 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
@ -1342,12 +1360,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Prompt Adapter so far.
if self.enable_prompt_adapter:
_raise_or_fallback(feature_name="--enable-prompt-adapter",
recommend_to_remove=False)
return False
# No text embedding inputs so far.
if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds",
@ -1469,7 +1481,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True
logger.warning(

View File

@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
)
@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):

View File

@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
@ -223,7 +222,6 @@ class LLMEngine:
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
@ -238,14 +236,14 @@ class LLMEngine:
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
@ -294,8 +292,6 @@ class LLMEngine:
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
@ -542,9 +538,6 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request(
self,
@ -553,7 +546,6 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
@ -569,7 +561,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
@ -583,11 +574,10 @@ class LLMEngine:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
lora_request)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
@ -598,7 +588,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
@ -608,7 +597,6 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
@ -637,7 +625,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add a request to the engine's request pool.
@ -658,7 +645,6 @@ class LLMEngine:
the current monotonic time.
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request.
Only applicable with priority scheduling.
@ -719,7 +705,6 @@ class LLMEngine:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
@ -728,7 +713,6 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
@ -741,7 +725,6 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
@ -769,17 +752,15 @@ class LLMEngine:
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
return seq_group
@ -790,7 +771,6 @@ class LLMEngine:
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
@ -798,15 +778,13 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
@ -1834,16 +1812,6 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None:
self.model_executor.start_profile()

View File

@ -10,7 +10,6 @@ from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import Device
@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
def __init__(
@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
super().__init__()
@ -53,7 +50,6 @@ class RPCProcessRequest:
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority

View File

@ -45,7 +45,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device
@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient):
return cast(
AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority))
lora_request, trace_headers, priority))
def encode(
self,
@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient):
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))

View File

@ -304,14 +304,12 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
self.engine.add_request(request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)

View File

@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""

View File

@ -151,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
video_url: Required[str]
class CustomThinkCompletionContentParam(TypedDict, total=False):
"""A Think Completion Content Param that accepts a plain text and a boolean.
Example:
{
"thinking": "I am thinking about the answer",
"closed": True,
"type": "thinking"
}
"""
thinking: Required[str]
"""The thinking content."""
closed: bool
"""Whether the thinking is closed."""
type: Required[Literal["thinking"]]
"""The thinking type."""
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
@ -159,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]
CustomChatCompletionContentSimpleVideoParam, str,
CustomThinkCompletionContentParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -938,6 +960,7 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
@ -954,6 +977,8 @@ MM_PARSER_MAP: dict[
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"thinking":
lambda part: _ThinkParser(part).get("thinking", None),
"input_text":
lambda part: _TextParser(part).get("text", None),
"input_image":
@ -1100,7 +1125,7 @@ def _parse_chat_message_content_part(
"with empty / unparsable content.", part, part_type)
return None
if part_type in ("text", "input_text", "refusal"):
if part_type in ("text", "input_text", "refusal", "thinking"):
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}

View File

@ -45,11 +45,6 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.data_parallel_start_rank:
raise ValueError(
"data_parallel_start_rank is only applicable "
"in headless mode. "
"Add --headless flag to enable headless mode.")
if args.api_server_count > 1:
run_multi_api_server(args)
else:
@ -86,13 +81,14 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_rank is not None:
raise ValueError("data_parallel_rank is not applicable in "
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in "
"headless mode")
parallel_config = vllm_config.parallel_config
@ -122,7 +118,7 @@ def run_headless(args: argparse.Namespace):
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
start_index=vllm_config.parallel_config.data_parallel_rank,
local_start_index=0,
vllm_config=vllm_config,
local_client=False,
@ -169,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
if vllm_config.parallel_config.data_parallel_hybrid_lb:
raise NotImplementedError(
"Hybrid load balancing with --api-server-count > 0"
"is not yet supported.")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats

View File

@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
@ -314,7 +313,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -330,7 +328,6 @@ class LLM:
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -346,7 +343,6 @@ class LLM:
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -363,7 +359,6 @@ class LLM:
prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -380,7 +375,6 @@ class LLM:
prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -395,7 +389,6 @@ class LLM:
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
@ -415,7 +408,6 @@ class LLM:
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None,
@ -440,8 +432,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
@ -507,7 +497,6 @@ class LLM:
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
@ -963,7 +952,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -980,7 +968,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -997,7 +984,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -1015,7 +1001,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -1033,7 +1018,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -1049,7 +1033,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -1070,7 +1053,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
@ -1092,8 +1074,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use.
Returns:
@ -1150,7 +1130,6 @@ class LLM:
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
@ -1167,7 +1146,6 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]:
"""
Generate an embedding vector for each prompt.
@ -1187,8 +1165,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
@ -1205,7 +1181,6 @@ class LLM:
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
@ -1218,7 +1193,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ClassificationRequestOutput]:
"""
Generate class logits for each prompt.
@ -1236,8 +1210,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `ClassificationRequestOutput` objects containing the
@ -1253,7 +1225,6 @@ class LLM:
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="classify",
)
@ -1267,7 +1238,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode(
@ -1275,7 +1245,6 @@ class LLM:
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
@ -1303,7 +1272,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer):
@ -1361,7 +1329,6 @@ class LLM:
params=pooling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
@ -1381,7 +1348,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or
`<multi-modal data, multi-modal data pair>`.
@ -1412,8 +1378,6 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `ScoringRequestOutput` objects containing the
@ -1504,8 +1468,7 @@ class LLM:
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
lora_request)
else:
return self._embedding_score(
tokenizer,
@ -1513,8 +1476,7 @@ class LLM:
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
lora_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()
@ -1625,7 +1587,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
@ -1671,7 +1632,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
)
@ -1681,7 +1641,6 @@ class LLM:
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
@ -1691,7 +1650,6 @@ class LLM:
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

View File

@ -8,7 +8,6 @@ import torch
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__)
@ -30,7 +29,6 @@ class RequestLogger:
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
@ -44,7 +42,6 @@ class RequestLogger:
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids,
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
lora_request, prompt_adapter_request)
lora_request)

View File

@ -1620,7 +1620,6 @@ async def init_app_state(
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=args.prompt_adapters,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = OpenAIServingResponses(

View File

@ -20,8 +20,7 @@ from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: list[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
@config
@dataclass
class FrontendArgs:
@ -115,9 +93,6 @@ class FrontendArgs:
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`"""
prompt_adapters: Optional[list[PromptAdapterPath]] = None
"""Prompt adapter configurations in the format name=path. Multiple adapters
can be specified."""
chat_template: Optional[str] = None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Prompt adapters need custom parser action and
# optional_type(str)
frontend_kwargs["prompt_adapters"]["type"] = optional_type(str)
frontend_kwargs["prompt_adapters"][
"action"] = PromptAdapterParserAction
# Special case: Middleware needs append action
frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str
@ -253,13 +222,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
parser.add_argument(
"--data-parallel-start-rank",
"-dpr",
type=int,
default=0,
help="Starting data parallel rank for secondary nodes. "
"Requires --headless.")
parser.add_argument("--api-server-count",
"-asc",
type=int,
@ -288,9 +250,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
if args.enable_prompt_embeds and args.enable_prompt_adapter:
raise ValueError(
"Cannot use prompt embeds and prompt adapter at the same time.")
def log_non_default_args(args: argparse.Namespace):

View File

@ -337,7 +337,6 @@ async def main(args):
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
)
openai_serving_chat = OpenAIServingChat(
engine,

View File

@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
raise self.engine_client.dead_error
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request)
@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs(request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)

View File

@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
return None
try:
(
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported for classification models"
)
(
ctx.request_prompts,
ctx.engine_prompts,

View File

@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
trace_headers = (None if raw_request is None else await
@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)

View File

@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx)
try:
(
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,

View File

@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Shared across most requests
tokenizer: Optional[AnyTokenizer] = None
@ -343,12 +341,10 @@ class OpenAIServing:
return self.create_error_response(
"Request prompts not available")
self._log_inputs(
request_id_item,
ctx.request_prompts[i],
params=pooling_params,
lora_request=ctx.lora_request,
prompt_adapter_request=ctx.prompt_adapter_request)
self._log_inputs(request_id_item,
ctx.request_prompts[i],
params=pooling_params,
lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
@ -450,11 +446,6 @@ class OpenAIServing:
if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.",
@ -489,25 +480,21 @@ class OpenAIServing:
self,
request: AnyRequest,
supports_default_mm_loras: bool = False,
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
None, PromptAdapterRequest]]:
) -> Optional[LoRARequest]:
if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model], None
return self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
return default_mm_lora, None
return default_mm_lora
if self._is_model_supported(request.model):
return None, None
return None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
@ -987,7 +974,6 @@ class OpenAIServing:
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
@ -1009,7 +995,6 @@ class OpenAIServing:
prompt_embeds,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
async def _get_trace_headers(

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pathlib
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@ -31,12 +28,6 @@ class BaseModelPath:
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
@ -60,7 +51,6 @@ class OpenAIServingModels:
base_model_paths: list[BaseModelPath],
*,
lora_modules: Optional[list[LoRAModulePath]] = None,
prompt_adapters: Optional[list[PromptAdapterPath]] = None,
):
super().__init__()
@ -81,20 +71,6 @@ class OpenAIServingModels:
LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
@ -141,14 +117,7 @@ class OpenAIServingModels:
permission=[ModelPermission()])
for lora in self.lora_requests.values()
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(

View File

@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for pooling models")
if isinstance(request, PoolingChatRequest):
(
_,
@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

View File

@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages(request, prev_response)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
self._log_inputs(request.request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
request.request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
generators.append(generator)

View File

@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
input_texts = texts_1 + texts_2
@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
generators.append(
self.engine_client.encode(
@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
request_prompts: list[str] = []
@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
generator = self.engine_client.encode(
engine_prompt,
@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
else:
@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
async def create_score(

View File

@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}"
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs(request_id,
request_prompts[i],
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
lora_request=lora_request)
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
if isinstance(engine_prompt,
dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
lora_request=lora_request)
prompt_input = await self._tokenize_prompt_input_async(
request,

View File

@ -11,6 +11,7 @@ from typing import Callable, Literal, Optional, TypeVar, Union, cast
import numpy as np
from fastapi import Request
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
@ -38,10 +39,6 @@ T = TypeVar("T", bound=SpeechToTextResponse)
logger = init_logger(__name__)
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
@ -70,6 +67,8 @@ class OpenAISpeechToText(OpenAIServing):
self.asr_config = self.model_cls.get_speech_to_text_config(
model_config, task_type)
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
@ -93,7 +92,7 @@ class OpenAISpeechToText(OpenAIServing):
lang = request.language or "en"
self.model_cls.validate_language(lang)
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
@ -150,19 +149,12 @@ class OpenAISpeechToText(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
if lora_request:
return self.create_error_response(
"Currently do not support LoRA for "
f"{self.task_type.title()}.")
if prompt_adapter_request:
return self.create_error_response(
f"Currently do not support PromptAdapter for "
f"{self.task_type.title()}.")
prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
@ -188,8 +180,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)
lora_request=None)
list_result_generator = [
self.engine_client.generate(

View File

@ -61,6 +61,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MM_INPUT_CACHE_GIB: int = 8
VLLM_TARGET_DEVICE: str = "cuda"
@ -140,6 +141,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
@ -518,6 +520,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Maximum filesize in MB for a single audio file when processing
# speech-to-text requests. Files larger than this will be rejected.
# Default is 25 MB
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB":
lambda: int(os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25")),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
#
@ -968,6 +976,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
# If set to 1, allows GC to run during capture.
"VLLM_ENABLE_CUDAGRAPH_GC":
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),

View File

@ -17,7 +17,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
@ -50,7 +49,6 @@ class ExecutorBase(ABC):
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
@ -171,35 +169,6 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("add_prompt_adapter",
args=(prompt_adapter_request, )))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("remove_prompt_adapter",
args=(prompt_adapter_id, )))
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("pin_prompt_adapter",
args=(prompt_adapter_id, )))
def list_prompt_adapters(self) -> Set[int]:
sets = self.collective_rpc("list_prompt_adapters")
for s in sets:
assert (s == sets[0]
), "All workers should have the same prompt adapters."
return sets[0]
def start_profile(self) -> None:
self.collective_rpc("start_profile")

View File

@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -168,18 +167,6 @@ class InputPreprocessor:
return decoder_input_ids
def _apply_prompt_adapter(
self,
prompt_token_ids: list[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> list[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_tokenization_kw(
self,
overrides: Optional[dict[str, Any]] = None,
@ -786,15 +773,10 @@ class InputPreprocessor:
def _build_decoder_only_llm_inputs(
self,
prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs:
if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
prompt_inputs) # Needed for mypy
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
return prompt_inputs
@ -803,7 +785,6 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> DecoderOnlyInputs:
"""
@ -815,7 +796,6 @@ class InputPreprocessor:
* prompt: input prompt
* lora_request
* prompt_adapter_request
* return_mm_hashes
Returns:
@ -830,17 +810,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> DecoderOnlyInputs:
"""
@ -854,17 +830,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
@ -886,7 +858,6 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
)
@ -895,7 +866,6 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> ProcessorInputs:
"""
@ -919,6 +889,5 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
)

View File

@ -6,6 +6,7 @@ from collections.abc import Generator
import gguf
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from vllm.config import LoadConfig, ModelConfig, VllmConfig
@ -32,8 +33,18 @@ class GGUFModelLoader(BaseModelLoader):
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
# for raw HTTPS link
if model_name_or_path.startswith(
("http://", "https://")) and model_name_or_path.endswith(".gguf"):
return hf_hub_download(url=model_name_or_path)
# repo id/filename.gguf
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
else:
raise ValueError(f"{model_name_or_path} is not a file.")
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)")
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""

View File

@ -136,6 +136,40 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...
@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...
def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
return getattr(model, "supports_multimodal_raw_input", False)
@runtime_checkable
class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template."""

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Union
@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.interfaces import (
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
def _get_prompt_updates(
@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for k, v in mm_data.items():
mm_kwargs[k] = v
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(
modality="image",
key=key,
data=data,
field=MultiModalSharedField(1),
) for key, data in mm_kwargs.items()
])
]
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders={},
mm_placeholders=mm_placeholders,
)
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
"""Prithvi Masked Autoencoder"""
is_pooling_model = True
@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"])
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty((input_ids.shape[0], 0))
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -22,8 +22,8 @@ from vllm.logger import init_logger
from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -287,6 +287,7 @@ class _ModelInfo:
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@ -304,6 +305,7 @@ class _ModelInfo:
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
@ -573,6 +575,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def supports_multimodal_raw_input(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal_raw_input
def is_pp_supported_model(
self,
architectures: Union[str, list[str]],

View File

@ -266,7 +266,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
if tokenizer is None:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
mm_config = model_config.get_multimodal_config()

View File

@ -1,83 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states

View File

@ -1,358 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Type
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.utils import load_peft_weights
logger = logging.getLogger(__name__)
_GLOBAL_PROMPT_ADAPTER_ID = 0
def get_prompt_adapter_id():
global _GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID += 1
return _GLOBAL_PROMPT_ADAPTER_ID
def convert_to_embedding_indices(indices):
embedding_indices = []
count = 0
for value in indices:
if value == -1:
count = 0
else:
embedding_indices.append([value, count])
count += 1
return torch.tensor(embedding_indices)
def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]],
) -> torch.Tensor:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index = {
id_: idx
for idx, id_ in enumerate(prompt_adapter_index_to_id)
if id_ is not None
}
pa_indices = ([
id_to_index.get(id_, -1) if id_ > 0 else -1
for id_ in mapping.index_mapping
])
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
pa_indices = torch.tensor(pa_indices)
return pa_indices, pa_embedding_mapping
class PromptAdapterModel(AdapterModel):
def __init__(self,
prompt_adapter_id=None,
num_virtual_tokens=None,
prompt_embedding=None) -> None:
self.id = prompt_adapter_id
self.prompt_embedding = prompt_embedding
self.num_virtual_tokens = num_virtual_tokens
@classmethod
def from_local_checkpoint(
cls,
adapter_model_path: str,
prompt_adapter_id: int,
num_virtual_tokens: int,
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
adapters_weights = load_peft_weights(adapter_model_path, device)
prompt_embedding = adapters_weights["prompt_embeddings"].to(
config.prompt_adapter_dtype)
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
class PromptAdapterModelManager(AdapterModelManager):
"""A manager that manages multiple Prompt Adapter models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self.model: nn.Module = model
# Dict instead of a Set for compatibility with LRUCache.
self.prompt_adapter_index_to_id: List[
Optional[int]] = [None] * self.prompt_adapter_slots
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.prompt_adapter_config = prompt_adapter_config
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'
self.base_indices = torch.tensor([-1])
self.base_embedding_indices = torch.tensor([])
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None
@property
def prompt_adapter_slots(self) -> int:
return self.prompt_adapter_config.max_prompt_adapters
@property
def adapter_slots(self) -> int:
return self.prompt_adapter_slots
@property
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if prompt_adapter_id in self._active_adapters:
return False
first_free_slot = next(
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
None)
if first_free_slot is None:
raise ValueError("No free prompt_adapter slots")
index, _ = first_free_slot
self._active_adapters[prompt_adapter_id] = None
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for _, v in self.modules.items():
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
return True
def _deactivate_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for _, v in self.modules.items():
v.reset_prompt_adapter(index)
except ValueError:
pass
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
self._registered_adapters[prompt_adapter.id] = prompt_adapter
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
base_indices, base_embedding_indices = convert_mapping(
mapping, self.prompt_adapter_index_to_id)
for k, v in self.modules.items():
v.set_mapping(base_indices, base_embedding_indices)
def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
new_module.create_prompt_adapter_weights(
self.prompt_adapter_config)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices,
self.base_embedding_indices)
break
def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in PromptAdapterModelManager. "
"Use LRUCachePromptAdapterModelManager for pinning"
) # type: ignore
def remove_all_adapters(self):
"""Remove all PromptAdapterModel from the manager."""
self._registered_adapters.clear()
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
self._active_adapters.clear()
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
def __init__(self, capacity: int,
deactivate_prompt_adapter_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_prompt_adapter_fn)
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
self.prompt_adapter_config = prompt_adapter_config
super().__init__(model, max_num_seqs, max_num_batched_tokens,
prompt_adapter_config)
self._registered_adapters = PromptAdapterLRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters = PromptAdapterLRUCache(
self.prompt_adapter_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
"""List all registered PromptAdapterModel."""
return dict(self._registered_adapters.cache)
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
"""Add a PromptAdapterModel to the manager."""
if prompt_adapter.id not in self._registered_adapters:
self._add_adapter(prompt_adapter)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(prompt_adapter.id)
was_added = False
return was_added
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
if prompt_adapter_id not in self._active_adapters and len(
self._active_adapters) >= self.prompt_adapter_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(prompt_adapter_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(prompt_adapter_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
return True
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
try:
self._registered_adapters.pin(prompt_adapter_id)
except ValueError as err:
raise ValueError(
"Pinning failed. "
f"Prompt Adapter {prompt_adapter_id} is not registered."
) from err
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
if prompt_adapter_id not in self._active_adapters:
# move adapter to gpu if not already active
self.activate_adapter(prompt_adapter_id)
self._active_adapters.pin(prompt_adapter_id)
def create_prompt_adapter_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_manager_cls: Type[
PromptAdapterModelManager] = PromptAdapterModelManager,
**kwargs) -> PromptAdapterModelManager:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager = prompt_adapter_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
prompt_adapter_config=prompt_adapter_config,
**kwargs)
return prompt_adapter_manager

View File

@ -1,37 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from vllm.adapter_commons.request import AdapterRequest
class PromptAdapterRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
frozen=True): # type: ignore[call-arg]
"""
Request for a Prompt adapter.
"""
__metaclass__ = AdapterRequest
prompt_adapter_name: str
prompt_adapter_id: int
prompt_adapter_local_path: str
prompt_adapter_num_virtual_tokens: int
def __hash__(self):
return super().__hash__()
@property
def adapter_id(self):
return self.prompt_adapter_id
@property
def name(self):
return self.prompt_adapter_name
@property
def local_path(self):
return self.prompt_adapter_local_path

View File

@ -1,98 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
import os
from typing import Optional
import torch
from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from vllm.platforms import current_platform
WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
# Get current device name based on available devices
def infer_device() -> str:
if current_platform.is_cuda_alike():
return "cuda"
return "cpu"
def load_peft_weights(model_id: str,
device: Optional[str] = None,
**hf_hub_download_kwargs) -> dict:
r"""
A helper method to load the PEFT weights from the HuggingFace Hub or locally
Args:
model_id (`str`):
The local path to the adapter weights or the name of the adapter to
load from the HuggingFace Hub.
device (`str`):
The device to load the weights onto.
hf_hub_download_kwargs (`dict`):
Additional arguments to pass to the `hf_hub_download` method when
loading from the HuggingFace Hub.
"""
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if
hf_hub_download_kwargs.get("subfolder") is not None else model_id)
if device is None:
device = infer_device()
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
use_safetensors = True
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False
else:
token = hf_hub_download_kwargs.get("token")
if token is None:
token = hf_hub_download_kwargs.get("use_auth_token")
hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
SAFETENSORS_WEIGHTS_NAME)
if hf_hub_download_kwargs.get("subfolder") is not None
else SAFETENSORS_WEIGHTS_NAME)
has_remote_safetensors_file = file_exists(
repo_id=model_id,
filename=hub_filename,
revision=hf_hub_download_kwargs.get("revision"),
repo_type=hf_hub_download_kwargs.get("repo_type"),
token=token,
)
use_safetensors = has_remote_safetensors_file
if has_remote_safetensors_file:
# Priority 1: load safetensors weights
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME,
**hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError( # noqa: B904
f"Can't find weights for {model_id} in {model_id} or \
in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} or \
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")
if use_safetensors:
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch.load(filename,
map_location=torch.device(device),
weights_only=True)
return adapters_weights

View File

@ -1,179 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from typing import Any, Optional, Set, Type
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
PromptAdapterModel,
PromptAdapterModelManager,
create_prompt_adapter_manager)
from vllm.prompt_adapter.request import PromptAdapterRequest
logger = logging.getLogger(__name__)
class WorkerPromptAdapterManager(AbstractWorkerManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Every request, the requested prompt_adapters will be
loaded (unless they are already loaded),
and every other prompt_adapter will be unloaded."""
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
device: torch.device,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
):
self._adapter_manager: PromptAdapterModelManager
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self._prompt_adapter_model_cls = prompt_adapter_model_cls
self.prompt_adapter_config = prompt_adapter_config
super().__init__(device)
@property
def is_enabled(self) -> bool:
return True
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._manager_cls,
)
self._adapter_manager = prompt_adapter_manager
return prompt_adapter_manager.model
def _load_adapter(
self, prompt_adapter_request: PromptAdapterRequest
) -> PromptAdapterModel:
try:
prompt_adapter = (
self._prompt_adapter_model_cls.from_local_checkpoint(
prompt_adapter_request.prompt_adapter_local_path,
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
num_virtual_tokens=prompt_adapter_request.
prompt_adapter_num_virtual_tokens,
config=self.prompt_adapter_config,
device=str(self.device),
))
except Exception as e:
raise RuntimeError(
f"Loading prompt_adapter "
f"{prompt_adapter_request.prompt_adapter_local_path}"
f" failed") from e
return prompt_adapter
def add_dummy_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return True
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Uses an LRU Cache. Every request, the requested
prompt_adapters will be loaded (unless they are already loaded)
and least recently used prompt_adapters will
be unloaded if the cache is above capacity."""
_prompt_adapter_manager_cls: Type[
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
self._adapter_manager: LRUCachePromptAdapterModelManager = (
prompt_adapter_manager)
return prompt_adapter_manager.model
def _apply_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
prompt_adapters_map = {
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
for prompt_adapter_request in prompt_adapter_requests
if prompt_adapter_request
}
if len(prompt_adapters_map
) > self._adapter_manager.prompt_adapter_slots:
raise RuntimeError(
f"Number of requested prompt_adapters "
f"({len(prompt_adapters_map)}) is greater "
"than the number of GPU prompt_adapter slots "
f"({self._adapter_manager.prompt_adapter_slots}).")
for prompt_adapter in prompt_adapters_map.values():
self.add_adapter(prompt_adapter)
def add_adapter(self,
prompt_adapter_request: PromptAdapterRequest) -> bool:
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
):
# Remove before we load the new prompt_adapter to save memory
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
self._adapter_manager.remove_oldest_adapter()
prompt_adapter = self._load_adapter(prompt_adapter_request)
loaded = self._adapter_manager.add_adapter(prompt_adapter)
else:
# If the prompt_adapter is already loaded, just touch it to
# update its position in the caches
loaded = self._adapter_manager.get_adapter(
prompt_adapter_request.prompt_adapter_id) is not None
self._adapter_manager.activate_adapter(
prompt_adapter_request.prompt_adapter_id)
return loaded

View File

@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
__all__ = [
@ -16,4 +17,5 @@ __all__ = [
"HunyuanA13BReasoningParser",
"Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser",
"MistralReasoningParser",
]

View File

@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import (
DeepSeekR1ReasoningParser)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__)
@ReasoningParserManager.register_module("mistral")
class MistralReasoningParser(DeepSeekR1ReasoningParser):
"""
Reasoning parser for Mistral models.
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: MistralTokenizer):
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"The tokenizer must be an instance of MistralTokenizer.")
ReasoningParser.__init__(self, tokenizer)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
from mistral_common.tokens.tokenizers.base import SpecialTokens
self.start_token = SpecialTokens.begin_think
self.end_token = SpecialTokens.end_think
self.start_token_id = tokenizer.tokenizer.get_control_token(
self.start_token)
self.end_token_id = tokenizer.tokenizer.get_control_token(
self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"Mistral reasoning parser could not locate think start/end "
"tokens in the tokenizer!")

View File

@ -19,7 +19,6 @@ from vllm.inputs import SingletonInputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@ -458,7 +457,6 @@ class Sequence:
block size used by the block manager and cache engine.
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@ -468,14 +466,12 @@ class Sequence:
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.seq_id = seq_id
self.inputs = inputs
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData.from_seqs(
self.prompt_token_ids,
@ -537,11 +533,6 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
@ -601,12 +592,12 @@ class Sequence:
designed for prefix caching mode. The final sequence hash is determined
by applying token_ids from the sequence's blocks.
"""
if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
if self.lora_int_id == 0:
return None
# NOTE: If there are additional factors influencing the block aside from
# token_ids, include them as input parameters to the hash.
return hash((self.prompt_adapter_id, self.lora_int_id))
return hash(self.lora_int_id)
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
@ -707,7 +698,6 @@ class SequenceGroup:
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
@ -725,7 +715,6 @@ class SequenceGroup:
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
draft_size: int = 1) -> None:
self.request_id = request_id
@ -747,7 +736,6 @@ class SequenceGroup:
self.state = SequenceGroupState()
self.pooling_params = pooling_params
self.pooled_data = pooled_data
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self.priority = priority
@ -802,16 +790,6 @@ class SequenceGroup:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def init_multi_step(self, num_steps: int) -> None:
self.state.num_steps = num_steps
self.state.current_step = 0
@ -1011,7 +989,6 @@ class SequenceGroupMetadata(
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
prompt_adapter_request: Prompt Adapter request.
"""
request_id: str
@ -1030,7 +1007,6 @@ class SequenceGroupMetadata(
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[list[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
### Stateful fields that are lazily defined. ###
@ -1052,16 +1028,6 @@ class SequenceGroupMetadata(
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
# Multi-Step Chunked-Prefill property
@property
def is_single_step_prompt(self) -> bool:
@ -1525,7 +1491,6 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
pooled_data=seq_group.pooled_data,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
priority=seq_group.priority,
)

View File

@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]):
return matched_files[0]
def _aggregate_content(content: list) -> list[dict[str, Any]]:
aggregated_content: list[dict[str, Any]] = []
for chunk in content:
if chunk.get("type"
) == "text" and aggregated_content and aggregated_content[
-1].get("type") == "text":
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
else:
aggregated_content.append(chunk)
if len(aggregated_content) == 1 and aggregated_content[0].get(
"type") == "text":
content = aggregated_content[0]["text"]
return content
def make_mistral_chat_completion_request(
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str,
@ -162,10 +177,10 @@ def make_mistral_chat_completion_request(
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content = message.get("content")
content: Any = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)
message["content"] = content
content = _aggregate_content(content)
message["content"] = content
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict to be present, even if it's empty.
@ -465,6 +480,8 @@ class MistralTokenizer(TokenizerBase):
skip_special_tokens: bool = True,
) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerV13)
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
@ -474,10 +491,18 @@ class MistralTokenizer(TokenizerBase):
assert self.is_tekken or self.is_spm, type(self.tokenizer)
if self.is_tekken:
# skip special tokens except tool call
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
# skip special tokens except tool call and think tokens
non_skip_special_tokens = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
}
if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK:
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK:
non_skip_special_tokens.add(self.instruct.END_THINK)
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens
or i in non_skip_special_tokens
]
tokens = [self.tokenizer.id_to_piece(id) for id in ids]

View File

@ -128,10 +128,6 @@ STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
"backends currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
@ -145,7 +141,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}
# Constants related to forcing the attention backend selection

View File

@ -20,7 +20,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@ -94,11 +93,14 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
@ -125,7 +127,7 @@ class AsyncLLM(EngineClient):
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()
@ -218,7 +220,6 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
@ -235,8 +236,7 @@ class AsyncLLM(EngineClient):
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
tokenization_kwargs, trace_headers, priority, data_parallel_rank)
if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
@ -280,7 +280,6 @@ class AsyncLLM(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
@ -311,7 +310,6 @@ class AsyncLLM(EngineClient):
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
@ -525,6 +523,10 @@ class AsyncLLM(EngineClient):
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:

View File

@ -61,11 +61,12 @@ class DPCoordinator:
host = parallel_config.data_parallel_master_ip
external_lb = parallel_config.data_parallel_external_lb
hybrid_lb = parallel_config.data_parallel_hybrid_lb
# Assume coordinator is colocated with front-end procs when not in
# external DP LB mode.
# either external or hybrid DP LB mode.
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb, host=host)
local_only=not external_lb and not hybrid_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)

View File

@ -234,9 +234,14 @@ class EngineCore:
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def execute_model(self, scheduler_output: SchedulerOutput):
def execute_model_with_error_logging(
self,
model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
scheduler_output: SchedulerOutput,
) -> ModelRunnerOutput:
"""Execute the model and log detailed info on failure."""
try:
return self.model_executor.execute_model(scheduler_output)
return model_fn(scheduler_output)
except Exception as err:
# We do not want to catch BaseException here since we're only
# interested in dumping info when the exception is due to an
@ -259,7 +264,9 @@ class EngineCore:
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
@ -306,8 +313,11 @@ class EngineCore:
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = future.result()
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
@ -467,13 +477,14 @@ class EngineCoreProc(EngineCore):
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external loadbalancing, two handshakes are performed:
For DP>1 with external or hybrid loadbalancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 engine itself which doesn't require
the second handshake.
with the exception of the rank 0 and colocated engines themselves which
don't require the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
@ -482,15 +493,18 @@ class EngineCoreProc(EngineCore):
"""
input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None
headless = not local_client
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config,
identity, is_local, headless,
vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client,
input_ctx, client_handshake_address, identity, True, False,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
@ -507,6 +521,7 @@ class EngineCoreProc(EngineCore):
handshake_address: str,
identity: bytes,
local_client: bool,
headless: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
@ -518,6 +533,7 @@ class EngineCoreProc(EngineCore):
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, local_client,
headless,
parallel_config_to_update)
yield addresses
@ -531,6 +547,7 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))
@ -539,6 +556,7 @@ class EngineCoreProc(EngineCore):
def startup_handshake(
handshake_socket: zmq.Socket,
local_client: bool,
headless: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
@ -547,6 +565,7 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({
"status": "HELLO",
"local": local_client,
"headless": headless,
}))
# Receive initialization message.
@ -891,22 +910,6 @@ class DPEngineCoreProc(EngineCoreProc):
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
# Set CUDA_VISIBLE_DEVICES or equivalent.
try:
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank *
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e:
raise Exception(
f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
@ -1069,14 +1072,41 @@ class DPEngineCoreActor(DPEngineCoreProc):
vllm_config.parallel_config.data_parallel_rank_local = \
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
del os.environ['CUDA_VISIBLE_DEVICES']
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
# NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
# and this cannot be done in the same way for Ray because:
# 1) Ray manages life cycle of all ray workers (including
# DPEngineCoreActor)
# 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
# To bypass 2, we need to also set
# RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
# thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
# https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
# But vLLM worker assumes visibility into all local GPUs, therefore
# this results in incorrect indexing into the GPU ID list.
self._set_cuda_visible_devices(vllm_config, local_dp_rank)
super().__init__(vllm_config, local_client, "", executor_class,
log_stats)
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
local_dp_rank: int):
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
# Set CUDA_VISIBLE_DEVICES or equivalent.
try:
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank *
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e:
raise Exception(
f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
def _decorate_logs(self):
pass

View File

@ -429,18 +429,23 @@ class MPClient(EngineCoreClient):
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
dp_local_size = parallel_config.data_parallel_size_local
offline_mode = parallel_config.data_parallel_rank_local is not None
self.engine_ranks = ([dp_rank] if
(offline_mode or external_dp_lb) else list(
range(dp_size)))
# Client manages local+remote EngineCores in pure internal LB case.
# Client manages local EngineCores in hybrid and external LB case.
local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
num_ranks = dp_local_size if local_engines_only else dp_size
self.engine_ranks_managed = [dp_rank] if offline_mode else list(
range(dp_rank, dp_rank + num_ranks))
assert parallel_config.data_parallel_size_local <= len(
self.engine_ranks)
self.engine_ranks_managed)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in self.engine_ranks
rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
]
# Wait for ready messages from each engine on the input socket.
@ -895,6 +900,12 @@ class DPAsyncMPClient(AsyncMPClient):
return
assert self.stats_update_address is not None
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(self.engine_ranks_managed[0],
self.engine_ranks_managed[-1] + 1)
async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
@ -959,7 +970,7 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
self.lb_engines = counts
self.lb_engines = counts[count_slice]
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())

View File

@ -17,7 +17,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
@ -82,11 +81,14 @@ class LLMEngine:
self.dp_group = None
self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
@ -189,7 +191,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# Validate the request_id type.
@ -200,8 +201,7 @@ class LLMEngine:
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
tokenization_kwargs, trace_headers, priority)
n = params.n if isinstance(params, SamplingParams) else 1

View File

@ -327,14 +327,16 @@ class OutputProcessor:
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
tokenizer = None if not self.tokenizer else \
self.tokenizer.get_lora_tokenizer(request.lora_request)
req_state = RequestState.from_new_request(tokenizer=tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:

View File

@ -16,7 +16,6 @@ from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
@ -226,7 +225,6 @@ class Processor:
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
@ -237,8 +235,6 @@ class Processor:
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
@ -253,12 +249,10 @@ class Processor:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,
)
from vllm.platforms import current_platform
@ -380,7 +374,6 @@ class Processor:
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
@ -389,9 +382,14 @@ class Processor:
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:

View File

@ -10,12 +10,14 @@ from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Callable, Optional, Union
from unittest.mock import patch
import msgspec
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
@ -105,10 +107,13 @@ class CoreEngineProcManager:
"client_handshake_address"] = client_handshake_address
self.processes: list[BaseProcess] = []
local_dp_ranks = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
local_dp_ranks.append(local_index)
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
@ -118,9 +123,14 @@ class CoreEngineProcManager:
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try:
for proc in self.processes:
proc.start()
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
with set_device_control_env_var(
vllm_config, local_dp_rank) if (
data_parallel) else contextlib.nullcontext():
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
@ -145,6 +155,30 @@ class CoreEngineProcManager:
}
@contextlib.contextmanager
def set_device_control_env_var(vllm_config: VllmConfig,
local_dp_rank: int) -> Iterator[None]:
"""
Temporarily set CUDA_VISIBLE_DEVICES or equivalent
for engine subprocess.
"""
world_size = vllm_config.parallel_config.world_size
evar = current_platform.device_control_env_var
try:
value = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
except IndexError as e:
raise Exception(f"Error setting {evar}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
"base value: "
f"\"{os.getenv(evar)}\"") from e
with patch.dict(os.environ, values=((evar, value), )):
yield
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
@ -215,10 +249,9 @@ class CoreEngineActorManager:
self.placement_group_is_local = []
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
for index, local_index, pg in zip(range(dp_size), local_dp_ranks,
placement_groups):
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
@ -264,7 +297,6 @@ class CoreEngineActorManager:
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
@ -544,7 +576,8 @@ def launch_core_engines(
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
external_dp_lb = parallel_config.data_parallel_external_lb
local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
@ -553,8 +586,8 @@ def launch_core_engines(
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
client_local_only = (offline_mode or local_engines_only
or (local_engine_count == dp_size))
# Set up input and output addresses.
addresses = EngineZmqAddresses(
@ -598,14 +631,27 @@ def launch_core_engines(
yield engine_actor_manager, coordinator, addresses
return
if offline_mode or (external_dp_lb and dp_rank > 0):
if offline_mode:
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else:
elif dp_rank == 0:
# Rank 0 holds Coordinator, so it handshakes with all Cores
# in both external dplb and internal dplb mode.
# Note this also covers the case where we have zero local engines
# and rank 0 is headless.
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
else:
# Rank > 0 handshakes with just the local cores it is managing.
assert local_engines_only, (
"Attempting to launch core_engines from dp_rank > 0, but "
"found internal DPLB, which is incompatible.")
engines_to_handshake = [
CoreEngine(index=i, local=True)
for i in range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@ -616,7 +662,7 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0:
if local_engines_only and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
@ -631,8 +677,6 @@ def launch_core_engines(
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
@ -678,6 +722,9 @@ def wait_for_engine_startup(
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \
and not parallel_config.data_parallel_external_lb
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
@ -713,13 +760,24 @@ def wait_for_engine_startup(
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
status, local, headless = msg["status"], msg["local"], msg["headless"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
if not local and headless != remote_should_be_headless:
if headless:
raise RuntimeError(f"Remote engine {eng_index} must not use "
f"--headless in external or hybrid dp lb "
f"mode")
else:
raise RuntimeError(f"Remote engine {eng_index} must use "
f"--headless unless in external or hybrid "
f"dp lb mode")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.

View File

@ -318,8 +318,6 @@ def report_usage_stats(
# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":

View File

@ -104,7 +104,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
@ -126,6 +125,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.model_supports_multimodal_raw_input = (
model_config.model_supports_multimodal_raw_input)
self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
@ -328,6 +329,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
@ -565,6 +574,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _init_model_kwargs_for_multimodal_model(
self,
scheduler_output: Optional["SchedulerOutput"] = None,
num_reqs: int = -1,
) -> dict[str, Any]:
model_kwargs: dict[str, Any] = {}
if self.model_supports_multimodal_raw_input:
# This model requires the raw multimodal data in input.
if scheduler_output:
multi_modal_kwargs_list = []
for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs)
multi_modal_kwargs = MultiModalKwargs.batch(
multi_modal_kwargs_list)
else:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data = [
self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=1).multi_modal_data for i in range(num_reqs)
]
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
model_kwargs.update(multi_modal_kwargs)
return model_kwargs
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
@ -1359,10 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
model_kwargs = self._init_model_kwargs_for_multimodal_model(
scheduler_output=scheduler_output)
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
multimodal_embeddings=mm_embeds or None,
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
@ -1374,6 +1419,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {}
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
@ -1406,6 +1452,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
)
self.maybe_wait_for_kv_save()
@ -1822,17 +1872,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config)
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@ -1865,6 +1907,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
rank_mapping,
)
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
@ -2084,11 +2133,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
model_kwargs = self._init_model_kwargs_for_multimodal_model(
num_reqs=num_reqs)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
model_kwargs = {}
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
@ -2117,7 +2170,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = outputs
else:
@ -2381,10 +2439,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
@contextmanager
def freeze_gc():
# Optimize garbage collection during CUDA graph capture.
# Clean up, then freeze all remaining objects from being included
# in future collections.
gc.collect()
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
with freeze_gc(), graph_capture(device=self.device):
full_cg = self.full_cuda_graph
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)

View File

@ -4,6 +4,7 @@
import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -15,7 +16,8 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -118,6 +120,21 @@ class Worker(WorkerBase):
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def _maybe_get_memory_pool_context(self,
tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag=tag)
else:
context = nullcontext()
return context
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
@ -179,24 +196,17 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with context:
with self._maybe_get_memory_pool_context(tag="weights"):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
with self._maybe_get_memory_pool_context(tag="weights"):
self.model_runner.reload_weights()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
@ -333,19 +343,20 @@ class Worker(WorkerBase):
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
return None
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
empty_output = copy.copy(empty_output)
empty_output.finished_sending = output.finished_sending
empty_output.finished_recving = output.finished_recving
output = empty_output
new_output = copy.copy(new_output)
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
output = new_output
assert isinstance(output, ModelRunnerOutput)
# return output only from the driver worker
return output if self.is_driver_worker else None
return output
def profile(self, is_start: bool = True):
if self.profiler is None:

View File

@ -114,7 +114,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config
@ -1174,16 +1173,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
mesh=self.mesh)
else:
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info("Model was already initialized. \
Loading weights inplace...")
model_loader.load_weights(
self.model, model_config=self.model_config)
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
except RuntimeError as e:
raise RuntimeError(
f"Unable to load model, a likely reason is the model is "
@ -1205,6 +1198,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.model = model
self.sampler = TPUSampler()
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
@torch.no_grad()
def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None:

View File

@ -62,7 +62,6 @@ class TPUWorker:
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
@ -265,6 +264,9 @@ class TPUWorker:
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
self.model_runner.reload_weights()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()

View File

@ -91,10 +91,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
'''
EncoderDecoderModelRunner constructor.
`lora_config` and `prompt_adapter_config` are
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
the base-class constructor.
`lora_config` is unused (since these features are not yet supported
for encoder/decoder models) but these arguments are present here for
compatibility with the base-class constructor.
'''
self._maybe_force_supported_attention_backend()

View File

@ -45,10 +45,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
@ -95,8 +91,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
@ -113,8 +107,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
@ -164,8 +156,6 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
@ -212,8 +202,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
def __init__(
self,
@ -252,11 +240,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,
# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs.
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[
@ -360,18 +343,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
self.lora_requests.clear()
if prompt_adapter_index_mapping:
self.prompt_adapter_index_mapping = \
prompt_adapter_index_mapping
else:
self.prompt_adapter_index_mapping.clear()
if prompt_adapter_prompt_mapping:
self.prompt_adapter_prompt_mapping = \
prompt_adapter_prompt_mapping
else:
self.prompt_adapter_prompt_mapping.clear()
else:
self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
@ -390,12 +361,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
@ -485,7 +450,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self.per_seq_group_compute_fns = [
self._compute_prompt_adapter_input,
self._compute_multi_modal_input,
]
@ -496,8 +460,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
# Attention metadata inputs.
if self.attn_backend is not None:
@ -693,34 +655,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
inter_data.lora_prompt_mapping.append([])
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if not self.enable_prompt_adapter:
return
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
return
# We expect only one sequence in the group when is_prompt=True.
assert inter_data.n_seqs == 1
query_len = inter_data.query_lens[0]
inter_data.prompt_adapter_request = (
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
inter_data.prompt_adapter_index_mapping = [
prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
@ -1009,29 +943,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only))
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
prompt_adapter_mapping = None
if self.enable_prompt_adapter:
prompt_adapter_requests = set(
data.prompt_adapter_request for data in self.inter_data_list
if data.prompt_adapter_request is not None)
prompt_adapter_index_mapping = flatten_2d_lists([
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
if cuda_graph_pad_size:
prompt_adapter_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
# Multi-modal data.
multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list
@ -1051,9 +962,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
finished_requests_ids=self.finished_requests_ids)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@ -1148,7 +1057,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.sampler = get_sampler()
set_cpu_offload_max_bytes(
@ -1207,14 +1115,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
@ -1466,40 +1367,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.remove_all_adapters()
def set_active_prompt_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest],
prompt_adapter_mapping: PromptAdapterMapping) -> None:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.set_active_adapters(
prompt_adapter_requests, prompt_adapter_mapping)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
@ -1609,13 +1476,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.set_active_loras(set([dummy_lora_request]),
lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size),
@ -1776,13 +1636,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase.
@ -1932,24 +1785,32 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
sampled_token_ids = []
valid_outputs = []
for sequence_group_output in output.outputs:
if len(sequence_group_output.samples) == 0:
continue
assert len(sequence_group_output.samples) == 1
valid_outputs.append(sequence_group_output)
sampled_token_ids.append(
sequence_group_output.samples[0].output_token)
sampled_token_ids = torch.tensor(sampled_token_ids).to(
self.device)
sampled_token_ids = broadcast_tensor_dict(
{"sampled_token_ids":
sampled_token_ids})["sampled_token_ids"]
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
sampled_token_ids = broadcast_tensor_dict(
)["sampled_token_ids"]
if len(sampled_token_ids) > 0:
sampled_token_embeds = \
self.model.get_input_embeddings(sampled_token_ids)
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
output.sampled_token_embeds = sampled_token_embeds
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
for i, sequence_group_output in enumerate(valid_outputs):
sequence_group_output.samples[0].output_embed = \
sampled_token_embeds[i]
if not self.is_driver_worker:
return []

View File

@ -190,7 +190,6 @@ class ModelRunnerBase(ABC, Generic[T]):
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling

View File

@ -288,9 +288,6 @@ class StatefulModelInput(BroadcastableModelInput):
assert fmi.lora_requests is not None
assert len(fmi.lora_requests) == 0
assert fmi.attn_metadata is not None
assert fmi.prompt_adapter_mapping is None
assert fmi.prompt_adapter_requests is not None
assert len(fmi.prompt_adapter_requests) == 0
assert fmi.multi_modal_kwargs is not None
assert len(fmi.multi_modal_kwargs) == 0

View File

@ -64,13 +64,6 @@ class PoolingModelRunner(
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata

View File

@ -47,7 +47,3 @@ def assert_enc_dec_mr_supported_scenario(
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
if enc_dec_mr.prompt_adapter_config is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])

Some files were not shown because too many files have changed in this diff Show More