mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
19 Commits
debug-logg
...
v0.10.0rc2
Author | SHA1 | Date | |
---|---|---|---|
6d8d0a24c0 | |||
11ef7a611e | |||
dc2f159f8a | |||
d5b981f8b1 | |||
eec6942014 | |||
fd48d99ffd | |||
f8c15c4efb | |||
aa08a954f9 | |||
13e4ee1dc3 | |||
772ce5af97 | |||
63d92abb7c | |||
11599b0e1f | |||
f3137cdd81 | |||
82ec66f514 | |||
78c13e30e1 | |||
5c9b807b34 | |||
14bf19e39f | |||
4ac7713e32 | |||
8560a5b258 |
@ -62,7 +62,8 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
@ -150,7 +151,7 @@ run_and_track_test 9 "test_multimodal.py" \
|
||||
run_and_track_test 10 "test_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py"
|
||||
run_and_track_test 11 "test_struct_output_generate.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 12 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 13 "test_lora.py" \
|
||||
|
@ -31,4 +31,13 @@ docker run \
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
pytest -v -s v1/engine
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
'
|
||||
|
@ -166,6 +166,7 @@ steps:
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -178,6 +179,7 @@ steps:
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -718,6 +720,7 @@ steps:
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
@ -265,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
FROM base AS dev
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
|
@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate hf_transfer pytest modelscope
|
||||
pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image \
|
||||
TRITON_XPU_PROFILE 1
|
||||
|
@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes.
|
||||
- [vllm.config.DeviceConfig][]
|
||||
- [vllm.config.SpeculativeConfig][]
|
||||
- [vllm.config.LoRAConfig][]
|
||||
- [vllm.config.PromptAdapterConfig][]
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.DecodingConfig][]
|
||||
|
@ -34,23 +34,22 @@ th:not(:first-child) {
|
||||
}
|
||||
</style>
|
||||
|
||||
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|
||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
||||
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|
||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
||||
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
|
||||
| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | |
|
||||
| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | |
|
||||
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
|
||||
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
|
||||
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
|
||||
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
|
||||
| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
|
||||
| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | |
|
||||
| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
|
||||
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [❌](gh-issue:7366) | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
|
||||
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
|
||||
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
|
||||
| multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
|
||||
| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | |
|
||||
| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ |
|
||||
|
||||
[](){ #feature-x-hardware }
|
||||
|
||||
@ -59,10 +58,9 @@ th:not(:first-child) {
|
||||
| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU |
|
||||
|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----|
|
||||
| [CP][chunked-prefill] | [❌](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [APC](automatic_prefix_caching.md) | [❌](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8475) | ✅ | ❌ |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| [APC](automatic_prefix_caching.md) | [❌](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
||||
| <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
@ -351,6 +351,11 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
|
||||
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||
<!-- TODO: api enforced limits + uploading audios -->
|
||||
|
||||
#### API Enforced Limits
|
||||
|
||||
Set the maximum audio file size (in MB) that VLLM will accept, via the
|
||||
`VLLM_MAX_AUDIO_CLIP_FILESIZE_MB` environment variable. Default is 25 MB.
|
||||
|
||||
#### Extra Parameters
|
||||
|
||||
The following [sampling parameters][sampling-params] are supported.
|
||||
|
@ -1,122 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This is a demo script showing how to use the
|
||||
PrithviGeospatialMAE model with vLLM
|
||||
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
|
||||
|
||||
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
|
||||
|
||||
The requirements for running this script are:
|
||||
- Installing [terratorch, albumentations, rasterio] in your python environment
|
||||
- downloading the model weights in a 'model' folder local to the script
|
||||
(temporary measure until the proper config.json file is uploaded to HF)
|
||||
- download an input example image (India_900498_S2Hand.tif) and place it in
|
||||
the same folder with the script (or specify with the --data_file argument)
|
||||
|
||||
Run the example:
|
||||
python prithvi_geospatial_mae.py
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
model_config = """{
|
||||
"architectures": ["PrithviGeoSpatialMAE"],
|
||||
"num_classes": 0,
|
||||
"pretrained_cfg": {
|
||||
"task_args": {
|
||||
"task": "SemanticSegmentationTask",
|
||||
"model_factory": "EncoderDecoderFactory",
|
||||
"loss": "ce",
|
||||
"ignore_index": -1,
|
||||
"lr": 0.001,
|
||||
"freeze_backbone": false,
|
||||
"freeze_decoder": false,
|
||||
"plot_on_val": 10,
|
||||
"optimizer": "AdamW",
|
||||
"scheduler": "CosineAnnealingLR"
|
||||
},
|
||||
"model_args": {
|
||||
"backbone_pretrained": false,
|
||||
"backbone": "prithvi_eo_v2_300_tl",
|
||||
"decoder": "UperNetDecoder",
|
||||
"decoder_channels": 256,
|
||||
"decoder_scale_modules": true,
|
||||
"num_classes": 2,
|
||||
"rescale": true,
|
||||
"backbone_bands": [
|
||||
"BLUE",
|
||||
"GREEN",
|
||||
"RED",
|
||||
"NIR_NARROW",
|
||||
"SWIR_1",
|
||||
"SWIR_2"
|
||||
],
|
||||
"head_dropout": 0.1,
|
||||
"necks": [
|
||||
{
|
||||
"name": "SelectIndices",
|
||||
"indices": [
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ReshapeTokensToImage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"optimizer_params" : {
|
||||
"lr": 5.0e-05,
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": [1.0e-08],
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": false,
|
||||
"maximize": false,
|
||||
"capturable": false,
|
||||
"differentiable": false
|
||||
},
|
||||
"scheduler_params" : {
|
||||
"T_max": 50,
|
||||
"eta_min": 0,
|
||||
"last_epoch": -1,
|
||||
"verbose": "deprecated"
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
"""
|
||||
|
||||
# Temporarily creating the "config.json" for the model.
|
||||
# This is going to disappear once the correct config.json is available on HF
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
|
||||
) as config_file:
|
||||
config_file.write(model_config)
|
||||
|
||||
datamodule_config = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size": 16,
|
||||
@ -138,28 +43,24 @@ datamodule_config = {
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
def __init__(self):
|
||||
print("Initializing PrithviMAE model")
|
||||
self.llm = LLM(
|
||||
model=os.path.join(os.path.dirname(__file__), "./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32",
|
||||
def __init__(self, model):
|
||||
self.model = LLM(
|
||||
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
print("################ Running inference on vLLM ##############")
|
||||
# merge the inputs into one data structure
|
||||
if input_data is not None and input_data.dtype == torch.float32:
|
||||
input_data = input_data.to(torch.float16)
|
||||
input_data = input_data[0]
|
||||
|
||||
mm_data = {
|
||||
"pixel_values": torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords": torch.empty(0)
|
||||
if location_coords is None
|
||||
else location_coords,
|
||||
"pixel_values": input_data,
|
||||
"location_coords": location_coords,
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
|
||||
outputs = self.llm.encode(prompt, use_tqdm=False)
|
||||
print("################ Inference done (it took seconds) ##############")
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
|
||||
"""
|
||||
Args:
|
||||
orig_img: torch.Tensor representing original image (reference)
|
||||
with shape = (bands, H, W).
|
||||
with shape = (bands, H, W).
|
||||
channels: list of indices representing RGB channels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with shape (num_channels, height, width) for original image
|
||||
torch.Tensor with shape (num_channels, height, width)
|
||||
for original image
|
||||
"""
|
||||
|
||||
orig_img = orig_img[channels, ...]
|
||||
@ -260,10 +162,10 @@ def load_example(
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the images
|
||||
in *file_paths*.
|
||||
std: list containing std values for each band in the images
|
||||
in *file_paths*.
|
||||
mean: list containing mean values for each band in the
|
||||
images in *file_paths*.
|
||||
std: list containing std values for each band in the
|
||||
images in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
@ -308,7 +210,7 @@ def load_example(
|
||||
print(f"Could not extract timestamp for {file} ({e})")
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
@ -332,8 +234,10 @@ def run_model(
|
||||
)
|
||||
|
||||
# Build sliding window
|
||||
|
||||
batch_size = 1
|
||||
batch = torch.tensor(input_data, device="cpu")
|
||||
# batch = torch.tensor(input_data, device="cpu")
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
@ -344,18 +248,16 @@ def run_model(
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
|
||||
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
|
||||
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
# Run model
|
||||
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
@ -363,15 +265,7 @@ def run_model(
|
||||
x = datamodule.aug(x)["image"]
|
||||
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
if lightning_model:
|
||||
pred_lightning = lightning_model(
|
||||
x, temporal_coords=temporal_coords, location_coords=location_coords
|
||||
)
|
||||
pred_lightning = pred_lightning.output.detach().cpu()
|
||||
if not torch.equal(pred, pred_lightning):
|
||||
print("Inference output is not equal")
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(
|
||||
@ -403,52 +297,18 @@ def run_model(
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
model: str,
|
||||
output_dir: str,
|
||||
rgb_outputs: bool,
|
||||
input_indices: list[int] = None,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load model ---------------------------------------------------------------
|
||||
|
||||
model_obj = PrithviMAE()
|
||||
model_obj = PrithviMAE(model=model)
|
||||
datamodule = generate_datamodule()
|
||||
img_size = 256 # Size of Sen1Floods11
|
||||
|
||||
# Loading data -------------------------------------------------------------
|
||||
img_size = 512 # Size of Sen1Floods11
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||
file_paths=[data_file],
|
||||
@ -460,8 +320,6 @@ def main(
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
# Running model ------------------------------------------------------------
|
||||
|
||||
channels = [
|
||||
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
@ -469,7 +327,6 @@ def main(
|
||||
pred = run_model(
|
||||
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
|
||||
)
|
||||
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
@ -487,6 +344,7 @@ def main(
|
||||
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||
channels=channels,
|
||||
)
|
||||
rgb_orig = rgb_orig.to(torch.float32)
|
||||
|
||||
pred[pred == 0.0] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
@ -503,9 +361,10 @@ def main(
|
||||
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
|
||||
rgb_file = os.path.join(
|
||||
output_dir,
|
||||
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
|
||||
f"original_rgb_{name_suffix}.tiff",
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
@ -515,6 +374,42 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
help="Path to a checkpoint file to load from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="""
|
||||
0-based indices of the six Prithvi channels to be selected from the input.
|
||||
By default selects [1,2,3,8,11,12] for S2L1C data.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
@ -93,6 +93,7 @@ ensure_python_library_installed() {
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM # prevent re-entrancy
|
||||
pkill -9 -f "disagg_proxy_p2p_nccl_xpyd.py"
|
||||
kill -- -$$ # negative PID == "this whole process-group"
|
||||
wait # reap children so we don't leave zombies
|
||||
exit 0
|
||||
|
@ -72,7 +72,6 @@ line-length = 80
|
||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
# Python 3.8 typing - skip utils for ROCm
|
||||
"vllm/utils/__init__.py" = ["UP006", "UP035"]
|
||||
|
@ -33,7 +33,7 @@ pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata; python_version < '3.10'
|
||||
mistral_common[opencv] >= 1.8.0
|
||||
mistral_common[image,audio] >= 1.8.2
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
@ -28,7 +28,7 @@ torchvision==0.22.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
@ -6,6 +6,10 @@ accelerate==1.0.1
|
||||
# via
|
||||
# lm-eval
|
||||
# peft
|
||||
aenum==3.1.16
|
||||
# via lightly
|
||||
affine==2.4.0
|
||||
# via rasterio
|
||||
aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.10.11
|
||||
@ -21,8 +25,18 @@ aiosignal==1.3.1
|
||||
# via
|
||||
# aiohttp
|
||||
# ray
|
||||
albucore==0.0.16
|
||||
# via terratorch
|
||||
albumentations==1.4.6
|
||||
# via terratorch
|
||||
alembic==1.16.4
|
||||
# via mlflow
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.6.2.post1
|
||||
# via
|
||||
# httpx
|
||||
@ -34,10 +48,12 @@ arrow==1.3.0
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# fiona
|
||||
# hypothesis
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# pytest-subtests
|
||||
# rasterio
|
||||
# referencing
|
||||
audioread==3.0.1
|
||||
# via librosa
|
||||
@ -46,9 +62,13 @@ backoff==2.2.1
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
bitsandbytes==0.46.1
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightning
|
||||
black==24.10.0
|
||||
# via datamodel-code-generator
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
blobfile==3.0.0
|
||||
# via -r requirements/test.in
|
||||
bm25s==0.2.13
|
||||
@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
|
||||
buildkite-test-collector==0.1.9
|
||||
# via -r requirements/test.in
|
||||
cachetools==5.5.2
|
||||
# via google-auth
|
||||
# via
|
||||
# google-auth
|
||||
# mlflow-skinny
|
||||
certifi==2024.8.30
|
||||
# via
|
||||
# fiona
|
||||
# httpcore
|
||||
# httpx
|
||||
# lightly
|
||||
# pyogrio
|
||||
# pyproj
|
||||
# rasterio
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
@ -79,11 +106,28 @@ charset-normalizer==3.4.0
|
||||
click==8.1.7
|
||||
# via
|
||||
# black
|
||||
# click-plugins
|
||||
# cligj
|
||||
# fiona
|
||||
# flask
|
||||
# jiwer
|
||||
# mlflow-skinny
|
||||
# nltk
|
||||
# rasterio
|
||||
# ray
|
||||
# schemathesis
|
||||
# typer
|
||||
# uvicorn
|
||||
click-plugins==1.1.1.2
|
||||
# via
|
||||
# fiona
|
||||
# rasterio
|
||||
cligj==0.7.2
|
||||
# via
|
||||
# fiona
|
||||
# rasterio
|
||||
cloudpickle==3.1.1
|
||||
# via mlflow-skinny
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# sacrebleu
|
||||
@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
|
||||
# via ray
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
databricks-sdk==0.59.0
|
||||
# via mlflow-skinny
|
||||
datamodel-code-generator==0.26.3
|
||||
# via -r requirements/test.in
|
||||
dataproperty==1.0.1
|
||||
@ -122,13 +168,21 @@ distlib==0.3.9
|
||||
# via virtualenv
|
||||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
docker==7.1.0
|
||||
# via mlflow
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
einops==0.8.0
|
||||
docstring-parser==0.17.0
|
||||
# via jsonargparse
|
||||
efficientnet-pytorch==0.7.1
|
||||
# via segmentation-models-pytorch
|
||||
einops==0.8.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# mamba-ssm
|
||||
# terratorch
|
||||
# torchgeo
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
@ -141,6 +195,8 @@ eval-type-backport==0.2.2
|
||||
# via mteb
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastapi==0.116.1
|
||||
# via mlflow-skinny
|
||||
fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
@ -156,6 +212,10 @@ filelock==3.16.1
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fiona==1.10.1
|
||||
# via torchgeo
|
||||
flask==3.1.1
|
||||
# via mlflow
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
@ -173,6 +233,8 @@ fsspec==2024.9.0
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# torch
|
||||
ftfy==6.3.1
|
||||
# via open-clip-torch
|
||||
@ -180,18 +242,41 @@ genai-perf==0.0.8
|
||||
# via -r requirements/test.in
|
||||
genson==1.3.0
|
||||
# via datamodel-code-generator
|
||||
geopandas==1.0.1
|
||||
# via terratorch
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.44
|
||||
# via mlflow-skinny
|
||||
google-api-core==2.24.2
|
||||
# via opencensus
|
||||
google-auth==2.40.2
|
||||
# via google-api-core
|
||||
# via
|
||||
# databricks-sdk
|
||||
# google-api-core
|
||||
googleapis-common-protos==1.70.0
|
||||
# via google-api-core
|
||||
graphene==3.4.3
|
||||
# via mlflow
|
||||
graphql-core==3.2.6
|
||||
# via hypothesis-graphql
|
||||
# via
|
||||
# graphene
|
||||
# graphql-relay
|
||||
# hypothesis-graphql
|
||||
graphql-relay==3.2.0
|
||||
# via graphene
|
||||
greenlet==3.2.3
|
||||
# via sqlalchemy
|
||||
grpcio==1.71.0
|
||||
# via ray
|
||||
gunicorn==23.0.0
|
||||
# via mlflow
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
h5py==3.13.0
|
||||
# via terratorch
|
||||
harfile==0.3.0
|
||||
# via schemathesis
|
||||
hf-xet==1.1.3
|
||||
@ -204,7 +289,7 @@ httpx==0.27.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
huggingface-hub==0.33.0
|
||||
huggingface-hub==0.33.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
@ -212,13 +297,19 @@ huggingface-hub==0.33.0
|
||||
# evaluate
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# terratorch
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
humanize==4.11.0
|
||||
# via runai-model-streamer
|
||||
hydra-core==1.3.2
|
||||
# via
|
||||
# lightly
|
||||
# lightning
|
||||
hypothesis==6.131.0
|
||||
# via
|
||||
# hypothesis-graphql
|
||||
@ -236,6 +327,14 @@ idna==3.10
|
||||
# jsonschema
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.37.0
|
||||
# via scikit-image
|
||||
importlib-metadata==8.7.0
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-api
|
||||
importlib-resources==6.5.2
|
||||
# via typeshed-client
|
||||
inflect==5.6.2
|
||||
# via datamodel-code-generator
|
||||
iniconfig==2.0.0
|
||||
@ -244,9 +343,13 @@ isoduration==20.11.0
|
||||
# via jsonschema
|
||||
isort==5.13.2
|
||||
# via datamodel-code-generator
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# datamodel-code-generator
|
||||
# flask
|
||||
# mlflow
|
||||
# torch
|
||||
jiwer==3.0.5
|
||||
# via -r requirements/test.in
|
||||
@ -259,6 +362,10 @@ joblib==1.4.2
|
||||
# librosa
|
||||
# nltk
|
||||
# scikit-learn
|
||||
jsonargparse==4.35.0
|
||||
# via
|
||||
# lightning
|
||||
# terratorch
|
||||
jsonlines==4.0.0
|
||||
# via lm-eval
|
||||
jsonpointer==3.0.0
|
||||
@ -277,12 +384,33 @@ kaleido==0.2.1
|
||||
# via genai-perf
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
kornia==0.8.1
|
||||
# via torchgeo
|
||||
kornia-rs==0.1.9
|
||||
# via kornia
|
||||
lazy-loader==0.4
|
||||
# via librosa
|
||||
# via
|
||||
# librosa
|
||||
# scikit-image
|
||||
libnacl==2.1.0
|
||||
# via tensorizer
|
||||
librosa==0.10.2.post1
|
||||
# via -r requirements/test.in
|
||||
lightly==1.5.20
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
lightly-utils==0.0.2
|
||||
# via lightly
|
||||
lightning==2.5.1.post0
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
lightning-utilities==0.14.3
|
||||
# via
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.8
|
||||
@ -291,16 +419,27 @@ lxml==5.3.0
|
||||
# via
|
||||
# blobfile
|
||||
# sacrebleu
|
||||
mako==1.3.10
|
||||
# via alembic
|
||||
mamba-ssm==2.2.4
|
||||
# via -r requirements/test.in
|
||||
markdown==3.8.2
|
||||
# via mlflow
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.1
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# mako
|
||||
# werkzeug
|
||||
matplotlib==3.9.2
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightning
|
||||
# mlflow
|
||||
# pycocotools
|
||||
# torchgeo
|
||||
mbstrdecoder==1.1.3
|
||||
# via
|
||||
# dataproperty
|
||||
@ -308,8 +447,12 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.8.0
|
||||
mistral-common==1.8.2
|
||||
# via -r requirements/test.in
|
||||
mlflow==2.22.0
|
||||
# via terratorch
|
||||
mlflow-skinny==2.22.0
|
||||
# via mlflow
|
||||
more-itertools==10.5.0
|
||||
# via lm-eval
|
||||
mpmath==1.3.0
|
||||
@ -328,10 +471,14 @@ multiprocess==0.70.16
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
munch==4.0.0
|
||||
# via pretrainedmodels
|
||||
mypy-extensions==1.0.0
|
||||
# via black
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.11.1.3
|
||||
# via mamba-ssm
|
||||
nltk==3.9.1
|
||||
@ -348,6 +495,8 @@ numpy==1.26.4
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
# albucore
|
||||
# albumentations
|
||||
# bitsandbytes
|
||||
# bm25s
|
||||
# contourpy
|
||||
@ -358,9 +507,15 @@ numpy==1.26.4
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# geopandas
|
||||
# h5py
|
||||
# imageio
|
||||
# librosa
|
||||
# lightly
|
||||
# lightly-utils
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# mlflow
|
||||
# mteb
|
||||
# numba
|
||||
# numexpr
|
||||
@ -368,18 +523,30 @@ numpy==1.26.4
|
||||
# pandas
|
||||
# patsy
|
||||
# peft
|
||||
# pycocotools
|
||||
# pyogrio
|
||||
# rasterio
|
||||
# rioxarray
|
||||
# rouge-score
|
||||
# runai-model-streamer
|
||||
# sacrebleu
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# segmentation-models-pytorch
|
||||
# shapely
|
||||
# soxr
|
||||
# statsmodels
|
||||
# tensorboardx
|
||||
# tensorizer
|
||||
# tifffile
|
||||
# torchgeo
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
# xarray
|
||||
nvidia-cublas-cu12==12.8.3.14
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.8.55
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via
|
||||
# hydra-core
|
||||
# lightning
|
||||
open-clip-torch==2.32.0
|
||||
# via -r requirements/test.in
|
||||
opencensus==0.11.4
|
||||
@ -426,7 +597,18 @@ opencensus-context==0.1.3
|
||||
opencv-python-headless==4.11.0.86
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albucore
|
||||
# albumentations
|
||||
# mistral-common
|
||||
opentelemetry-api==1.35.0
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
opentelemetry-sdk==1.35.0
|
||||
# via mlflow-skinny
|
||||
opentelemetry-semantic-conventions==0.56b0
|
||||
# via opentelemetry-sdk
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
@ -435,26 +617,44 @@ packaging==24.2
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# geopandas
|
||||
# gunicorn
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# kornia
|
||||
# lazy-loader
|
||||
# lightning
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# matplotlib
|
||||
# mlflow-skinny
|
||||
# peft
|
||||
# plotly
|
||||
# pooch
|
||||
# pyogrio
|
||||
# pytest
|
||||
# pytest-rerunfailures
|
||||
# pytorch-lightning
|
||||
# ray
|
||||
# rioxarray
|
||||
# scikit-image
|
||||
# statsmodels
|
||||
# tensorboardx
|
||||
# torchmetrics
|
||||
# transformers
|
||||
# typepy
|
||||
# xarray
|
||||
pandas==2.2.3
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# geopandas
|
||||
# mlflow
|
||||
# statsmodels
|
||||
# torchgeo
|
||||
# xarray
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
pathvalidate==3.2.1
|
||||
@ -468,9 +668,14 @@ peft==0.13.2
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# genai-perf
|
||||
# imageio
|
||||
# lightly-utils
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# scikit-image
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# torchgeo
|
||||
# torchvision
|
||||
platformdirs==4.3.6
|
||||
# via
|
||||
@ -489,6 +694,8 @@ portalocker==2.10.1
|
||||
# via sacrebleu
|
||||
pqdm==0.2.0
|
||||
# via -r requirements/test.in
|
||||
pretrainedmodels==0.7.4
|
||||
# via segmentation-models-pytorch
|
||||
prometheus-client==0.22.0
|
||||
# via ray
|
||||
propcache==0.2.0
|
||||
@ -499,8 +706,10 @@ protobuf==5.28.3
|
||||
# via
|
||||
# google-api-core
|
||||
# googleapis-common-protos
|
||||
# mlflow-skinny
|
||||
# proto-plus
|
||||
# ray
|
||||
# tensorboardx
|
||||
# tensorizer
|
||||
psutil==6.1.0
|
||||
# via
|
||||
@ -515,6 +724,7 @@ pyarrow==18.0.0
|
||||
# via
|
||||
# datasets
|
||||
# genai-perf
|
||||
# mlflow
|
||||
pyasn1==0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycocotools==2.0.8
|
||||
# via terratorch
|
||||
pycountry==24.6.1
|
||||
# via pydantic-extra-types
|
||||
pycparser==2.22
|
||||
@ -532,8 +744,12 @@ pycryptodomex==3.22.0
|
||||
pydantic==2.11.5
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albumentations
|
||||
# datamodel-code-generator
|
||||
# fastapi
|
||||
# lightly
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# pydantic-extra-types
|
||||
# ray
|
||||
@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
|
||||
# via mistral-common
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyogrio==0.11.0
|
||||
# via geopandas
|
||||
pyparsing==3.2.0
|
||||
# via matplotlib
|
||||
# via
|
||||
# matplotlib
|
||||
# rasterio
|
||||
pyproj==3.7.1
|
||||
# via
|
||||
# geopandas
|
||||
# rioxarray
|
||||
# torchgeo
|
||||
pyrate-limiter==3.7.0
|
||||
# via schemathesis
|
||||
pystemmer==3.0.0
|
||||
# via mteb
|
||||
pytablewriter==1.2.0
|
||||
# via lm-eval
|
||||
pytest==8.3.3
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# buildkite-test-collector
|
||||
@ -564,6 +789,7 @@ pytest==8.3.3
|
||||
# pytest-subtests
|
||||
# pytest-timeout
|
||||
# schemathesis
|
||||
# terratorch
|
||||
pytest-asyncio==0.24.0
|
||||
# via -r requirements/test.in
|
||||
pytest-forked==1.6.0
|
||||
@ -578,15 +804,23 @@ pytest-subtests==0.14.1
|
||||
# via schemathesis
|
||||
pytest-timeout==2.3.1
|
||||
# via -r requirements/test.in
|
||||
python-box==7.3.2
|
||||
# via terratorch
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# arrow
|
||||
# botocore
|
||||
# graphene
|
||||
# lightly
|
||||
# matplotlib
|
||||
# pandas
|
||||
# typepy
|
||||
python-rapidjson==1.20
|
||||
# via tritonclient
|
||||
pytorch-lightning==2.5.2
|
||||
# via
|
||||
# lightly
|
||||
# lightning
|
||||
pytrec-eval-terrier==0.5.7
|
||||
# via mteb
|
||||
pytz==2024.2
|
||||
@ -596,11 +830,17 @@ pytz==2024.2
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# albumentations
|
||||
# datamodel-code-generator
|
||||
# datasets
|
||||
# genai-perf
|
||||
# huggingface-hub
|
||||
# jsonargparse
|
||||
# lightning
|
||||
# mlflow-skinny
|
||||
# omegaconf
|
||||
# peft
|
||||
# pytorch-lightning
|
||||
# ray
|
||||
# responses
|
||||
# schemathesis
|
||||
@ -609,6 +849,11 @@ pyyaml==6.0.2
|
||||
# vocos
|
||||
rapidfuzz==3.12.1
|
||||
# via jiwer
|
||||
rasterio==1.4.3
|
||||
# via
|
||||
# rioxarray
|
||||
# terratorch
|
||||
# torchgeo
|
||||
ray==2.43.0
|
||||
# via -r requirements/test.in
|
||||
redis==5.2.0
|
||||
@ -627,12 +872,16 @@ regex==2024.9.11
|
||||
requests==2.32.3
|
||||
# via
|
||||
# buildkite-test-collector
|
||||
# databricks-sdk
|
||||
# datasets
|
||||
# docker
|
||||
# evaluate
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# lightly
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# pooch
|
||||
# ray
|
||||
@ -650,8 +899,11 @@ rfc3987==1.3.8
|
||||
rich==13.9.4
|
||||
# via
|
||||
# genai-perf
|
||||
# lightning
|
||||
# mteb
|
||||
# typer
|
||||
rioxarray==0.19.0
|
||||
# via terratorch
|
||||
rouge-score==0.1.2
|
||||
# via lm-eval
|
||||
rpds-py==0.20.1
|
||||
@ -660,6 +912,8 @@ rpds-py==0.20.1
|
||||
# referencing
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
rtree==1.4.0
|
||||
# via torchgeo
|
||||
runai-model-streamer==0.11.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.11.0
|
||||
@ -677,21 +931,32 @@ safetensors==0.4.5
|
||||
# transformers
|
||||
schemathesis==3.39.15
|
||||
# via -r requirements/test.in
|
||||
scikit-image==0.25.2
|
||||
# via albumentations
|
||||
scikit-learn==1.5.2
|
||||
# via
|
||||
# albumentations
|
||||
# librosa
|
||||
# lm-eval
|
||||
# mlflow
|
||||
# mteb
|
||||
# sentence-transformers
|
||||
scipy==1.13.1
|
||||
# via
|
||||
# albumentations
|
||||
# bm25s
|
||||
# librosa
|
||||
# mlflow
|
||||
# mteb
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
segmentation-models-pytorch==0.4.0
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
sentence-transformers==3.2.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -700,21 +965,30 @@ sentencepiece==0.2.0
|
||||
# via mistral-common
|
||||
setuptools==77.0.3
|
||||
# via
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# pytablewriter
|
||||
# torch
|
||||
# triton
|
||||
shapely==2.1.1
|
||||
# via
|
||||
# geopandas
|
||||
# torchgeo
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via
|
||||
# junit-xml
|
||||
# lightly
|
||||
# opencensus
|
||||
# python-dateutil
|
||||
# rfc3339-validator
|
||||
# rouge-score
|
||||
# segmentation-models-pytorch
|
||||
smart-open==7.1.0
|
||||
# via ray
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
@ -725,12 +999,22 @@ soundfile==0.12.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# librosa
|
||||
# mistral-common
|
||||
soxr==0.5.0.post1
|
||||
# via librosa
|
||||
# via
|
||||
# librosa
|
||||
# mistral-common
|
||||
sqlalchemy==2.0.41
|
||||
# via
|
||||
# alembic
|
||||
# mlflow
|
||||
sqlitedict==2.1.0
|
||||
# via lm-eval
|
||||
sqlparse==0.5.3
|
||||
# via mlflow-skinny
|
||||
starlette==0.46.2
|
||||
# via
|
||||
# fastapi
|
||||
# schemathesis
|
||||
# starlette-testclient
|
||||
starlette-testclient==0.4.1
|
||||
@ -751,18 +1035,29 @@ tenacity==9.0.0
|
||||
# via
|
||||
# lm-eval
|
||||
# plotly
|
||||
tensorboardx==2.6.4
|
||||
# via lightning
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
terratorch==1.1rc2
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tifffile==2025.3.30
|
||||
# via
|
||||
# scikit-image
|
||||
# terratorch
|
||||
tiktoken==0.7.0
|
||||
# via
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
timm==1.0.15
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# open-clip-torch
|
||||
# segmentation-models-pytorch
|
||||
# terratorch
|
||||
# torchgeo
|
||||
tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -776,18 +1071,28 @@ torch==2.7.1+cu128
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# efficientnet-pytorch
|
||||
# encodec
|
||||
# fastsafetensors
|
||||
# kornia
|
||||
# lightly
|
||||
# lightning
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# mteb
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# pretrainedmodels
|
||||
# pytorch-lightning
|
||||
# runai-model-streamer
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# terratorch
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchgeo
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
@ -796,22 +1101,40 @@ torchaudio==2.7.1+cu128
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchgeo==0.7.0
|
||||
# via terratorch
|
||||
torchmetrics==1.7.4
|
||||
# via
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# terratorch
|
||||
# torchgeo
|
||||
torchvision==0.22.1+cu128
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightly
|
||||
# open-clip-torch
|
||||
# pretrainedmodels
|
||||
# segmentation-models-pytorch
|
||||
# terratorch
|
||||
# timm
|
||||
# torchgeo
|
||||
tqdm==4.66.6
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# lightly
|
||||
# lightning
|
||||
# lm-eval
|
||||
# mteb
|
||||
# nltk
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# pqdm
|
||||
# pretrainedmodels
|
||||
# pytorch-lightning
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
@ -843,18 +1166,34 @@ typer==0.15.2
|
||||
# via fastsafetensors
|
||||
types-python-dateutil==2.9.0.20241206
|
||||
# via arrow
|
||||
typeshed-client==2.8.2
|
||||
# via jsonargparse
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# albumentations
|
||||
# alembic
|
||||
# fastapi
|
||||
# graphene
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# lightning
|
||||
# lightning-utilities
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# opentelemetry-api
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydantic-extra-types
|
||||
# pytorch-lightning
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# torchgeo
|
||||
# typer
|
||||
# typeshed-client
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
@ -866,9 +1205,13 @@ urllib3==2.2.3
|
||||
# via
|
||||
# blobfile
|
||||
# botocore
|
||||
# docker
|
||||
# lightly
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
uvicorn==0.35.0
|
||||
# via mlflow-skinny
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements/test.in
|
||||
virtualenv==20.31.2
|
||||
@ -880,11 +1223,15 @@ wcwidth==0.2.13
|
||||
webcolors==24.11.1
|
||||
# via jsonschema
|
||||
werkzeug==3.1.3
|
||||
# via schemathesis
|
||||
# via
|
||||
# flask
|
||||
# schemathesis
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
wrapt==1.17.2
|
||||
# via smart-open
|
||||
xarray==2025.7.1
|
||||
# via rioxarray
|
||||
xxhash==3.5.0
|
||||
# via
|
||||
# datasets
|
||||
@ -893,5 +1240,7 @@ yarl==1.17.1
|
||||
# via
|
||||
# aiohttp
|
||||
# schemathesis
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
350
tests/compile/piecewise/test_multiple_graphs.py
Normal file
350
tests/compile/piecewise/test_multiple_graphs.py
Normal file
@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test (piecewise) compilation with a simple model where multiple submodules
|
||||
are compiled and graph captured separately.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
HIDDEN_SIZE = 1024
|
||||
RANDOM_SEED = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||
self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False)
|
||||
self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
# Initialize to same weights for testing
|
||||
nn.init.xavier_normal_(
|
||||
self.pre_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(
|
||||
self.post_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
|
||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_f32 = x.float()
|
||||
return (x_f32 * torch.rsqrt(
|
||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
|
||||
self.rms_norm_weight).to(x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.pre_attn(x)
|
||||
x = self.rms_norm_ref(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.rms_norm_ref(x)
|
||||
x = self.post_attn(x)
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(mlp_size, hidden_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttentionTwo(CompiledAttention):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x) + x
|
||||
|
||||
|
||||
@ignore_torch_compile
|
||||
class SimpleModelWithTwoGraphs(ParentModel):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# Test will fail without set_model_tag here with error:
|
||||
# "ValueError: too many values to unpack (expected 3)"
|
||||
# This is because CompiledAttention and CompiledAttentionTwo
|
||||
# have different implmentations but the same torch.compile
|
||||
# cache dir will be used as default prefix is 'model_tag'
|
||||
with set_model_tag("attn_one"):
|
||||
self.attn_one = CompiledAttention(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.attn_one",
|
||||
)
|
||||
with set_model_tag("attn_two"):
|
||||
self.attn_two = CompiledAttentionTwo(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.attn_two",
|
||||
)
|
||||
|
||||
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz = x.shape[0]
|
||||
# CUDAGraph expects same tensor addresses for each run
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_one(self.hidden_states[:bsz])
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_two(self.hidden_states[:bsz])
|
||||
return x
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
assert VLLM_USE_V1
|
||||
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
# first run is for compile
|
||||
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
# run cudagraph captured sizes
|
||||
mod_A(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_A(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# First run is for compile
|
||||
model(inputs)
|
||||
|
||||
# Run CUDAGraph captured sizes
|
||||
model(inputs[:2])
|
||||
model(inputs[:1])
|
||||
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
outputs = []
|
||||
|
||||
# piecewise compile
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
|
||||
# Pre-allocate memory for CUDAGraph which expects
|
||||
# static tensor addresses
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, ))
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
))
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
assert torch.allclose(outputs[0], outputs[1])
|
||||
|
||||
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
|
||||
assert torch.equal(outputs[0], outputs[2])
|
@ -69,8 +69,9 @@ def run_test(more_args):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 currently only supported on CUDA and TPU")
|
||||
and not current_platform.is_tpu()
|
||||
and not current_platform.is_xpu(),
|
||||
reason="V1 currently only supported on CUDA, XPU and TPU")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically these adapters use a different base model,
|
||||
# but we're not testing generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
PA_NAME = "swapnilbp/llama_tweet_ptune"
|
||||
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
|
||||
# need to change to match the prompt adapter
|
||||
PA_NUM_VIRTUAL_TOKENS = 8
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
|
||||
@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_pa_files():
|
||||
return snapshot_download(repo_id=PA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
zephyr_pa_files):
|
||||
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
# pa config
|
||||
"--enable-prompt-adapter",
|
||||
"--prompt-adapters",
|
||||
f"zephyr-pa={zephyr_pa_files}",
|
||||
f"zephyr-pa2={zephyr_pa_files}",
|
||||
"--max-prompt-adapters",
|
||||
"2",
|
||||
"--max-prompt-adapter-token",
|
||||
"128",
|
||||
]
|
||||
|
||||
|
||||
@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
original_value = os.environ.get('VLLM_USE_V1')
|
||||
os.environ['VLLM_USE_V1'] = '0'
|
||||
try:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
finally:
|
||||
# Restore original env value
|
||||
if original_value is None:
|
||||
os.environ.pop('VLLM_USE_V1', None)
|
||||
else:
|
||||
os.environ['VLLM_USE_V1'] = original_value
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -110,14 +103,11 @@ async def client(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras, then test prompt adapters
|
||||
"model_name,num_virtual_tokens",
|
||||
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
|
||||
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
|
||||
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||
num_virtual_tokens: int):
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5,
|
||||
prompt_tokens=6 + num_virtual_tokens,
|
||||
total_tokens=11 + num_virtual_tokens)
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras, then test prompt adapters
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora and 1 pa hereafter
|
||||
# just test 1 lora
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test both text and token IDs
|
||||
|
@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import default_server_args # noqa: F401
|
||||
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||
from .test_completion import zephyr_lora_files # noqa: F401
|
||||
from .test_completion import zephyr_pa_files # noqa: F401
|
||||
from .test_completion import MODEL_NAME
|
||||
|
||||
|
||||
|
@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None)
|
||||
lora_modules=None)
|
||||
await serving_models.init_static_loras()
|
||||
|
||||
return serving_models
|
||||
|
@ -6,6 +6,10 @@ from collections.abc import Mapping
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
|
||||
SpecialTokens)
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
@ -21,6 +25,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
encode_video_base64)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import VLLM_PATH
|
||||
@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
|
||||
|
||||
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
|
||||
mistral_tokenizer):
|
||||
messages = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type":
|
||||
"thinking",
|
||||
"closed":
|
||||
True,
|
||||
"thinking":
|
||||
"Only return the answer when you are confident."
|
||||
}]
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is 2+2?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
}, {
|
||||
"type": "thinking",
|
||||
"closed": True,
|
||||
"thinking": "2+2 = 4"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "The answer is 4.",
|
||||
}],
|
||||
}]
|
||||
|
||||
conversation_with_thinking, _ = parse_chat_messages(
|
||||
messages,
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
expected_conversation = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "Only return the answer when you are confident."
|
||||
}],
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "What is 2+2?"
|
||||
}],
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "2+2 = 4"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The answer is 4."
|
||||
},
|
||||
]
|
||||
}]
|
||||
|
||||
assert conversation_with_thinking == expected_conversation
|
||||
|
||||
|
||||
def test_apply_mistral_chat_template_thinking_chunk():
|
||||
# Moved import here to avoid yapf and isort conflicts
|
||||
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
|
||||
messages = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type":
|
||||
"thinking",
|
||||
"closed":
|
||||
True,
|
||||
"thinking":
|
||||
"Only return the answer when you are confident."
|
||||
}]
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is 2+2?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
}, {
|
||||
"type": "thinking",
|
||||
"closed": True,
|
||||
"thinking": "2+2 = 4"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "The answer is 4.",
|
||||
}],
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Thanks, what is 3+3?"
|
||||
}]
|
||||
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507")
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
|
||||
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
|
||||
messages,
|
||||
chat_template=None,
|
||||
tools=None)
|
||||
|
||||
string_tokens = mistral_tokenizer.mistral.decode(
|
||||
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
|
||||
expected_tokens = (
|
||||
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
|
||||
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
|
||||
r"[INST]What is 2+2?[/INST]"
|
||||
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
|
||||
r"[INST]Thanks, what is 3+3?[/INST]")
|
||||
|
||||
assert string_tokens == expected_tokens
|
||||
|
@ -23,6 +23,8 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ....utils import multi_gpu_test
|
||||
|
||||
# Sample prompts for testing
|
||||
PROMPTS: list[str] = [
|
||||
"Hello, my name is",
|
||||
@ -541,6 +543,7 @@ def run_reduced_model(model_path: str,
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"original_model_name,text_layers,num_experts,vision_layers,",
|
||||
[("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)])
|
||||
|
63
tests/models/multimodal/pooling/test_prithvi_mae.py
Normal file
63
tests/models/multimodal/pooling/test_prithvi_mae.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
|
||||
def generate_test_mm_data():
|
||||
mm_data = {
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
return mm_data
|
||||
|
||||
|
||||
def _run_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
) -> None:
|
||||
|
||||
prompt = [
|
||||
{
|
||||
# This model deals with no text input
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": generate_test_mm_data(),
|
||||
} for _ in range(10)
|
||||
]
|
||||
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model,
|
||||
task="embed",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
) as vllm_model,
|
||||
):
|
||||
vllm_model.encode(prompt)
|
||||
|
||||
|
||||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
) -> None:
|
||||
_run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
)
|
@ -1,48 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "bigscience/bloomz-560m"
|
||||
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
|
||||
|
||||
|
||||
def do_sample(llm, pa_name: str, pa_id: int):
|
||||
|
||||
prompts = [
|
||||
"Tweet text : @nationalgridus I have no water and the bill is \
|
||||
current and paid. Can you do something about this? Label : ",
|
||||
"Tweet text : @nationalgridus Looks good thanks! Label : "
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0.0,
|
||||
max_tokens=3,
|
||||
stop_token_ids=[3])
|
||||
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
prompt_adapter_request=PromptAdapterRequest(
|
||||
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
|
||||
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_twitter_prompt_adapter(enforce_eager: bool):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_adapter=True,
|
||||
max_prompt_adapter_token=8)
|
||||
|
||||
expected_output = ['complaint', 'no complaint']
|
||||
|
||||
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
|
@ -1,56 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "bigscience/bloomz-560m"
|
||||
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
|
||||
pa_path2 = 'swapnilbp/angry_tweet_ptune'
|
||||
|
||||
|
||||
def do_sample(engine):
|
||||
|
||||
prompts = [
|
||||
("Tweet text: I have complaints! Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
|
||||
("Tweet text: I have no problems Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
|
||||
("Tweet text: I have complaints! Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3), None),
|
||||
("Tweet text: I have no problems Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("complain", 3, pa_path, 8)),
|
||||
]
|
||||
|
||||
request_id = 0
|
||||
results = set()
|
||||
while prompts or engine.has_unfinished_requests():
|
||||
if prompts:
|
||||
prompt, sampling_params, pa_request = prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_adapter_request=pa_request)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
results.add(request_output.outputs[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def test_multi_prompt_adapters():
|
||||
engine_args = EngineArgs(model=MODEL_PATH,
|
||||
max_prompt_adapters=3,
|
||||
enable_prompt_adapter=True,
|
||||
max_prompt_adapter_token=8)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
expected_output = {
|
||||
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
|
||||
}
|
||||
assert do_sample(engine) == expected_output
|
@ -1,64 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
|
||||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
|
||||
|
||||
def do_sample(engine):
|
||||
|
||||
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
|
||||
|
||||
# first prompt with a prompt adapter and second without adapter
|
||||
prompts = [
|
||||
(prompt_text,
|
||||
SamplingParams(temperature=0.0, max_tokens=100,
|
||||
stop=["[/assistant]"]),
|
||||
PromptAdapterRequest("hate_speech", 1, pa_path,
|
||||
8), LoRARequest("sql_test", 1, lora_path)),
|
||||
(prompt_text,
|
||||
SamplingParams(temperature=0.0, max_tokens=100,
|
||||
stop=["[/assistant]"]), None,
|
||||
LoRARequest("sql_test", 1, lora_path)),
|
||||
]
|
||||
|
||||
request_id = 0
|
||||
results = set()
|
||||
while prompts or engine.has_unfinished_requests():
|
||||
if prompts:
|
||||
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_adapter_request=pa_request,
|
||||
lora_request=lora_request)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
results.add(request_output.outputs[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def test_lora_prompt_adapter():
|
||||
engine_args = EngineArgs(model=MODEL_PATH,
|
||||
enable_prompt_adapter=True,
|
||||
enable_lora=True,
|
||||
max_num_seqs=60,
|
||||
max_prompt_adapter_token=8)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
result = do_sample(engine)
|
||||
|
||||
expected_output = {
|
||||
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
|
||||
}
|
||||
assert result == expected_output
|
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||
Tekkenizer)
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
parser_name = "mistral"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507")
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": "This is content",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output": "[THINK]This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
THINK_NO_END = {
|
||||
"output": "[THINK]This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning_content": "",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY_STREAMING = {
|
||||
"output": "",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
# Streaming cannot handle new lines at the beginning of the output
|
||||
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||
# We cannot know if the text before [THINK] is reasoning content
|
||||
# or not.
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning_content": "\nThis is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_CONTENT,
|
||||
id="no_content_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_REASONING_STREAMING,
|
||||
id="no_reasoning_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
THINK_NO_END,
|
||||
id="think_no_end",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
THINK_NO_END,
|
||||
id="think_no_end_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY,
|
||||
id="empty",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_STREAMING,
|
||||
id="empty_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NEW_LINE,
|
||||
id="new_line",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NEW_LINE_STREAMING,
|
||||
id="new_line_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_mistral_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
mistral_tokenizer: MistralTokenizer,
|
||||
):
|
||||
output = param_dict["output"]
|
||||
|
||||
index_think = output.find("[THINK]")
|
||||
len_think = len("[THINK]")
|
||||
index_end_think = output.find("[/THINK]")
|
||||
len_end_think = len("[/THINK]")
|
||||
|
||||
# encode everything to tokens ids
|
||||
output_tokens = []
|
||||
if index_think != -1:
|
||||
output_before_think = output[:index_think]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_before_think, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK]
|
||||
|
||||
if index_end_think != -1:
|
||||
output_middle = output[index_think + len_think:index_end_think]
|
||||
output_after_think = output[index_end_think + len_end_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_middle, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_after_think, False, False)
|
||||
else:
|
||||
output_middle = output[index_think + len_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_middle, False, False)
|
||||
elif index_end_think != -1:
|
||||
output_before_think = output[:index_end_think]
|
||||
output_after_think = output[index_end_think + len_end_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_before_think, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_after_think, False, False)
|
||||
else:
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output, False, False)
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(mistral_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction_mistral(parser,
|
||||
output_tokens,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
is_reasoning_end = parser.is_reasoning_end(output_tokens)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == mistral_tokenizer.tokenizer.encode(
|
||||
param_dict["content"], bos=False, eos=False)
|
||||
else:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == []
|
@ -6,6 +6,7 @@ from typing import Optional, Union
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
||||
@ -54,6 +55,32 @@ def run_reasoning_extraction(
|
||||
return reasoning, content
|
||||
|
||||
|
||||
def run_reasoning_extraction_mistral(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_output: list[int],
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
streaming: bool = False,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
assert isinstance(reasoning_parser.model_tokenizer,
|
||||
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||
if streaming:
|
||||
reconstructor = run_reasoning_extraction_streaming_mistral(
|
||||
reasoning_parser,
|
||||
model_output,
|
||||
request,
|
||||
)
|
||||
return (
|
||||
reconstructor.reasoning_content,
|
||||
reconstructor.other_content or None,
|
||||
)
|
||||
else:
|
||||
str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||
model_output)
|
||||
reasoning, content = run_reasoning_extraction_nonstreaming(
|
||||
reasoning_parser, str_output, request)
|
||||
return reasoning, content
|
||||
|
||||
|
||||
def run_reasoning_extraction_nonstreaming(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_output: list[str],
|
||||
@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming(
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
|
||||
|
||||
def run_reasoning_extraction_streaming_mistral(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_deltas: list[int],
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
) -> StreamingReasoningReconstructor:
|
||||
assert isinstance(reasoning_parser.model_tokenizer,
|
||||
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingReasoningReconstructor()
|
||||
previous_text = ""
|
||||
previous_tokens: list[int] = []
|
||||
for model_delta in model_deltas:
|
||||
token_delta = [model_delta]
|
||||
delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||
[model_delta])[0]
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_tokens,
|
||||
current_tokens,
|
||||
token_delta,
|
||||
)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
|
@ -565,8 +565,8 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
|
||||
from vllm.v1.engine.utils import EngineZmqAddresses
|
||||
|
||||
def mock_startup_handshake(self, handshake_socket, on_head_node,
|
||||
parallel_config):
|
||||
def mock_startup_handshake(self, handshake_socket, local_client,
|
||||
headless, parallel_config):
|
||||
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
|
||||
outputs=["tcp://127.0.0.1:5556"],
|
||||
coordinator_input=None,
|
||||
|
352
tests/v1/test_hybrid_lb_dp.py
Normal file
352
tests/v1/test_hybrid_lb_dp.py
Normal file
@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from vllm.platforms import Platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
# Number of data parallel ranks for hybrid LB testing (4 total)
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
|
||||
# Default tensor parallel size to use
|
||||
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
|
||||
|
||||
# Number of nodes (2 nodes, each with 2 DP ranks)
|
||||
NUM_NODES = 2
|
||||
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node
|
||||
|
||||
|
||||
class HybridLBServerManager:
|
||||
"""Manages hybrid data parallel vLLM server instances where each node
|
||||
runs a single logical API server that balances requests only to the
|
||||
DP engines running on that same node."""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
dp_size: int,
|
||||
api_server_count: int,
|
||||
base_server_args: list,
|
||||
dp_size_local: int = DP_SIZE_LOCAL,
|
||||
tp_size: int = TP_SIZE):
|
||||
self.model_name = model_name
|
||||
self.dp_size = dp_size
|
||||
self.dp_size_local = dp_size_local
|
||||
self.tp_size = tp_size
|
||||
self.api_server_count = api_server_count
|
||||
self.base_server_args = base_server_args
|
||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
||||
self.server_threads: list[threading.Thread] = []
|
||||
self.num_nodes = dp_size // dp_size_local
|
||||
|
||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||
"""Start all server instances for hybrid LB mode."""
|
||||
for node_id in range(self.num_nodes):
|
||||
# Create server args for this specific node
|
||||
server_args = self.base_server_args.copy()
|
||||
|
||||
# Calculate start rank for this node
|
||||
start_rank = node_id * self.dp_size_local
|
||||
|
||||
# Add hybrid LB specific arguments
|
||||
server_args.extend([
|
||||
"--data-parallel-size",
|
||||
str(self.dp_size),
|
||||
"--data-parallel-size-local",
|
||||
str(self.dp_size_local),
|
||||
"--data-parallel-start-rank",
|
||||
str(start_rank),
|
||||
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
|
||||
"--tensor-parallel-size",
|
||||
str(self.tp_size),
|
||||
"--port",
|
||||
str(8000 + node_id), # Different port for each node
|
||||
"--api-server-count",
|
||||
str(self.api_server_count),
|
||||
"--data-parallel-address",
|
||||
"127.0.0.1",
|
||||
"--data-parallel-rpc-port",
|
||||
"13345",
|
||||
])
|
||||
|
||||
# Use a thread to start each server to allow parallel initialization
|
||||
def start_server(node: int, sargs: list[str]):
|
||||
try:
|
||||
# Calculate GPU devices for this node
|
||||
gpus_per_node = self.dp_size_local * self.tp_size
|
||||
gpu_start = node * gpus_per_node
|
||||
gpu_end = gpu_start + gpus_per_node
|
||||
|
||||
# Start the server
|
||||
server = RemoteOpenAIServer(
|
||||
self.model_name,
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
",".join(
|
||||
str(Platform.device_id_to_physical_device_id(
|
||||
i)) for i in range(gpu_start, gpu_end))
|
||||
})
|
||||
server.__enter__()
|
||||
print(f"Hybrid LB node {node} started successfully with "
|
||||
f"{self.dp_size_local} local DP ranks and "
|
||||
f"{self.api_server_count} API servers")
|
||||
self.servers.append((server, sargs))
|
||||
except Exception as e:
|
||||
print(f"Failed to start hybrid LB node {node}: {e}")
|
||||
raise
|
||||
|
||||
thread = threading.Thread(target=start_server,
|
||||
args=(node_id, server_args))
|
||||
thread.start()
|
||||
|
||||
self.server_threads.append(thread)
|
||||
|
||||
# Wait for all servers to start
|
||||
for thread in self.server_threads:
|
||||
thread.join()
|
||||
|
||||
# Give servers additional time to fully initialize and coordinate
|
||||
time.sleep(3)
|
||||
|
||||
if len(self.servers) != self.num_nodes:
|
||||
raise Exception("Servers failed to start")
|
||||
|
||||
return self.servers
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop all server instances."""
|
||||
while self.servers:
|
||||
try:
|
||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
|
||||
def servers(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args, DP_SIZE_LOCAL,
|
||||
TP_SIZE) as server_list:
|
||||
yield server_list
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
# Create a client for each node (each node has its own API endpoint)
|
||||
async with AsyncExitStack() as stack:
|
||||
yield [
|
||||
await stack.enter_async_context(server.get_async_client())
|
||||
for server, _ in servers
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
|
||||
servers: list[tuple[RemoteOpenAIServer,
|
||||
list[str]]],
|
||||
model_name: str) -> None:
|
||||
|
||||
async def make_request(client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=10,
|
||||
temperature=1.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes early
|
||||
# or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request to each node
|
||||
for i, client in enumerate(clients):
|
||||
result = await make_request(client)
|
||||
assert result is not None
|
||||
print(
|
||||
f"Hybrid LB node {i} handled single completion request successfully"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send requests to all nodes - each should balance within its local DP ranks
|
||||
num_requests_per_node = 25 # Total 50 requests across 2 nodes
|
||||
all_tasks = []
|
||||
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [make_request(client) for _ in range(num_requests_per_node)]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_node * len(clients)
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of requests
|
||||
all_tasks = []
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [make_request(client) for _ in range(num_requests_per_node)]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_node * len(clients)
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
_, server_args = servers[0]
|
||||
api_server_count = (
|
||||
server_args.count('--api-server-count')
|
||||
and server_args[server_args.index('--api-server-count') + 1] or 1)
|
||||
print(
|
||||
f"Successfully completed hybrid LB test with {len(clients)} nodes "
|
||||
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
|
||||
)
|
||||
|
||||
# Check request balancing within each node
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Checking request balancing for node {i}")
|
||||
check_request_balancing(server, DP_SIZE_LOCAL)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_hybrid_lb_completion_streaming(clients: list[
|
||||
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
async def make_streaming_request(client: openai.AsyncOpenAI):
|
||||
# Perform a non-streaming request to get the expected full output
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
# Perform the streaming request
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
last_chunk = None
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# finish reason should only return in the last block for OpenAI API
|
||||
assert finish_reason_count == 1, (
|
||||
"Finish reason should appear exactly once.")
|
||||
assert last_chunk is not None, (
|
||||
"Stream should have yielded at least one chunk.")
|
||||
assert last_chunk.choices[
|
||||
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||
# Check that the combined text matches the non-streamed version.
|
||||
assert "".join(
|
||||
chunks
|
||||
) == single_output, "Streamed output should match non-streamed output."
|
||||
return True # Indicate success for this request
|
||||
|
||||
# Test single request to each node
|
||||
for i, client in enumerate(clients):
|
||||
result = await make_streaming_request(client)
|
||||
assert result is not None
|
||||
print(
|
||||
f"Hybrid LB node {i} handled single streaming request successfully"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send streaming requests to all nodes
|
||||
num_requests_per_node = 25 # Total 50 requests across 2 nodes
|
||||
all_tasks = []
|
||||
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [
|
||||
make_streaming_request(client)
|
||||
for _ in range(num_requests_per_node)
|
||||
]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_node * len(clients)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of streaming requests
|
||||
all_tasks = []
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [
|
||||
make_streaming_request(client)
|
||||
for _ in range(num_requests_per_node)
|
||||
]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_node * len(clients)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
_, server_args = servers[0]
|
||||
api_server_count = (
|
||||
server_args.count('--api-server-count')
|
||||
and server_args[server_args.index('--api-server-count') + 1] or 1)
|
||||
print(f"Successfully completed hybrid LB streaming test with "
|
||||
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
|
||||
f"API server count: {api_server_count})")
|
||||
|
||||
# Check request balancing within each node
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Checking streaming request balancing for node {i}")
|
||||
check_request_balancing(server, DP_SIZE_LOCAL)
|
@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
{"load_config": {
|
||||
"load_format": original_load_format
|
||||
}})
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
model_runner_2.reload_weights() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
|
||||
def test_reload_weights_before_load_model(model_runner):
|
||||
with pytest.raises(AssertionError):
|
||||
model_runner.reload_weights()
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
|
@ -31,6 +31,5 @@ run_mypy vllm/inputs
|
||||
run_mypy vllm/lora
|
||||
run_mypy vllm/model_executor
|
||||
run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
|
@ -143,6 +143,8 @@ class Attention(nn.Module):
|
||||
# the backends)
|
||||
if envs.VLLM_USE_V1:
|
||||
self.use_irope = extra_impl_args.pop("use_irope", False)
|
||||
else:
|
||||
self.use_irope = extra_impl_args.get("use_irope", False)
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
@ -177,7 +179,6 @@ class Attention(nn.Module):
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
self.use_irope = extra_impl_args.get("use_irope", False)
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
|
@ -423,6 +423,12 @@ class InductorAdaptor(CompilerInterface):
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False))
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False))
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(
|
||||
enable_remote_autograd_cache=False))
|
||||
|
@ -20,9 +20,38 @@ from .monitor import start_monitoring_torch_compile
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: _T) -> _T:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
@ -148,6 +177,8 @@ def _support_torch_compile(
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.vllm_config = vllm_config
|
||||
@ -156,9 +187,11 @@ def _support_torch_compile(
|
||||
self.do_not_compile = \
|
||||
vllm_config.compilation_config.level in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo()
|
||||
] or not supports_dynamo() or _should_ignore_torch_compile(
|
||||
self.__class__)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level)
|
||||
|
@ -651,6 +651,8 @@ class ModelConfig:
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
self.multimodal_config = self._init_multimodal_config()
|
||||
self.model_supports_multimodal_raw_input = (
|
||||
self.registry.supports_multimodal_raw_input(self.architectures))
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
@ -1243,10 +1245,10 @@ class ModelConfig:
|
||||
return self.get_hf_config_sliding_window()
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_text_config.vocab_size
|
||||
return getattr(self.hf_text_config, "vocab_size", 0)
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_text_config.hidden_size
|
||||
return getattr(self.hf_text_config, "hidden_size", 0)
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
@ -1906,8 +1908,16 @@ class ParallelConfig:
|
||||
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||
data_parallel_external_lb: bool = False
|
||||
"""Whether to use "external" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. Set implicitly when
|
||||
data_parallel_rank is provided explicitly to vllm serve."""
|
||||
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
|
||||
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
|
||||
is provided explicitly to vllm serve."""
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||
and API server on a "per-node" basis where vLLM load balances
|
||||
between local data parallel ranks, but an external LB balances
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
--data-parallel-start-rank."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
enable_eplb: bool = False
|
||||
@ -3141,59 +3151,6 @@ class LoRAConfig:
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class PromptAdapterConfig:
|
||||
"""Configuration for PromptAdapters."""
|
||||
|
||||
max_prompt_adapters: int = 1
|
||||
"""Max number of PromptAdapters in a batch."""
|
||||
max_prompt_adapter_token: int = 0
|
||||
"""Max number of PromptAdapters tokens."""
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
|
||||
`max_prompt_adapters`."""
|
||||
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
|
||||
"""Data type for PromptAdapter. If auto, will default to base model dtype.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.max_prompt_adapters < 1:
|
||||
raise ValueError(f"max_prompt_adapters "
|
||||
f"({self.max_prompt_adapters}) must be >= 1.")
|
||||
if self.max_prompt_adapter_token == 0:
|
||||
raise ValueError("max_prompt_adapter_token must be set.")
|
||||
if self.max_cpu_prompt_adapters is None:
|
||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.prompt_adapter_dtype == "auto":
|
||||
self.prompt_adapter_dtype = model_config.dtype
|
||||
elif isinstance(self.prompt_adapter_dtype, str):
|
||||
self.prompt_adapter_dtype = getattr(torch,
|
||||
self.prompt_adapter_dtype)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
@ -4400,8 +4357,6 @@ class VllmConfig:
|
||||
"""Decoding configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
"""Prompt adapter configuration."""
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
"""Quantization configuration."""
|
||||
compilation_config: CompilationConfig = field(
|
||||
@ -4498,10 +4453,6 @@ class VllmConfig:
|
||||
vllm_factors.append(self.observability_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.prompt_adapter_config:
|
||||
vllm_factors.append(self.prompt_adapter_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.quant_config:
|
||||
pass # should be captured by model_config.quantization
|
||||
if self.compilation_config:
|
||||
@ -4609,9 +4560,6 @@ class VllmConfig:
|
||||
if self.lora_config is not None:
|
||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
if self.prompt_adapter_config is not None:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
if self.quant_config is None and self.model_config is not None:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
|
@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupBase, SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta, SequenceStage,
|
||||
@ -165,8 +164,6 @@ class SchedulerOutputs:
|
||||
if self.num_loras > 0:
|
||||
self._sort_by_lora_ids()
|
||||
|
||||
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
# NOTE: We do not consider the ignored sequence groups.
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
@ -194,14 +191,6 @@ class SchedulerOutputs:
|
||||
if g.seq_group.lora_request is not None
|
||||
}
|
||||
|
||||
@property
|
||||
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
|
||||
return {
|
||||
g.seq_group.prompt_adapter_request
|
||||
for g in self.scheduled_seq_groups
|
||||
if g.seq_group.prompt_adapter_request is not None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRunningOutputs:
|
||||
@ -1648,7 +1637,6 @@ class Scheduler:
|
||||
multi_modal_placeholders=(
|
||||
seq_group.multi_modal_placeholders
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None),
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
else:
|
||||
# When SPMD mode is enabled, we only send delta data except for
|
||||
|
@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -295,9 +295,11 @@ class EngineArgs:
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: Optional[int] = None
|
||||
data_parallel_start_rank: Optional[int] = None
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
@ -358,11 +360,6 @@ class EngineArgs:
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
# PromptAdapter fields
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||
max_prompt_adapter_token: int = \
|
||||
PromptAdapterConfig.max_prompt_adapter_token
|
||||
|
||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||
@ -437,6 +434,8 @@ class EngineArgs:
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
# DEPRECATED
|
||||
enable_prompt_adapter: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@ -607,6 +606,11 @@ class EngineArgs:
|
||||
type=int,
|
||||
help='Data parallel rank of this instance. '
|
||||
'When set, enables external load balancer mode.')
|
||||
parallel_group.add_argument('--data-parallel-start-rank',
|
||||
'-dpr',
|
||||
type=int,
|
||||
help='Starting data parallel rank '
|
||||
'for secondary nodes.')
|
||||
parallel_group.add_argument('--data-parallel-size-local',
|
||||
'-dpl',
|
||||
type=int,
|
||||
@ -628,6 +632,9 @@ class EngineArgs:
|
||||
default='mp',
|
||||
help='Backend for data parallel, either '
|
||||
'"mp" or "ray".')
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-hybrid-lb",
|
||||
**parallel_kwargs["data_parallel_hybrid_lb"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
@ -729,23 +736,6 @@ class EngineArgs:
|
||||
lora_group.add_argument("--default-mm-loras",
|
||||
**lora_kwargs["default_mm_loras"])
|
||||
|
||||
# PromptAdapter related configs
|
||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||
prompt_adapter_group = parser.add_argument_group(
|
||||
title="PromptAdapterConfig",
|
||||
description=PromptAdapterConfig.__doc__,
|
||||
)
|
||||
prompt_adapter_group.add_argument(
|
||||
"--enable-prompt-adapter",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If True, enable handling of PromptAdapters.")
|
||||
prompt_adapter_group.add_argument(
|
||||
"--max-prompt-adapters",
|
||||
**prompt_adapter_kwargs["max_prompt_adapters"])
|
||||
prompt_adapter_group.add_argument(
|
||||
"--max-prompt-adapter-token",
|
||||
**prompt_adapter_kwargs["max_prompt_adapter_token"])
|
||||
|
||||
# Speculative arguments
|
||||
speculative_group = parser.add_argument_group(
|
||||
title="SpeculativeConfig",
|
||||
@ -850,6 +840,12 @@ class EngineArgs:
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='Disable logging statistics.')
|
||||
parser.add_argument('--enable-prompt-adapter',
|
||||
action='store_true',
|
||||
deprecated=True,
|
||||
help='[DEPRECATED] Prompt adapter has been '
|
||||
'removed. Setting this flag to True or False'
|
||||
' has no effect on vLLM behavior.')
|
||||
|
||||
return parser
|
||||
|
||||
@ -986,6 +982,7 @@ class EngineArgs:
|
||||
def create_engine_config(
|
||||
self,
|
||||
usage_context: Optional[UsageContext] = None,
|
||||
headless: bool = False,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create the VllmConfig.
|
||||
@ -1074,15 +1071,41 @@ class EngineArgs:
|
||||
# but we should not do this here.
|
||||
placement_group = ray.util.get_current_placement_group()
|
||||
|
||||
assert not headless or not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_hybrid_lb is not applicable in "
|
||||
"headless mode")
|
||||
|
||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
||||
# Local DP rank = 1, use pure-external LB.
|
||||
if data_parallel_external_lb:
|
||||
assert self.data_parallel_size_local in (1, None), (
|
||||
"data_parallel_size_local must be 1 when data_parallel_rank "
|
||||
"is set")
|
||||
data_parallel_size_local = 1
|
||||
# Use full external lb if we have local_size of 1.
|
||||
self.data_parallel_hybrid_lb = False
|
||||
elif self.data_parallel_size_local is not None:
|
||||
data_parallel_size_local = self.data_parallel_size_local
|
||||
|
||||
if self.data_parallel_start_rank and not headless:
|
||||
# Infer hybrid LB mode.
|
||||
self.data_parallel_hybrid_lb = True
|
||||
|
||||
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
|
||||
# Use full external lb if we have local_size of 1.
|
||||
data_parallel_external_lb = True
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
if data_parallel_size_local == self.data_parallel_size:
|
||||
# Disable hybrid LB mode if set for a single node
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
self.data_parallel_rank = self.data_parallel_start_rank or 0
|
||||
else:
|
||||
assert not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_size_local must be set to use "
|
||||
"data_parallel_hybrid_lb.")
|
||||
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size
|
||||
|
||||
@ -1139,6 +1162,7 @@ class EngineArgs:
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
@ -1234,11 +1258,6 @@ class EngineArgs:
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
prompt_adapter_config = PromptAdapterConfig(
|
||||
max_prompt_adapters=self.max_prompt_adapters,
|
||||
max_prompt_adapter_token=self.max_prompt_adapter_token) \
|
||||
if self.enable_prompt_adapter else None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
backend=self.guided_decoding_backend,
|
||||
disable_fallback=self.guided_decoding_disable_fallback,
|
||||
@ -1266,7 +1285,6 @@ class EngineArgs:
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
observability_config=observability_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
@ -1342,12 +1360,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Prompt Adapter so far.
|
||||
if self.enable_prompt_adapter:
|
||||
_raise_or_fallback(feature_name="--enable-prompt-adapter",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No text embedding inputs so far.
|
||||
if self.enable_prompt_embeds:
|
||||
_raise_or_fallback(feature_name="--enable-prompt-embeds",
|
||||
@ -1469,7 +1481,6 @@ class EngineArgs:
|
||||
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and model_config.runner_type != "pooling"):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
|
@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
processed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
arrival_time=arrival_time or time.time(),
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
data_parallel_rank: The (global) data parallel rank that must
|
||||
@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
):
|
||||
|
@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
|
||||
@ -223,7 +222,6 @@ class LLMEngine:
|
||||
self.load_config = vllm_config.load_config
|
||||
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
||||
)
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
|
||||
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
||||
)
|
||||
|
||||
@ -238,14 +236,14 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
self.use_cached_outputs = use_cached_outputs
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
else:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
else:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
@ -294,8 +292,6 @@ class LLMEngine:
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(self.lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(self.prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
self.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
@ -542,9 +538,6 @@ class LLMEngine:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
@ -553,7 +546,6 @@ class LLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
@ -569,7 +561,6 @@ class LLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
return None
|
||||
@ -583,11 +574,10 @@ class LLMEngine:
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
lora_request)
|
||||
|
||||
encoder_seq = (None if encoder_inputs is None else Sequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
||||
prompt_adapter_request))
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
@ -598,7 +588,6 @@ class LLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
elif isinstance(params, PoolingParams):
|
||||
@ -608,7 +597,6 @@ class LLMEngine:
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
else:
|
||||
@ -637,7 +625,6 @@ class LLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
@ -658,7 +645,6 @@ class LLMEngine:
|
||||
the current monotonic time.
|
||||
lora_request: The LoRA request to add.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: The prompt adapter request to add.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
@ -719,7 +705,6 @@ class LLMEngine:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
@ -728,7 +713,6 @@ class LLMEngine:
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
@ -741,7 +725,6 @@ class LLMEngine:
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
priority: int = 0,
|
||||
) -> SequenceGroup:
|
||||
@ -769,17 +752,15 @@ class LLMEngine:
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
draft_size = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens + 1
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority,
|
||||
draft_size=draft_size)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority,
|
||||
draft_size=draft_size)
|
||||
|
||||
return seq_group
|
||||
|
||||
@ -790,7 +771,6 @@ class LLMEngine:
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
priority: int = 0,
|
||||
) -> SequenceGroup:
|
||||
@ -798,15 +778,13 @@ class LLMEngine:
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
pooling_params = pooling_params.clone()
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
@ -1834,16 +1812,6 @@ class LLMEngine:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> List[int]:
|
||||
return self.model_executor.list_prompt_adapters()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.model_executor.start_profile()
|
||||
|
||||
|
@ -10,7 +10,6 @@ from vllm import PoolingParams
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -33,7 +32,6 @@ class RPCProcessRequest:
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
priority: int = 0
|
||||
|
||||
def __init__(
|
||||
@ -43,7 +41,6 @@ class RPCProcessRequest:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -53,7 +50,6 @@ class RPCProcessRequest:
|
||||
self.request_id = request_id
|
||||
self.lora_request = lora_request
|
||||
self.trace_headers = trace_headers
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.priority = priority
|
||||
|
||||
|
||||
|
@ -45,7 +45,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import Device
|
||||
@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
priority: Priority of the request (lower means earlier handling).
|
||||
Any priority other than 0 will lead to an error if the
|
||||
scheduling policy is not "priority".
|
||||
@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
return cast(
|
||||
AsyncGenerator[RequestOutput, None],
|
||||
self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers,
|
||||
prompt_adapter_request, priority))
|
||||
lora_request, trace_headers, priority))
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
PoolingRequestOutput, None]]:
|
||||
@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
))
|
||||
|
||||
|
@ -304,14 +304,12 @@ class MQLLMEngine:
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request,
|
||||
priority=request.priority)
|
||||
self.engine.add_request(request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
priority=request.priority)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
@ -55,7 +54,6 @@ class EngineClient(ABC):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request."""
|
||||
|
@ -151,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
video_url: Required[str]
|
||||
|
||||
|
||||
class CustomThinkCompletionContentParam(TypedDict, total=False):
|
||||
"""A Think Completion Content Param that accepts a plain text and a boolean.
|
||||
|
||||
Example:
|
||||
{
|
||||
"thinking": "I am thinking about the answer",
|
||||
"closed": True,
|
||||
"type": "thinking"
|
||||
}
|
||||
"""
|
||||
|
||||
thinking: Required[str]
|
||||
"""The thinking content."""
|
||||
|
||||
closed: bool
|
||||
"""Whether the thinking is closed."""
|
||||
|
||||
type: Required[Literal["thinking"]]
|
||||
"""The thinking type."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
@ -159,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam, str]
|
||||
CustomChatCompletionContentSimpleVideoParam, str,
|
||||
CustomThinkCompletionContentParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@ -938,6 +960,7 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
||||
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
|
||||
# Need to validate url objects
|
||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||
@ -954,6 +977,8 @@ MM_PARSER_MAP: dict[
|
||||
] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", None),
|
||||
"thinking":
|
||||
lambda part: _ThinkParser(part).get("thinking", None),
|
||||
"input_text":
|
||||
lambda part: _TextParser(part).get("text", None),
|
||||
"input_image":
|
||||
@ -1100,7 +1125,7 @@ def _parse_chat_message_content_part(
|
||||
"with empty / unparsable content.", part, part_type)
|
||||
return None
|
||||
|
||||
if part_type in ("text", "input_text", "refusal"):
|
||||
if part_type in ("text", "input_text", "refusal", "thinking"):
|
||||
str_content = cast(str, content)
|
||||
if wrap_dicts:
|
||||
return {'type': 'text', 'text': str_content}
|
||||
|
@ -45,11 +45,6 @@ class ServeSubcommand(CLISubcommand):
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
else:
|
||||
if args.data_parallel_start_rank:
|
||||
raise ValueError(
|
||||
"data_parallel_start_rank is only applicable "
|
||||
"in headless mode. "
|
||||
"Add --headless flag to enable headless mode.")
|
||||
if args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
@ -86,13 +81,14 @@ def run_headless(args: argparse.Namespace):
|
||||
# Create the EngineConfig.
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
|
||||
headless=True)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Headless mode is only supported for V1")
|
||||
|
||||
if engine_args.data_parallel_rank is not None:
|
||||
raise ValueError("data_parallel_rank is not applicable in "
|
||||
if engine_args.data_parallel_hybrid_lb:
|
||||
raise ValueError("data_parallel_hybrid_lb is not applicable in "
|
||||
"headless mode")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -122,7 +118,7 @@ def run_headless(args: argparse.Namespace):
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=args.data_parallel_start_rank,
|
||||
start_index=vllm_config.parallel_config.data_parallel_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
local_client=False,
|
||||
@ -169,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
" api_server_count > 1")
|
||||
model_config.disable_mm_preprocessor_cache = True
|
||||
|
||||
if vllm_config.parallel_config.data_parallel_hybrid_lb:
|
||||
raise NotImplementedError(
|
||||
"Hybrid load balancing with --api-server-count > 0"
|
||||
"is not yet supported.")
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
|
@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
@ -314,7 +313,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -330,7 +328,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -346,7 +343,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -363,7 +359,6 @@ class LLM:
|
||||
prompt_token_ids: list[int],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -380,7 +375,6 @@ class LLM:
|
||||
prompt_token_ids: list[list[int]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -395,7 +389,6 @@ class LLM:
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
@ -415,7 +408,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
@ -440,8 +432,6 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
priority: The priority of the requests, if any.
|
||||
Only applicable when priority scheduling policy is enabled.
|
||||
|
||||
@ -507,7 +497,6 @@ class LLM:
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
guided_options=guided_options_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
@ -963,7 +952,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -980,7 +968,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -997,7 +984,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -1015,7 +1001,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -1033,7 +1018,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -1049,7 +1033,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -1070,7 +1053,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -1092,8 +1074,6 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
pooling_task: Override the pooling task to use.
|
||||
|
||||
Returns:
|
||||
@ -1150,7 +1130,6 @@ class LLM:
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
@ -1167,7 +1146,6 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[EmbeddingRequestOutput]:
|
||||
"""
|
||||
Generate an embedding vector for each prompt.
|
||||
@ -1187,8 +1165,6 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
@ -1205,7 +1181,6 @@ class LLM:
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
@ -1218,7 +1193,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[ClassificationRequestOutput]:
|
||||
"""
|
||||
Generate class logits for each prompt.
|
||||
@ -1236,8 +1210,6 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `ClassificationRequestOutput` objects containing the
|
||||
@ -1253,7 +1225,6 @@ class LLM:
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="classify",
|
||||
)
|
||||
|
||||
@ -1267,7 +1238,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
@ -1275,7 +1245,6 @@ class LLM:
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
@ -1303,7 +1272,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
@ -1361,7 +1329,6 @@ class LLM:
|
||||
params=pooling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
@ -1381,7 +1348,6 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs `<text,text_pair>` or
|
||||
`<multi-modal data, multi-modal data pair>`.
|
||||
@ -1412,8 +1378,6 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
@ -1504,8 +1468,7 @@ class LLM:
|
||||
data_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
lora_request)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
@ -1513,8 +1476,7 @@ class LLM:
|
||||
data_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
lora_request)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
@ -1625,7 +1587,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
@ -1671,7 +1632,6 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority[i] if priority else 0,
|
||||
)
|
||||
|
||||
@ -1681,7 +1641,6 @@ class LLM:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
@ -1691,7 +1650,6 @@ class LLM:
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,6 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -30,7 +29,6 @@ class RequestLogger:
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
@ -44,7 +42,6 @@ class RequestLogger:
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"prompt_embeds shape: %s, "
|
||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||
prompt, params, prompt_token_ids,
|
||||
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
|
||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||
lora_request, prompt_adapter_request)
|
||||
lora_request)
|
||||
|
@ -1620,7 +1620,6 @@ async def init_app_state(
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.openai_serving_responses = OpenAIServingResponses(
|
||||
|
@ -20,8 +20,7 @@ from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: list[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FrontendArgs:
|
||||
@ -115,9 +93,6 @@ class FrontendArgs:
|
||||
or JSON list format. Example (old format): `'name=path'` Example (new
|
||||
format): `{\"name\": \"name\", \"path\": \"lora_path\",
|
||||
\"base_model_name\": \"id\"}`"""
|
||||
prompt_adapters: Optional[list[PromptAdapterPath]] = None
|
||||
"""Prompt adapter configurations in the format name=path. Multiple adapters
|
||||
can be specified."""
|
||||
chat_template: Optional[str] = None
|
||||
"""The file path to the chat template, or the template in single-line form
|
||||
for the specified model."""
|
||||
@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
|
||||
|
||||
# Special case: Prompt adapters need custom parser action and
|
||||
# optional_type(str)
|
||||
frontend_kwargs["prompt_adapters"]["type"] = optional_type(str)
|
||||
frontend_kwargs["prompt_adapters"][
|
||||
"action"] = PromptAdapterParserAction
|
||||
|
||||
# Special case: Middleware needs append action
|
||||
frontend_kwargs["middleware"]["action"] = "append"
|
||||
frontend_kwargs["middleware"]["type"] = str
|
||||
@ -253,13 +222,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.")
|
||||
parser.add_argument(
|
||||
"--data-parallel-start-rank",
|
||||
"-dpr",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Starting data parallel rank for secondary nodes. "
|
||||
"Requires --headless.")
|
||||
parser.add_argument("--api-server-count",
|
||||
"-asc",
|
||||
type=int,
|
||||
@ -288,9 +250,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
if args.enable_prompt_embeds and args.enable_prompt_adapter:
|
||||
raise ValueError(
|
||||
"Cannot use prompt embeds and prompt adapter at the same time.")
|
||||
|
||||
|
||||
def log_non_default_args(args: argparse.Namespace):
|
||||
|
@ -337,7 +337,6 @@ async def main(args):
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
|
@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request,
|
||||
supports_default_mm_loras=True)
|
||||
lora_request = self._maybe_get_adapters(
|
||||
request, supports_default_mm_loras=True)
|
||||
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
|
||||
@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
|
@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
|
||||
return None
|
||||
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||
ctx.lora_request)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported for classification models"
|
||||
)
|
||||
|
||||
(
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
|
@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
|
||||
) -> Optional[ErrorResponse]:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||
)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
|
@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
|
||||
MultiModalDataDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob, PromptLogprobs
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
@ -343,12 +341,10 @@ class OpenAIServing:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
ctx.request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
prompt_adapter_request=ctx.prompt_adapter_request)
|
||||
self._log_inputs(request_id_item,
|
||||
ctx.request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance of
|
||||
# TypedDicts with `builtins.enumerate`:
|
||||
@ -450,11 +446,6 @@ class OpenAIServing:
|
||||
if isinstance(load_result, ErrorResponse) and \
|
||||
load_result.code == HTTPStatus.BAD_REQUEST.value:
|
||||
error_response = load_result
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
@ -489,25 +480,21 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
supports_default_mm_loras: bool = False,
|
||||
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
) -> Optional[LoRARequest]:
|
||||
|
||||
if request.model in self.models.lora_requests:
|
||||
return self.models.lora_requests[request.model], None
|
||||
return self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
return default_mm_lora, None
|
||||
return default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
return None
|
||||
|
||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
@ -987,7 +974,6 @@ class OpenAIServing:
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
@ -1009,7 +995,6 @@ class OpenAIServing:
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def _get_trace_headers(
|
||||
|
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -31,12 +28,6 @@ class BaseModelPath:
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
@ -60,7 +51,6 @@ class OpenAIServingModels:
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None,
|
||||
prompt_adapters: Optional[list[PromptAdapterPath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -81,20 +71,6 @@ class OpenAIServingModels:
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
@ -141,14 +117,7 @@ class OpenAIServingModels:
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
|
@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for pooling models")
|
||||
|
||||
if isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
messages = self._construct_input_messages(request, prev_response)
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self._log_inputs(request.request_id,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
request.request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
generators.append(generator)
|
||||
|
@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
input_texts = texts_1 + texts_2
|
||||
@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
|
||||
self._log_inputs(request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
request_prompts: list[str] = []
|
||||
@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
|
||||
raw_request: Optional[Request] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
else:
|
||||
@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
async def create_score(
|
||||
|
@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
lora_request=lora_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect
|
||||
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||
if isinstance(engine_prompt,
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
# (Unlike in Embeddings API where an error is raised)
|
||||
lora_request=lora_request)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
|
@ -11,6 +11,7 @@ from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -38,10 +39,6 @@ T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
@ -70,6 +67,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
model_config, task_type)
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
@ -93,7 +92,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
lang = request.language or "en"
|
||||
self.model_cls.validate_language(lang)
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
@ -150,19 +149,12 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for "
|
||||
f"{self.task_type.title()}.")
|
||||
if prompt_adapter_request:
|
||||
return self.create_error_response(
|
||||
f"Currently do not support PromptAdapter for "
|
||||
f"{self.task_type.title()}.")
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
@ -188,8 +180,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
lora_request=None)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
|
14
vllm/envs.py
14
vllm/envs.py
@ -61,6 +61,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 8
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
@ -140,6 +141,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
|
||||
|
||||
@ -518,6 +520,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Maximum filesize in MB for a single audio file when processing
|
||||
# speech-to-text requests. Files larger than this will be rejected.
|
||||
# Default is 25 MB
|
||||
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB":
|
||||
lambda: int(os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25")),
|
||||
|
||||
# Backend for Video IO
|
||||
# - "opencv": Default backend that uses OpenCV stream buffered backend.
|
||||
#
|
||||
@ -968,6 +976,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
# If set to 1, allows GC to run during capture.
|
||||
"VLLM_ENABLE_CUDAGRAPH_GC":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
|
||||
|
||||
# Used to force set up loopback IP
|
||||
"VLLM_LOOPBACK_IP":
|
||||
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
|
||||
|
@ -17,7 +17,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
@ -50,7 +49,6 @@ class ExecutorBase(ABC):
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self._init_executor()
|
||||
self.is_sleeping = False
|
||||
@ -171,35 +169,6 @@ class ExecutorBase(ABC):
|
||||
assert s == sets[0], "All workers should have the same LORAs."
|
||||
return sets[0]
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
assert prompt_adapter_request.prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("add_prompt_adapter",
|
||||
args=(prompt_adapter_request, )))
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("remove_prompt_adapter",
|
||||
args=(prompt_adapter_id, )))
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("pin_prompt_adapter",
|
||||
args=(prompt_adapter_id, )))
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
sets = self.collective_rpc("list_prompt_adapters")
|
||||
for s in sets:
|
||||
assert (s == sets[0]
|
||||
), "All workers should have the same prompt adapters."
|
||||
return sets[0]
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.collective_rpc("start_profile")
|
||||
|
||||
|
@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
@ -168,18 +167,6 @@ class InputPreprocessor:
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
def _apply_prompt_adapter(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> list[int]:
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_tokenization_kw(
|
||||
self,
|
||||
overrides: Optional[dict[str, Any]] = None,
|
||||
@ -786,15 +773,10 @@ class InputPreprocessor:
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_inputs: DecoderOnlyInputs,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> DecoderOnlyInputs:
|
||||
if "prompt_token_ids" in prompt_inputs:
|
||||
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
|
||||
prompt_inputs) # Needed for mypy
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
@ -803,7 +785,6 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
@ -815,7 +796,6 @@ class InputPreprocessor:
|
||||
|
||||
* prompt: input prompt
|
||||
* lora_request
|
||||
* prompt_adapter_request
|
||||
* return_mm_hashes
|
||||
|
||||
Returns:
|
||||
@ -830,17 +810,13 @@ class InputPreprocessor:
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
@ -854,17 +830,13 @@ class InputPreprocessor:
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
@ -886,7 +858,6 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
@ -895,7 +866,6 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
@ -919,6 +889,5 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from collections.abc import Generator
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
@ -32,8 +33,18 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# for raw HTTPS link
|
||||
if model_name_or_path.startswith(
|
||||
("http://", "https://")) and model_name_or_path.endswith(".gguf"):
|
||||
return hf_hub_download(url=model_name_or_path)
|
||||
# repo id/filename.gguf
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
raise ValueError(f"{model_name_or_path} is not a file.")
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)")
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
|
@ -136,6 +136,40 @@ def supports_multimodal(
|
||||
return getattr(model, "supports_multimodal", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
|
||||
"""The interface required for all multi-modal models."""
|
||||
|
||||
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports multi-modal inputs and processes
|
||||
them in their raw form and not embeddings.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
|
||||
...
|
||||
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
|
||||
TypeIs[SupportsMultiModalWithRawInput]]:
|
||||
return getattr(model, "supports_multimodal_raw_input", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsScoreTemplate(Protocol):
|
||||
"""The interface required for all models that support score template."""
|
||||
|
@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -27,13 +28,14 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
|
||||
PoolerIdentity, SimplePooler)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
MultiModalFieldElem, MultiModalInputs,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalSharedField, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
return {
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0),
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
||||
dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
|
||||
|
||||
@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
# a single image patch and therefore it is not to be split
|
||||
# into multiple elements, but rather to be considered a single one.
|
||||
# Hence, the decision of using a MultiModalSharedField.
|
||||
# The expected shape is (num_channels, width, height).
|
||||
|
||||
# This model however allows the user to also submit multiple image
|
||||
# patches as a batch, adding a further dimension to the above shape.
|
||||
# At this stage we only support submitting one patch per request and
|
||||
# batching is achieved via vLLM batching.
|
||||
# TODO (christian-pinto): enable support for multi patch requests
|
||||
# in tandem with vLLM batching.
|
||||
multimodal_kwargs_items = [
|
||||
MultiModalKwargsItem.from_elems([
|
||||
MultiModalFieldElem(
|
||||
modality="image",
|
||||
key=key,
|
||||
data=data,
|
||||
field=MultiModalSharedField(1),
|
||||
) for key, data in mm_kwargs.items()
|
||||
])
|
||||
]
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
|
||||
mm_hashes=None,
|
||||
mm_placeholders={},
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
|
||||
SupportsMultiModalWithRawInput):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
is_pooling_model = True
|
||||
@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
||||
)
|
||||
|
||||
return task.model
|
||||
else:
|
||||
@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
# We do not really use any input tokens and therefore no embeddings
|
||||
# to be calculated. However, due to the mandatory token ids in
|
||||
# the input prompt we pass one token and the size of the dummy
|
||||
# embedding tensors must reflect that.
|
||||
return torch.empty((input_ids.shape[0], 0))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
@ -22,8 +22,8 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||
is_hybrid, supports_cross_encoding,
|
||||
supports_multimodal, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
supports_multimodal, supports_multimodal_raw_input,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -287,6 +287,7 @@ class _ModelInfo:
|
||||
is_pooling_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
@ -304,6 +305,7 @@ class _ModelInfo:
|
||||
is_pooling_model=True, # Can convert any model into a pooling model
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
@ -573,6 +575,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal_raw_input
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
|
@ -266,7 +266,7 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
if tokenizer is None:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
if disable_cache is None:
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
@ -1,83 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.layers import AdapterMapping
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterMapping(AdapterMapping):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
|
||||
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.emb_layer = self.base_layer
|
||||
if 'LoRA' in base_layer.__class__.__name__:
|
||||
self.emb_layer = self.base_layer.base_layer
|
||||
|
||||
def create_prompt_adapter_weights(
|
||||
self, prompt_adapter_config: PromptAdapterConfig):
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
prompt_adapter_config.max_prompt_adapter_token,
|
||||
self.emb_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.emb_layer.weight.dtype,
|
||||
device=self.emb_layer.weight.device,
|
||||
)
|
||||
self.adapter_lengths = torch.zeros(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
dtype=torch.long,
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
self.indices_gpu: torch.Tensor
|
||||
self.embedding_indices_gpu: torch.Tensor
|
||||
|
||||
def reset_prompt_adapter(self, index: int):
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_prompt_adapter(
|
||||
self,
|
||||
index: int,
|
||||
adapter_model: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_prompt_adapter(index)
|
||||
if adapter_model is not None:
|
||||
length = adapter_model.shape[0]
|
||||
self.embeddings_tensors[index, :length] = adapter_model
|
||||
self.adapter_lengths[index] = length
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
prompt_indices: torch.Tensor,
|
||||
prompt_embedding_indices: torch.Tensor,
|
||||
):
|
||||
self.indices_gpu = prompt_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
self.embedding_indices_gpu = prompt_embedding_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.base_layer(x)
|
||||
if self.embedding_indices_gpu.ndim > 1:
|
||||
valid_mask = self.indices_gpu != -1
|
||||
gathered_embeddings = self.embeddings_tensors[
|
||||
self.embedding_indices_gpu[:, 0],
|
||||
self.embedding_indices_gpu[:, 1]]
|
||||
|
||||
# Update hidden states
|
||||
hidden_states[valid_mask] = gathered_embeddings
|
||||
return hidden_states
|
@ -1,358 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
AdapterModelManager)
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.layers import (
|
||||
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.utils import load_peft_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GLOBAL_PROMPT_ADAPTER_ID = 0
|
||||
|
||||
|
||||
def get_prompt_adapter_id():
|
||||
global _GLOBAL_PROMPT_ADAPTER_ID
|
||||
_GLOBAL_PROMPT_ADAPTER_ID += 1
|
||||
return _GLOBAL_PROMPT_ADAPTER_ID
|
||||
|
||||
|
||||
def convert_to_embedding_indices(indices):
|
||||
embedding_indices = []
|
||||
count = 0
|
||||
|
||||
for value in indices:
|
||||
if value == -1:
|
||||
count = 0
|
||||
else:
|
||||
embedding_indices.append([value, count])
|
||||
count += 1
|
||||
|
||||
return torch.tensor(embedding_indices)
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: PromptAdapterMapping,
|
||||
prompt_adapter_index_to_id: List[Optional[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Converts PromptAdapterMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: PromptAdapterMapping mapping rows in a
|
||||
batch to PromptAdapter ids.
|
||||
prompt_adapter_index_to_id: List mapping PromptAdapter
|
||||
ids to PromptAdapter indices.
|
||||
|
||||
Returns:
|
||||
pa_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
PromptAdapter indices.
|
||||
"""
|
||||
id_to_index = {
|
||||
id_: idx
|
||||
for idx, id_ in enumerate(prompt_adapter_index_to_id)
|
||||
if id_ is not None
|
||||
}
|
||||
pa_indices = ([
|
||||
id_to_index.get(id_, -1) if id_ > 0 else -1
|
||||
for id_ in mapping.index_mapping
|
||||
])
|
||||
|
||||
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
|
||||
pa_indices = torch.tensor(pa_indices)
|
||||
return pa_indices, pa_embedding_mapping
|
||||
|
||||
|
||||
class PromptAdapterModel(AdapterModel):
|
||||
|
||||
def __init__(self,
|
||||
prompt_adapter_id=None,
|
||||
num_virtual_tokens=None,
|
||||
prompt_embedding=None) -> None:
|
||||
self.id = prompt_adapter_id
|
||||
self.prompt_embedding = prompt_embedding
|
||||
self.num_virtual_tokens = num_virtual_tokens
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
adapter_model_path: str,
|
||||
prompt_adapter_id: int,
|
||||
num_virtual_tokens: int,
|
||||
config: PromptAdapterConfig,
|
||||
device: str = "cuda",
|
||||
) -> "PromptAdapterModel":
|
||||
|
||||
if num_virtual_tokens > config.max_prompt_adapter_token:
|
||||
raise ValueError(
|
||||
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
|
||||
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
|
||||
|
||||
adapters_weights = load_peft_weights(adapter_model_path, device)
|
||||
prompt_embedding = adapters_weights["prompt_embeddings"].to(
|
||||
config.prompt_adapter_dtype)
|
||||
|
||||
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
|
||||
|
||||
|
||||
class PromptAdapterModelManager(AdapterModelManager):
|
||||
"""A manager that manages multiple Prompt Adapter models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
"""Create a PromptAdapterModel and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
prompt_adapter_config: the PromptAdapter config,
|
||||
"""
|
||||
self.model: nn.Module = model
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self.prompt_adapter_index_to_id: List[
|
||||
Optional[int]] = [None] * self.prompt_adapter_slots
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.model.prompt_adapter_manager = self
|
||||
self.adapter_type = 'PromptAdapter'
|
||||
|
||||
self.base_indices = torch.tensor([-1])
|
||||
self.base_embedding_indices = torch.tensor([])
|
||||
|
||||
self.modules: Dict[str, nn.Module] = {}
|
||||
self._create_prompt_adapter_modules()
|
||||
self._last_mapping: Optional[PromptAdapterMapping] = None
|
||||
|
||||
@property
|
||||
def prompt_adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_config.max_prompt_adapters
|
||||
|
||||
@property
|
||||
def adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_slots
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.prompt_adapter_config.max_cpu_prompt_adapters
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
"""Move PromptAdapter into a GPU buffer
|
||||
to be used in the forward pass."""
|
||||
if prompt_adapter_id in self._active_adapters:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
|
||||
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
|
||||
None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free prompt_adapter slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_adapters[prompt_adapter_id] = None
|
||||
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
|
||||
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
|
||||
prompt_adapter_model.id, index)
|
||||
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
|
||||
for _, v in self.modules.items():
|
||||
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
|
||||
return True
|
||||
|
||||
def _deactivate_adapter(self, prompt_adapter_id: int):
|
||||
try:
|
||||
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
|
||||
self.prompt_adapter_index_to_id[index] = None
|
||||
for _, v in self.modules.items():
|
||||
v.reset_prompt_adapter(index)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
|
||||
self._registered_adapters[prompt_adapter.id] = prompt_adapter
|
||||
|
||||
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
base_indices, base_embedding_indices = convert_mapping(
|
||||
mapping, self.prompt_adapter_index_to_id)
|
||||
for k, v in self.modules.items():
|
||||
v.set_mapping(base_indices, base_embedding_indices)
|
||||
|
||||
def _create_prompt_adapter_modules(self):
|
||||
for module_name, module in self.model.named_modules(
|
||||
remove_duplicate=False):
|
||||
if "VocabParallel" in module.__class__.__name__:
|
||||
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
|
||||
new_module.create_prompt_adapter_weights(
|
||||
self.prompt_adapter_config)
|
||||
replaced_module = self.replace_submodule(
|
||||
self.model, module_name, new_module)
|
||||
self.register_module(module.__class__.__name__,
|
||||
replaced_module)
|
||||
replaced_module.set_mapping(self.base_indices,
|
||||
self.base_embedding_indices)
|
||||
break
|
||||
|
||||
def replace_submodule(self, model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
def register_module(self, module_name: str, module: nn.Module):
|
||||
self.modules[module_name] = module
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in PromptAdapterModelManager. "
|
||||
"Use LRUCachePromptAdapterModelManager for pinning"
|
||||
) # type: ignore
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all PromptAdapterModel from the manager."""
|
||||
self._registered_adapters.clear()
|
||||
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
|
||||
self._active_adapters.clear()
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
||||
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
||||
self._set_adapter_mapping)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return remove_adapter(adapter_id, self._registered_adapters,
|
||||
self.deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
return list_adapters(self._registered_adapters)
|
||||
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
return get_adapter(adapter_id, self._registered_adapters)
|
||||
|
||||
|
||||
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
|
||||
|
||||
def __init__(self, capacity: int,
|
||||
deactivate_prompt_adapter_fn: Callable[[int], bool]):
|
||||
super().__init__(capacity, deactivate_prompt_adapter_fn)
|
||||
|
||||
|
||||
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
|
||||
"""A model manager that manages multiple prompt_adapters with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
prompt_adapter_config)
|
||||
self._registered_adapters = PromptAdapterLRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters = PromptAdapterLRUCache(
|
||||
self.prompt_adapter_slots, self._deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
|
||||
"""List all registered PromptAdapterModel."""
|
||||
return dict(self._registered_adapters.cache)
|
||||
|
||||
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
|
||||
"""Add a PromptAdapterModel to the manager."""
|
||||
if prompt_adapter.id not in self._registered_adapters:
|
||||
self._add_adapter(prompt_adapter)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_adapters.touch(prompt_adapter.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
if prompt_adapter_id not in self._active_adapters and len(
|
||||
self._active_adapters) >= self.prompt_adapter_slots:
|
||||
self._active_adapters.remove_oldest()
|
||||
result = super().activate_adapter(prompt_adapter_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_adapters.touch(prompt_adapter_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_adapter(self) -> bool:
|
||||
if len(self._registered_adapters) > 0:
|
||||
self._registered_adapters.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
|
||||
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
|
||||
return True
|
||||
|
||||
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
|
||||
try:
|
||||
self._registered_adapters.pin(prompt_adapter_id)
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Pinning failed. "
|
||||
f"Prompt Adapter {prompt_adapter_id} is not registered."
|
||||
) from err
|
||||
|
||||
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
|
||||
if prompt_adapter_id not in self._active_adapters:
|
||||
# move adapter to gpu if not already active
|
||||
self.activate_adapter(prompt_adapter_id)
|
||||
self._active_adapters.pin(prompt_adapter_id)
|
||||
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_manager_cls: Type[
|
||||
PromptAdapterModelManager] = PromptAdapterModelManager,
|
||||
**kwargs) -> PromptAdapterModelManager:
|
||||
"""Create a PromptAdapterModel for a given model."""
|
||||
prompt_adapter_manager = prompt_adapter_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
**kwargs)
|
||||
return prompt_adapter_manager
|
@ -1,37 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
class PromptAdapterRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
frozen=True): # type: ignore[call-arg]
|
||||
"""
|
||||
Request for a Prompt adapter.
|
||||
"""
|
||||
__metaclass__ = AdapterRequest
|
||||
|
||||
prompt_adapter_name: str
|
||||
prompt_adapter_id: int
|
||||
prompt_adapter_local_path: str
|
||||
prompt_adapter_num_virtual_tokens: int
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
@property
|
||||
def adapter_id(self):
|
||||
return self.prompt_adapter_id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.prompt_adapter_name
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
return self.prompt_adapter_local_path
|
@ -1,98 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import file_exists, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
WEIGHTS_NAME = "adapter_model.bin"
|
||||
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
|
||||
|
||||
# Get current device name based on available devices
|
||||
def infer_device() -> str:
|
||||
if current_platform.is_cuda_alike():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def load_peft_weights(model_id: str,
|
||||
device: Optional[str] = None,
|
||||
**hf_hub_download_kwargs) -> dict:
|
||||
r"""
|
||||
A helper method to load the PEFT weights from the HuggingFace Hub or locally
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The local path to the adapter weights or the name of the adapter to
|
||||
load from the HuggingFace Hub.
|
||||
device (`str`):
|
||||
The device to load the weights onto.
|
||||
hf_hub_download_kwargs (`dict`):
|
||||
Additional arguments to pass to the `hf_hub_download` method when
|
||||
loading from the HuggingFace Hub.
|
||||
"""
|
||||
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if
|
||||
hf_hub_download_kwargs.get("subfolder") is not None else model_id)
|
||||
|
||||
if device is None:
|
||||
device = infer_device()
|
||||
|
||||
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
|
||||
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
|
||||
use_safetensors = True
|
||||
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
|
||||
filename = os.path.join(path, WEIGHTS_NAME)
|
||||
use_safetensors = False
|
||||
else:
|
||||
token = hf_hub_download_kwargs.get("token")
|
||||
if token is None:
|
||||
token = hf_hub_download_kwargs.get("use_auth_token")
|
||||
|
||||
hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
|
||||
SAFETENSORS_WEIGHTS_NAME)
|
||||
if hf_hub_download_kwargs.get("subfolder") is not None
|
||||
else SAFETENSORS_WEIGHTS_NAME)
|
||||
has_remote_safetensors_file = file_exists(
|
||||
repo_id=model_id,
|
||||
filename=hub_filename,
|
||||
revision=hf_hub_download_kwargs.get("revision"),
|
||||
repo_type=hf_hub_download_kwargs.get("repo_type"),
|
||||
token=token,
|
||||
)
|
||||
use_safetensors = has_remote_safetensors_file
|
||||
|
||||
if has_remote_safetensors_file:
|
||||
# Priority 1: load safetensors weights
|
||||
filename = hf_hub_download(
|
||||
model_id,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
**hf_hub_download_kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
filename = hf_hub_download(model_id, WEIGHTS_NAME,
|
||||
**hf_hub_download_kwargs)
|
||||
except EntryNotFoundError:
|
||||
raise ValueError( # noqa: B904
|
||||
f"Can't find weights for {model_id} in {model_id} or \
|
||||
in the Hugging Face Hub. "
|
||||
f"Please check that the file {WEIGHTS_NAME} or \
|
||||
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")
|
||||
|
||||
if use_safetensors:
|
||||
adapters_weights = safe_load_file(filename, device=device)
|
||||
else:
|
||||
adapters_weights = torch.load(filename,
|
||||
map_location=torch.device(device),
|
||||
weights_only=True)
|
||||
|
||||
return adapters_weights
|
@ -1,179 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
|
||||
PromptAdapterModel,
|
||||
PromptAdapterModelManager,
|
||||
create_prompt_adapter_manager)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerPromptAdapterManager(AbstractWorkerManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Every request, the requested prompt_adapters will be
|
||||
loaded (unless they are already loaded),
|
||||
and every other prompt_adapter will be unloaded."""
|
||||
|
||||
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
|
||||
):
|
||||
self._adapter_manager: PromptAdapterModelManager
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self._prompt_adapter_model_cls = prompt_adapter_model_cls
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = prompt_adapter_manager
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _load_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest
|
||||
) -> PromptAdapterModel:
|
||||
try:
|
||||
prompt_adapter = (
|
||||
self._prompt_adapter_model_cls.from_local_checkpoint(
|
||||
prompt_adapter_request.prompt_adapter_local_path,
|
||||
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
|
||||
num_virtual_tokens=prompt_adapter_request.
|
||||
prompt_adapter_num_virtual_tokens,
|
||||
config=self.prompt_adapter_config,
|
||||
device=str(self.device),
|
||||
))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading prompt_adapter "
|
||||
f"{prompt_adapter_request.prompt_adapter_local_path}"
|
||||
f" failed") from e
|
||||
return prompt_adapter
|
||||
|
||||
def add_dummy_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return True
|
||||
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested
|
||||
prompt_adapters will be loaded (unless they are already loaded)
|
||||
and least recently used prompt_adapters will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_prompt_adapter_manager_cls: Type[
|
||||
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
|
||||
self._adapter_manager: LRUCachePromptAdapterModelManager = (
|
||||
prompt_adapter_manager)
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _apply_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
|
||||
prompt_adapters_map = {
|
||||
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
|
||||
for prompt_adapter_request in prompt_adapter_requests
|
||||
if prompt_adapter_request
|
||||
}
|
||||
if len(prompt_adapters_map
|
||||
) > self._adapter_manager.prompt_adapter_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested prompt_adapters "
|
||||
f"({len(prompt_adapters_map)}) is greater "
|
||||
"than the number of GPU prompt_adapter slots "
|
||||
f"({self._adapter_manager.prompt_adapter_slots}).")
|
||||
for prompt_adapter in prompt_adapters_map.values():
|
||||
self.add_adapter(prompt_adapter)
|
||||
|
||||
def add_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
|
||||
):
|
||||
# Remove before we load the new prompt_adapter to save memory
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
prompt_adapter = self._load_adapter(prompt_adapter_request)
|
||||
loaded = self._adapter_manager.add_adapter(prompt_adapter)
|
||||
else:
|
||||
# If the prompt_adapter is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id) is not None
|
||||
self._adapter_manager.activate_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id)
|
||||
return loaded
|
@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
|
||||
__all__ = [
|
||||
@ -16,4 +17,5 @@ __all__ = [
|
||||
"HunyuanA13BReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
"MistralReasoningParser",
|
||||
]
|
||||
|
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import (
|
||||
DeepSeekR1ReasoningParser)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("mistral")
|
||||
class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Mistral models.
|
||||
|
||||
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
|
||||
text. This parser extracts the reasoning content from the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer):
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"The tokenizer must be an instance of MistralTokenizer.")
|
||||
|
||||
ReasoningParser.__init__(self, tokenizer)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
self.start_token = SpecialTokens.begin_think
|
||||
self.end_token = SpecialTokens.end_think
|
||||
|
||||
self.start_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.end_token)
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!")
|
@ -19,7 +19,6 @@ from vllm.inputs import SingletonInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
@ -458,7 +457,6 @@ class Sequence:
|
||||
block size used by the block manager and cache engine.
|
||||
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -468,14 +466,12 @@ class Sequence:
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.data = SequenceData.from_seqs(
|
||||
self.prompt_token_ids,
|
||||
@ -537,11 +533,6 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int,
|
||||
delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
@ -601,12 +592,12 @@ class Sequence:
|
||||
designed for prefix caching mode. The final sequence hash is determined
|
||||
by applying token_ids from the sequence's blocks.
|
||||
"""
|
||||
if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
|
||||
if self.lora_int_id == 0:
|
||||
return None
|
||||
|
||||
# NOTE: If there are additional factors influencing the block aside from
|
||||
# token_ids, include them as input parameters to the hash.
|
||||
return hash((self.prompt_adapter_id, self.lora_int_id))
|
||||
return hash(self.lora_int_id)
|
||||
|
||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||
return logical_idx * self.block_size + self.block_size
|
||||
@ -707,7 +698,6 @@ class SequenceGroup:
|
||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||
unless you are working with an encoder/decoder model.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
priority: User-defined priority of the request.
|
||||
draft_size: The number of speculative tokens plus one from the target
|
||||
model; equal to max number of tokens a step can generate
|
||||
@ -725,7 +715,6 @@ class SequenceGroup:
|
||||
pooled_data: Optional[torch.Tensor] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
draft_size: int = 1) -> None:
|
||||
self.request_id = request_id
|
||||
@ -747,7 +736,6 @@ class SequenceGroup:
|
||||
self.state = SequenceGroupState()
|
||||
self.pooling_params = pooling_params
|
||||
self.pooled_data = pooled_data
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.encoder_seq = encoder_seq
|
||||
self.trace_headers = trace_headers
|
||||
self.priority = priority
|
||||
@ -802,16 +790,6 @@ class SequenceGroup:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def init_multi_step(self, num_steps: int) -> None:
|
||||
self.state.num_steps = num_steps
|
||||
self.state.current_step = 0
|
||||
@ -1011,7 +989,6 @@ class SequenceGroupMetadata(
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
model.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
@ -1030,7 +1007,6 @@ class SequenceGroupMetadata(
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[list[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
token_chunk_size: Optional[int] = None
|
||||
|
||||
### Stateful fields that are lazily defined. ###
|
||||
@ -1052,16 +1028,6 @@ class SequenceGroupMetadata(
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
# Multi-Step Chunked-Prefill property
|
||||
@property
|
||||
def is_single_step_prompt(self) -> bool:
|
||||
@ -1525,7 +1491,6 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
pooled_data=seq_group.pooled_data,
|
||||
encoder_seq=seq_group.encoder_seq,
|
||||
trace_headers=seq_group.trace_headers,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
priority=seq_group.priority,
|
||||
)
|
||||
|
||||
|
@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]):
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
def _aggregate_content(content: list) -> list[dict[str, Any]]:
|
||||
aggregated_content: list[dict[str, Any]] = []
|
||||
for chunk in content:
|
||||
if chunk.get("type"
|
||||
) == "text" and aggregated_content and aggregated_content[
|
||||
-1].get("type") == "text":
|
||||
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
|
||||
else:
|
||||
aggregated_content.append(chunk)
|
||||
if len(aggregated_content) == 1 and aggregated_content[0].get(
|
||||
"type") == "text":
|
||||
content = aggregated_content[0]["text"]
|
||||
return content
|
||||
|
||||
|
||||
def make_mistral_chat_completion_request(
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str,
|
||||
@ -162,10 +177,10 @@ def make_mistral_chat_completion_request(
|
||||
|
||||
# Convert list text content to string
|
||||
if message.get("role") in ("assistant", "tool"):
|
||||
content = message.get("content")
|
||||
content: Any = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
content = _aggregate_content(content)
|
||||
message["content"] = content
|
||||
|
||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||
# "parameters" dict to be present, even if it's empty.
|
||||
@ -465,6 +480,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
skip_special_tokens: bool = True,
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.instruct import (
|
||||
InstructTokenizerV13)
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
@ -474,10 +491,18 @@ class MistralTokenizer(TokenizerBase):
|
||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||
|
||||
if self.is_tekken:
|
||||
# skip special tokens except tool call
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
||||
# skip special tokens except tool call and think tokens
|
||||
non_skip_special_tokens = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||
}
|
||||
if isinstance(self.instruct, InstructTokenizerV13):
|
||||
if self.instruct.BEGIN_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
|
||||
if self.instruct.END_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.END_THINK)
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens
|
||||
or i in non_skip_special_tokens
|
||||
]
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
|
@ -128,10 +128,6 @@ STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
|
||||
"backends currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
# Efficiently import all enc/dec error strings
|
||||
# rather than having to import all of the above
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
@ -145,7 +141,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
|
||||
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||
}
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
@ -20,7 +20,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
@ -94,11 +93,14 @@ class AsyncLLM(EngineClient):
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
@ -125,7 +127,7 @@ class AsyncLLM(EngineClient):
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
engine_idxs=self.engine_core.engine_ranks,
|
||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
@ -218,7 +220,6 @@ class AsyncLLM(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> RequestOutputCollector:
|
||||
@ -235,8 +236,7 @@ class AsyncLLM(EngineClient):
|
||||
# Convert Input --> Request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority, data_parallel_rank)
|
||||
tokenization_kwargs, trace_headers, priority, data_parallel_rank)
|
||||
|
||||
if is_pooling or params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
@ -280,7 +280,6 @@ class AsyncLLM(EngineClient):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -311,7 +310,6 @@ class AsyncLLM(EngineClient):
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
@ -525,6 +523,10 @@ class AsyncLLM(EngineClient):
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
|
@ -61,11 +61,12 @@ class DPCoordinator:
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
external_lb = parallel_config.data_parallel_external_lb
|
||||
hybrid_lb = parallel_config.data_parallel_hybrid_lb
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# external DP LB mode.
|
||||
# either external or hybrid DP LB mode.
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=not external_lb, host=host)
|
||||
local_only=not external_lb and not hybrid_lb, host=host)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
@ -234,9 +234,14 @@ class EngineCore:
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def execute_model(self, scheduler_output: SchedulerOutput):
|
||||
def execute_model_with_error_logging(
|
||||
self,
|
||||
model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Execute the model and log detailed info on failure."""
|
||||
try:
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
return model_fn(scheduler_output)
|
||||
except Exception as err:
|
||||
# We do not want to catch BaseException here since we're only
|
||||
# interested in dumping info when the exception is due to an
|
||||
@ -259,7 +264,9 @@ class EngineCore:
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
self.model_executor.execute_model, # type: ignore
|
||||
scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
@ -306,8 +313,11 @@ class EngineCore:
|
||||
# so we need more work.
|
||||
if not scheduled_batch and not self.batch_queue.empty():
|
||||
future, scheduler_output = self.batch_queue.get_nowait()
|
||||
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
lambda _: future.result(), scheduler_output)
|
||||
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
@ -467,13 +477,14 @@ class EngineCoreProc(EngineCore):
|
||||
For DP>1 with internal loadbalancing this is with the shared front-end
|
||||
process which may reside on a different node.
|
||||
|
||||
For DP>1 with external loadbalancing, two handshakes are performed:
|
||||
For DP>1 with external or hybrid loadbalancing, two handshakes are
|
||||
performed:
|
||||
- With the rank 0 front-end process which retrieves the
|
||||
DP Coordinator ZMQ addresses and DP process group address.
|
||||
- With the colocated front-end process which retrieves the
|
||||
client input/output socket addresses.
|
||||
with the exception of the rank 0 engine itself which doesn't require
|
||||
the second handshake.
|
||||
with the exception of the rank 0 and colocated engines themselves which
|
||||
don't require the second handshake.
|
||||
|
||||
Here, "front-end" process can mean the process containing the engine
|
||||
core client (which is the API server process in the case the API
|
||||
@ -482,15 +493,18 @@ class EngineCoreProc(EngineCore):
|
||||
"""
|
||||
input_ctx = zmq.Context()
|
||||
is_local = local_client and client_handshake_address is None
|
||||
headless = not local_client
|
||||
handshake = self._perform_handshake(input_ctx, handshake_address,
|
||||
identity, is_local, vllm_config,
|
||||
identity, is_local, headless,
|
||||
vllm_config,
|
||||
vllm_config.parallel_config)
|
||||
if client_handshake_address is None:
|
||||
with handshake as addresses:
|
||||
yield addresses
|
||||
else:
|
||||
assert local_client
|
||||
local_handshake = self._perform_handshake(
|
||||
input_ctx, client_handshake_address, identity, local_client,
|
||||
input_ctx, client_handshake_address, identity, True, False,
|
||||
vllm_config)
|
||||
with handshake as addresses, local_handshake as client_addresses:
|
||||
addresses.inputs = client_addresses.inputs
|
||||
@ -507,6 +521,7 @@ class EngineCoreProc(EngineCore):
|
||||
handshake_address: str,
|
||||
identity: bytes,
|
||||
local_client: bool,
|
||||
headless: bool,
|
||||
vllm_config: VllmConfig,
|
||||
parallel_config_to_update: Optional[ParallelConfig] = None,
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
@ -518,6 +533,7 @@ class EngineCoreProc(EngineCore):
|
||||
bind=False) as handshake_socket:
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, local_client,
|
||||
headless,
|
||||
parallel_config_to_update)
|
||||
yield addresses
|
||||
|
||||
@ -531,6 +547,7 @@ class EngineCoreProc(EngineCore):
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}))
|
||||
@ -539,6 +556,7 @@ class EngineCoreProc(EngineCore):
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket,
|
||||
local_client: bool,
|
||||
headless: bool,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
) -> EngineZmqAddresses:
|
||||
|
||||
@ -547,6 +565,7 @@ class EngineCoreProc(EngineCore):
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
@ -891,22 +910,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
||||
try:
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank *
|
||||
world_size, (local_dp_rank + 1) * world_size))
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
f"Error setting {device_control_env_var}: "
|
||||
f"local range: [{local_dp_rank * world_size}, "
|
||||
f"{(local_dp_rank + 1) * world_size}) "
|
||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
|
||||
@ -1069,14 +1072,41 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
||||
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||
local_dp_rank
|
||||
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
||||
# we clean this up to be able to properly initialize
|
||||
# data parallel groups.
|
||||
del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
|
||||
# NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
|
||||
# and this cannot be done in the same way for Ray because:
|
||||
# 1) Ray manages life cycle of all ray workers (including
|
||||
# DPEngineCoreActor)
|
||||
# 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
|
||||
# To bypass 2, we need to also set
|
||||
# RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
|
||||
# thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
|
||||
# https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
|
||||
# But vLLM worker assumes visibility into all local GPUs, therefore
|
||||
# this results in incorrect indexing into the GPU ID list.
|
||||
self._set_cuda_visible_devices(vllm_config, local_dp_rank)
|
||||
|
||||
super().__init__(vllm_config, local_client, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
|
||||
local_dp_rank: int):
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
||||
try:
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank *
|
||||
world_size, (local_dp_rank + 1) * world_size))
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
f"Error setting {device_control_env_var}: "
|
||||
f"local range: [{local_dp_rank * world_size}, "
|
||||
f"{(local_dp_rank + 1) * world_size}) "
|
||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
|
@ -429,18 +429,23 @@ class MPClient(EngineCoreClient):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
dp_local_size = parallel_config.data_parallel_size_local
|
||||
offline_mode = parallel_config.data_parallel_rank_local is not None
|
||||
self.engine_ranks = ([dp_rank] if
|
||||
(offline_mode or external_dp_lb) else list(
|
||||
range(dp_size)))
|
||||
# Client manages local+remote EngineCores in pure internal LB case.
|
||||
# Client manages local EngineCores in hybrid and external LB case.
|
||||
local_engines_only = (parallel_config.data_parallel_hybrid_lb
|
||||
or parallel_config.data_parallel_external_lb)
|
||||
|
||||
num_ranks = dp_local_size if local_engines_only else dp_size
|
||||
self.engine_ranks_managed = [dp_rank] if offline_mode else list(
|
||||
range(dp_rank, dp_rank + num_ranks))
|
||||
assert parallel_config.data_parallel_size_local <= len(
|
||||
self.engine_ranks)
|
||||
self.engine_ranks_managed)
|
||||
|
||||
# ZMQ identity of each engine that this client will talk to.
|
||||
self.core_engines: list[EngineIdentity] = [
|
||||
index.to_bytes(2, "little") for index in self.engine_ranks
|
||||
rank.to_bytes(2, "little")
|
||||
for rank in self.engine_ranks_managed
|
||||
]
|
||||
|
||||
# Wait for ready messages from each engine on the input socket.
|
||||
@ -895,6 +900,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
return
|
||||
|
||||
assert self.stats_update_address is not None
|
||||
assert len(self.engine_ranks_managed) > 0
|
||||
# NOTE: running and waiting counts are all global from
|
||||
# the Coordinator include all global EngineCores. This
|
||||
# slice includes just the cores managed by this client.
|
||||
count_slice = slice(self.engine_ranks_managed[0],
|
||||
self.engine_ranks_managed[-1] + 1)
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with make_zmq_socket(self.ctx, self.stats_update_address,
|
||||
@ -959,7 +970,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
self.lb_engines = counts
|
||||
self.lb_engines = counts[count_slice]
|
||||
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
|
@ -17,7 +17,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
@ -82,11 +81,14 @@ class LLMEngine:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
@ -189,7 +191,6 @@ class LLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
# Validate the request_id type.
|
||||
@ -200,8 +201,7 @@ class LLMEngine:
|
||||
# Process raw inputs into the request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
tokenization_kwargs, trace_headers, priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
|
@ -327,14 +327,16 @@ class OutputProcessor:
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
tokenizer = None if not self.tokenizer else \
|
||||
self.tokenizer.get_lora_tokenizer(request.lora_request)
|
||||
|
||||
req_state = RequestState.from_new_request(tokenizer=tokenizer,
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
|
@ -16,7 +16,6 @@ from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@ -226,7 +225,6 @@ class Processor:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
@ -237,8 +235,6 @@ class Processor:
|
||||
self._validate_params(params, lora_request)
|
||||
if trace_headers is not None:
|
||||
raise ValueError("V1 does not support tracing yet.")
|
||||
if prompt_adapter_request is not None:
|
||||
raise ValueError("V1 does not support prompt_adapter_request.")
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
@ -253,12 +249,10 @@ class Processor:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@ -380,7 +374,6 @@ class Processor:
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
if not prompt_ids:
|
||||
@ -389,9 +382,14 @@ class Processor:
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
|
@ -10,12 +10,14 @@ from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
@ -105,10 +107,13 @@ class CoreEngineProcManager:
|
||||
"client_handshake_address"] = client_handshake_address
|
||||
|
||||
self.processes: list[BaseProcess] = []
|
||||
local_dp_ranks = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
|
||||
# Start EngineCore in background process.
|
||||
local_dp_ranks.append(local_index)
|
||||
self.processes.append(
|
||||
context.Process(target=target_fn,
|
||||
name=f"EngineCore_{global_index}",
|
||||
@ -118,9 +123,14 @@ class CoreEngineProcManager:
|
||||
}))
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||
with set_device_control_env_var(
|
||||
vllm_config, local_dp_rank) if (
|
||||
data_parallel) else contextlib.nullcontext():
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
@ -145,6 +155,30 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_device_control_env_var(vllm_config: VllmConfig,
|
||||
local_dp_rank: int) -> Iterator[None]:
|
||||
"""
|
||||
Temporarily set CUDA_VISIBLE_DEVICES or equivalent
|
||||
for engine subprocess.
|
||||
"""
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
evar = current_platform.device_control_env_var
|
||||
try:
|
||||
value = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
except IndexError as e:
|
||||
raise Exception(f"Error setting {evar}: "
|
||||
f"local range: [{local_dp_rank * world_size}, "
|
||||
f"{(local_dp_rank + 1) * world_size}) "
|
||||
"base value: "
|
||||
f"\"{os.getenv(evar)}\"") from e
|
||||
with patch.dict(os.environ, values=((evar, value), )):
|
||||
yield
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
@ -215,10 +249,9 @@ class CoreEngineActorManager:
|
||||
|
||||
self.placement_group_is_local = []
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
for index, local_index, pg in zip(range(dp_size), local_dp_ranks,
|
||||
placement_groups):
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
local_client = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
@ -264,7 +297,6 @@ class CoreEngineActorManager:
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
@ -544,7 +576,8 @@ def launch_core_engines(
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
local_engines_only = (parallel_config.data_parallel_hybrid_lb
|
||||
or parallel_config.data_parallel_external_lb)
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
@ -553,8 +586,8 @@ def launch_core_engines(
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = offline_mode or external_dp_lb or (local_engine_count
|
||||
== dp_size)
|
||||
client_local_only = (offline_mode or local_engines_only
|
||||
or (local_engine_count == dp_size))
|
||||
|
||||
# Set up input and output addresses.
|
||||
addresses = EngineZmqAddresses(
|
||||
@ -598,14 +631,27 @@ def launch_core_engines(
|
||||
yield engine_actor_manager, coordinator, addresses
|
||||
return
|
||||
|
||||
if offline_mode or (external_dp_lb and dp_rank > 0):
|
||||
if offline_mode:
|
||||
assert local_engine_count == 1
|
||||
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
|
||||
else:
|
||||
elif dp_rank == 0:
|
||||
# Rank 0 holds Coordinator, so it handshakes with all Cores
|
||||
# in both external dplb and internal dplb mode.
|
||||
# Note this also covers the case where we have zero local engines
|
||||
# and rank 0 is headless.
|
||||
engines_to_handshake = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
else:
|
||||
# Rank > 0 handshakes with just the local cores it is managing.
|
||||
assert local_engines_only, (
|
||||
"Attempting to launch core_engines from dp_rank > 0, but "
|
||||
"found internal DPLB, which is incompatible.")
|
||||
engines_to_handshake = [
|
||||
CoreEngine(index=i, local=True)
|
||||
for i in range(dp_rank, dp_rank + local_engine_count)
|
||||
]
|
||||
|
||||
# Whether the started engines will handshake only with co-located
|
||||
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
|
||||
@ -616,7 +662,7 @@ def launch_core_engines(
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
if external_dp_lb and dp_rank > 0:
|
||||
if local_engines_only and dp_rank > 0:
|
||||
assert not handshake_local_only
|
||||
local_handshake_address = get_open_zmq_ipc_path()
|
||||
client_handshake_address = local_handshake_address
|
||||
@ -631,8 +677,6 @@ def launch_core_engines(
|
||||
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
@ -678,6 +722,9 @@ def wait_for_engine_startup(
|
||||
poller = zmq.Poller()
|
||||
poller.register(handshake_socket, zmq.POLLIN)
|
||||
|
||||
remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \
|
||||
and not parallel_config.data_parallel_external_lb
|
||||
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
@ -713,13 +760,24 @@ def wait_for_engine_startup(
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
status, local, headless = msg["status"], msg["local"], msg["headless"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
|
||||
if not local and headless != remote_should_be_headless:
|
||||
if headless:
|
||||
raise RuntimeError(f"Remote engine {eng_index} must not use "
|
||||
f"--headless in external or hybrid dp lb "
|
||||
f"mode")
|
||||
else:
|
||||
raise RuntimeError(f"Remote engine {eng_index} must use "
|
||||
f"--headless unless in external or hybrid "
|
||||
f"dp lb mode")
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Send init message with DP config info.
|
||||
|
@ -318,8 +318,6 @@ def report_usage_stats(
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(vllm_config.lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(vllm_config.prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
|
@ -104,7 +104,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
@ -126,6 +125,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.model_supports_multimodal_raw_input = (
|
||||
model_config.model_supports_multimodal_raw_input)
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
@ -328,6 +329,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_goups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
@ -565,6 +574,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _init_model_kwargs_for_multimodal_model(
|
||||
self,
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
num_reqs: int = -1,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if self.model_supports_multimodal_raw_input:
|
||||
# This model requires the raw multimodal data in input.
|
||||
if scheduler_output:
|
||||
multi_modal_kwargs_list = []
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_mm_inputs = req.mm_inputs
|
||||
if not isinstance(req_mm_inputs, list):
|
||||
req_mm_inputs = list(req_mm_inputs)
|
||||
multi_modal_kwargs_list.extend(req_mm_inputs)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||
multi_modal_kwargs_list)
|
||||
else:
|
||||
# The only case where SchedulerOutput is None is for
|
||||
# a dummy run let's get some dummy data.
|
||||
dummy_data = [
|
||||
self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=1).multi_modal_data for i in range(num_reqs)
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
|
||||
|
||||
model_kwargs.update(multi_modal_kwargs)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
@ -1359,10 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
|
||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
||||
scheduler_output=scheduler_output)
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids=input_ids,
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
@ -1374,6 +1419,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
model_kwargs = {}
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
@ -1406,6 +1452,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
@ -1822,17 +1872,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
with DeviceMemoryProfiler() as m:
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info(
|
||||
"Model was already initialized. Loading weights inplace..."
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
@ -1865,6 +1907,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
assert getattr(self, "model", None) is not None, \
|
||||
"Cannot reload weights before model is loaded."
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
logger.info("Reloading weights inplace...")
|
||||
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
@ -2084,11 +2133,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
if self.is_multimodal_model:
|
||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
||||
num_reqs=num_reqs)
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
model_kwargs = {}
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
@ -2117,7 +2170,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, _ = outputs
|
||||
else:
|
||||
@ -2381,10 +2439,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
# Optimize garbage collection during CUDA graph capture.
|
||||
# Clean up, then freeze all remaining objects from being included
|
||||
# in future collections.
|
||||
gc.collect()
|
||||
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
|
||||
if should_freeze:
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if should_freeze:
|
||||
gc.unfreeze()
|
||||
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
full_cg = self.full_cuda_graph
|
||||
# Only rank 0 should print progress bar during capture
|
||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||
|
@ -4,6 +4,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -15,7 +16,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -118,6 +120,21 @@ class Worker(WorkerBase):
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def _maybe_get_memory_pool_context(self,
|
||||
tag: str) -> AbstractContextManager:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
if tag == "weights":
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag=tag)
|
||||
else:
|
||||
context = nullcontext()
|
||||
return context
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
@ -179,24 +196,17 @@ class Worker(WorkerBase):
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
with context:
|
||||
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
@ -333,19 +343,20 @@ class Worker(WorkerBase):
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# finished_sending and finished_recving buffers.
|
||||
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
empty_output = copy.copy(empty_output)
|
||||
empty_output.finished_sending = output.finished_sending
|
||||
empty_output.finished_recving = output.finished_recving
|
||||
output = empty_output
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
output = new_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
# return output only from the driver worker
|
||||
return output if self.is_driver_worker else None
|
||||
return output
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
|
@ -114,7 +114,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.original_parallel_config = original_parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.device_config = vllm_config.device_config
|
||||
|
||||
@ -1174,16 +1173,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
mesh=self.mesh)
|
||||
else:
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info("Model was already initialized. \
|
||||
Loading weights inplace...")
|
||||
model_loader.load_weights(
|
||||
self.model, model_config=self.model_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to load model, a likely reason is the model is "
|
||||
@ -1205,6 +1198,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
assert getattr(self, "model", None) is not None, \
|
||||
"Cannot reload weights before model is loaded."
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
logger.info("Reloading weights inplace...")
|
||||
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dummy_run(self, num_tokens: int, num_reqs: int,
|
||||
num_blocks: int) -> None:
|
||||
|
@ -62,7 +62,6 @@ class TPUWorker:
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
@ -265,6 +264,9 @@ class TPUWorker:
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
@ -91,10 +91,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
'''
|
||||
EncoderDecoderModelRunner constructor.
|
||||
|
||||
`lora_config` and `prompt_adapter_config` are
|
||||
unused (since these features are not yet supported for encoder/decoder
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
`lora_config` is unused (since these features are not yet supported
|
||||
for encoder/decoder models) but these arguments are present here for
|
||||
compatibility with the base-class constructor.
|
||||
'''
|
||||
self._maybe_force_supported_attention_backend()
|
||||
|
||||
|
@ -45,10 +45,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
MultiModalRegistry)
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.prompt_adapter.worker_manager import (
|
||||
LRUCacheWorkerPromptAdapterManager)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
|
||||
@ -95,8 +91,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
|
||||
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
@ -113,8 +107,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"prompt_adapter_mapping": self.prompt_adapter_mapping,
|
||||
"prompt_adapter_requests": self.prompt_adapter_requests,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
@ -164,8 +156,6 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"prompt_adapter_mapping": self.prompt_adapter_mapping,
|
||||
"prompt_adapter_requests": self.prompt_adapter_requests,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
@ -212,8 +202,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.lora_index_mapping.clear() # type: ignore
|
||||
self.lora_prompt_mapping.clear() # type: ignore
|
||||
self.lora_requests.clear() # type: ignore
|
||||
self.prompt_adapter_index_mapping.clear() # type: ignore
|
||||
self.prompt_adapter_prompt_mapping.clear() # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -252,11 +240,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||
|
||||
# Prompt adapter inputs.
|
||||
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||||
multi_modal_placeholder_maps: Optional[Dict[
|
||||
@ -360,18 +343,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
else:
|
||||
self.lora_requests.clear()
|
||||
|
||||
if prompt_adapter_index_mapping:
|
||||
self.prompt_adapter_index_mapping = \
|
||||
prompt_adapter_index_mapping
|
||||
else:
|
||||
self.prompt_adapter_index_mapping.clear()
|
||||
|
||||
if prompt_adapter_prompt_mapping:
|
||||
self.prompt_adapter_prompt_mapping = \
|
||||
prompt_adapter_prompt_mapping
|
||||
else:
|
||||
self.prompt_adapter_prompt_mapping.clear()
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.inputs_embeds = inputs_embeds
|
||||
@ -390,12 +361,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
|
||||
self.prompt_adapter_index_mapping = (
|
||||
prompt_adapter_index_mapping or [])
|
||||
self.prompt_adapter_prompt_mapping = (
|
||||
prompt_adapter_prompt_mapping or [])
|
||||
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.multi_modal_kwargs = multi_modal_kwargs
|
||||
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
@ -485,7 +450,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Compute functions for each sequence group.
|
||||
# WARNING: The order of the functions matters!
|
||||
self.per_seq_group_compute_fns = [
|
||||
self._compute_prompt_adapter_input,
|
||||
self._compute_multi_modal_input,
|
||||
]
|
||||
|
||||
@ -496,8 +460,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
|
||||
is not None)
|
||||
|
||||
# Attention metadata inputs.
|
||||
if self.attn_backend is not None:
|
||||
@ -693,34 +655,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
else:
|
||||
inter_data.lora_prompt_mapping.append([])
|
||||
|
||||
def _compute_prompt_adapter_input(
|
||||
self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If prompt adapter is enabled, compute index and prompt mapping.
|
||||
"""
|
||||
# Note that when is_prompt=True, we expect only one sequence
|
||||
# in the group.
|
||||
if not self.enable_prompt_adapter:
|
||||
return
|
||||
|
||||
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
|
||||
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
|
||||
return
|
||||
|
||||
# We expect only one sequence in the group when is_prompt=True.
|
||||
assert inter_data.n_seqs == 1
|
||||
query_len = inter_data.query_lens[0]
|
||||
inter_data.prompt_adapter_request = (
|
||||
seq_group_metadata.prompt_adapter_request)
|
||||
|
||||
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
|
||||
inter_data.prompt_adapter_index_mapping = [
|
||||
prompt_adapter_id
|
||||
] * num_tokens + [0] * (query_len - num_tokens)
|
||||
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
|
||||
query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
|
||||
|
||||
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If multi-modal data is given, add it to the input."""
|
||||
@ -1009,29 +943,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
is_prefill=not self.decode_only))
|
||||
|
||||
# Prompt adapter data.
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
prompt_adapter_mapping = None
|
||||
if self.enable_prompt_adapter:
|
||||
prompt_adapter_requests = set(
|
||||
data.prompt_adapter_request for data in self.inter_data_list
|
||||
if data.prompt_adapter_request is not None)
|
||||
prompt_adapter_index_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_index_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
if cuda_graph_pad_size:
|
||||
prompt_adapter_index_mapping.extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_prompt_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
prompt_adapter_index_mapping,
|
||||
prompt_adapter_prompt_mapping,
|
||||
)
|
||||
|
||||
# Multi-modal data.
|
||||
multi_modal_kwargs_list = [
|
||||
data.multi_modal_kwargs for data in self.inter_data_list
|
||||
@ -1051,9 +962,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
prompt_adapter_mapping=prompt_adapter_mapping,
|
||||
prompt_adapter_requests=prompt_adapter_requests)
|
||||
finished_requests_ids=self.finished_requests_ids)
|
||||
|
||||
|
||||
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
@ -1148,7 +1057,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
||||
self.sampler = get_sampler()
|
||||
|
||||
set_cpu_offload_max_bytes(
|
||||
@ -1207,14 +1115,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
self.model_memory_usage / GiB_bytes,
|
||||
time_after_load - time_before_load)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens, self.device,
|
||||
self.prompt_adapter_config)
|
||||
self.model = (
|
||||
self.prompt_adapter_manager.create_prompt_adapter_manager(
|
||||
self.model))
|
||||
|
||||
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
@ -1466,40 +1367,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
def remove_all_prompt_adapters(self):
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
self.prompt_adapter_manager.remove_all_adapters()
|
||||
|
||||
def set_active_prompt_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest],
|
||||
prompt_adapter_mapping: PromptAdapterMapping) -> None:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
self.prompt_adapter_manager.set_active_adapters(
|
||||
prompt_adapter_requests, prompt_adapter_mapping)
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.list_adapters()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
@ -1609,13 +1476,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.set_active_loras(set([dummy_lora_request]),
|
||||
lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
[-1] * batch_size,
|
||||
[-1] * batch_size,
|
||||
)
|
||||
self.set_active_prompt_adapters(
|
||||
set(), prompt_adapter_mapping)
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name(),
|
||||
self.attn_state.graph_clone(batch_size),
|
||||
@ -1776,13 +1636,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
@ -1932,24 +1785,32 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
if self.is_driver_worker:
|
||||
sampled = broadcast_tensor_dict(
|
||||
{"token_ids": output.sampled_token_ids})
|
||||
sampled_token_ids = []
|
||||
valid_outputs = []
|
||||
for sequence_group_output in output.outputs:
|
||||
if len(sequence_group_output.samples) == 0:
|
||||
continue
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
valid_outputs.append(sequence_group_output)
|
||||
sampled_token_ids.append(
|
||||
sequence_group_output.samples[0].output_token)
|
||||
sampled_token_ids = torch.tensor(sampled_token_ids).to(
|
||||
self.device)
|
||||
sampled_token_ids = broadcast_tensor_dict(
|
||||
{"sampled_token_ids":
|
||||
sampled_token_ids})["sampled_token_ids"]
|
||||
else:
|
||||
sampled = broadcast_tensor_dict()
|
||||
if sampled["token_ids"] is not None:
|
||||
sampled_token_embeds = self.model.get_input_embeddings(
|
||||
sampled["token_ids"].squeeze(1))
|
||||
sampled_token_ids = broadcast_tensor_dict(
|
||||
)["sampled_token_ids"]
|
||||
if len(sampled_token_ids) > 0:
|
||||
sampled_token_embeds = \
|
||||
self.model.get_input_embeddings(sampled_token_ids)
|
||||
if self.is_driver_worker:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs
|
||||
|
||||
output.sampled_token_embeds = sampled_token_embeds
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[
|
||||
0].output_embed = token_embed
|
||||
for i, sequence_group_output in enumerate(valid_outputs):
|
||||
sequence_group_output.samples[0].output_embed = \
|
||||
sampled_token_embeds[i]
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
@ -190,7 +190,6 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
# Map of request_id -> generator used for seeded random sampling
|
||||
|
@ -288,9 +288,6 @@ class StatefulModelInput(BroadcastableModelInput):
|
||||
assert fmi.lora_requests is not None
|
||||
assert len(fmi.lora_requests) == 0
|
||||
assert fmi.attn_metadata is not None
|
||||
assert fmi.prompt_adapter_mapping is None
|
||||
assert fmi.prompt_adapter_requests is not None
|
||||
assert len(fmi.prompt_adapter_requests) == 0
|
||||
assert fmi.multi_modal_kwargs is not None
|
||||
assert len(fmi.multi_modal_kwargs) == 0
|
||||
|
||||
|
@ -64,13 +64,6 @@ class PoolingModelRunner(
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
|
@ -47,7 +47,3 @@ def assert_enc_dec_mr_supported_scenario(
|
||||
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||
|
||||
if enc_dec_mr.prompt_adapter_config is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user