[Misc] IO Processor plugins for pooling models (#22820)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Christian Pinto
2025-09-01 07:07:12 +01:00
committed by GitHub
parent 437c3ce026
commit 1cb39dbcdd
25 changed files with 1183 additions and 43 deletions

View File

@ -769,6 +769,11 @@ steps:
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
- pip install -e ./plugins/prithvi_io_processor_plugin
- pytest -v -s plugins_tests/test_io_processor_plugins.py
- pip uninstall prithvi_io_processor_plugin -y
# end io_processor plugins test
# other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model

View File

@ -0,0 +1,78 @@
# IO Processor Plugins
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
## Writing an IO Processor Plugin
IO Processor plugins implement the `IOProcessor` interface (<gh-file:vllm/plugins/io_processors/interface.py>):
```python
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError
```
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples.
## Using an IO Processor plugin
IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded:
1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode.
2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json).
The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json).

View File

@ -49,6 +49,8 @@ Every plugin has three parts:
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name.
## Guidelines for Writing Plugins
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import torch
from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Reuirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
def main():
torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
)
pooling_params = PoolingParams(task="encode", softmax=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
print(output)
decoded_data = base64.b64decode(output.data)
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
with open(file_path, "wb") as f:
f.write(decoded_data)
print(f"Output file path: {file_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import requests
# This example shows how to perform an online inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Reuirements :
# - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india
def main():
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
server_endpoint = "http://localhost:8000/pooling"
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
}
ret = requests.post(server_endpoint, json=request_payload_url)
print(f"response.status_code: {ret.status_code}")
print(f"response.reason:{ret.reason}")
response = ret.json()
decoded_image = base64.b64decode(response["data"]["data"])
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
with open(out_path, "wb") as f:
f.write(decoded_image)
if __name__ == "__main__":
main()

View File

@ -1120,6 +1120,9 @@ class VllmRunner:
return self.llm.llm_engine.collective_rpc(_apply_model)
def get_llm(self) -> LLM:
return self.llm
def __enter__(self):
return self

View File

@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def register_prithvi_india():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
def register_prithvi_valencia():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501

View File

@ -0,0 +1,449 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import base64
import datetime
import os
import tempfile
import urllib.request
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Optional, Union
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
IOProcessorResponse)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (IOProcessor,
IOProcessorInput,
IOProcessorOutput)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
logger = init_logger(__name__)
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
datamodule_config: DataModuleConfig = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size":
16,
"constant_scale":
0.0001,
"data_root":
"/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last":
True,
"no_data_replace":
0.0,
"no_label_replace":
-1,
"num_workers":
8,
"test_transform": [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0),
],
}
def save_geotiff(image: torch.Tensor, meta: dict,
out_format: str) -> str | bytes:
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
if out_format == "path":
# create temp file
file_path = os.path.join(os.getcwd(), "prediction.tiff")
with rasterio.open(file_path, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
return file_path
elif out_format == "b64_json":
with tempfile.NamedTemporaryFile() as tmpfile:
with rasterio.open(tmpfile.name, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
file_data = tmpfile.read()
return base64.b64encode(file_data)
else:
raise ValueError("Unknown output format")
def _convert_np_uint8(float_image: torch.Tensor):
image = float_image.numpy() * 255.0
image = image.astype(dtype=np.uint8)
return image
def read_geotiff(
file_path: Optional[str] = None,
path_type: Optional[str] = None,
file_data: Optional[bytes] = None,
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
if all([x is None for x in [file_path, path_type, file_data]]):
raise Exception("All input fields to read_geotiff are None")
write_to_file: Optional[bytes] = None
path: Optional[str] = None
if file_data is not None:
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data)
# path = tmpfile.name
write_to_file = file_data
elif file_path is not None and path_type == "url":
resp = urllib.request.urlopen(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(resp.read())
# path = tmpfile.name
write_to_file = resp.read()
elif file_path is not None and path_type == "path":
path = file_path
elif file_path is not None and path_type == "b64_json":
image_data = base64.b64decode(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(image_data)
# path = tmpfile.name
write_to_file = image_data
else:
raise Exception("Wrong combination of parameters to read_geotiff")
with tempfile.NamedTemporaryFile() as tmpfile:
path_to_use = None
if write_to_file:
tmpfile.write(write_to_file)
path_to_use = tmpfile.name
elif path:
path_to_use = path
with rasterio.open(path_to_use) as src:
img = src.read()
meta = src.meta
try:
coords = src.lnglat()
except Exception:
# Cannot read coords
coords = None
return img, meta, coords
def load_image(
data: Union[list[str]],
path_type: str,
mean: Optional[list[float]] = None,
std: Optional[list[float]] = None,
indices: Optional[Union[list[int], None]] = None,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs = []
metas = []
temporal_coords = []
location_coords = []
for file in data:
# if isinstance(file, bytes):
# img, meta, coords = read_geotiff(file_data=file)
# else:
img, meta, coords = read_geotiff(file_path=file, path_type=path_type)
# Rescaling (don't normalize on nodata)
img = np.moveaxis(img, 0, -1) # channels last for rescaling
if indices is not None:
img = img[..., indices]
if mean is not None and std is not None:
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
imgs.append(img)
metas.append(meta)
if coords is not None:
location_coords.append(coords)
try:
match = re.search(r"(\d{7,8}T\d{6})", file)
if match:
year = int(match.group(1)[:4])
julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = (datetime.datetime.strptime(
julian_day, "%m%d").timetuple().tm_yday)
temporal_coords.append([year, julian_day])
except Exception:
logger.exception("Could not extract timestamp for %s", file)
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform"],
)
self.img_size = 512
self.h1 = 1
self.w1 = 1
self.original_h = 512
self.original_w = 512
self.batch_size = 1
self.meta_data = None
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError(
"missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt)
if request_id:
self.requests_cache[request_id] = {
"out_format": image_data["out_data_format"],
}
input_data, temporal_coords, location_coords, meta_data = load_image(
data=[image_data["data"]],
indices=self.indices,
path_type=image_data["data_format"],
)
self.meta_data = meta_data[0]
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
self.original_h, self.original_w = input_data.shape[-2:]
pad_h = (self.img_size -
(self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size -
(self.original_w % self.img_size)) % self.img_size
input_data = np.pad(
input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
mode="reflect",
)
batch = torch.tensor(input_data)
windows = batch.unfold(3, self.img_size,
self.img_size).unfold(4, self.img_size,
self.img_size)
self.h1, self.w1 = windows.shape[3:5]
windows = rearrange(
windows,
"b c t h1 w1 h w -> (b h1 w1) c t h w",
h=self.img_size,
w=self.img_size,
)
# Split into batches if number of windows > batch_size
num_batches = (windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size else 1)
windows = torch.tensor_split(windows, num_batches, dim=0)
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None
prompts = []
for window in windows:
# Apply standardization
window = self.datamodule.test_transform(
image=window.squeeze().numpy().transpose(1, 2, 0))
window = self.datamodule.aug(window)["image"]
prompts.append({
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
})
return prompts
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
out_format = self.requests_cache[request_id]["out_format"]
else:
out_format = "b64_json"
for output in model_output:
y_hat = output.outputs.data.argmax(dim=1)
pred = torch.nn.functional.interpolate(
y_hat.unsqueeze(1).float(),
size=self.img_size,
mode="nearest",
)
pred_imgs_list.append(pred)
pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0)
# Build images from patches
pred_imgs = rearrange(
pred_imgs,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
h=self.img_size,
w=self.img_size,
b=1,
c=1,
h1=self.h1,
w1=self.w1,
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
if not self.meta_data:
raise ValueError("No metadata available for the current task")
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
out_format)
return ImageRequestOutput(type=out_format,
format="tiff",
data=out_data,
request_id=request_id)
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [1, 2, 3, 8, 11, 12]
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [0, 1, 2, 3, 4, 5]

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal, Optional, TypedDict, Union
import albumentations
from pydantic import BaseModel
class DataModuleConfig(TypedDict):
bands: list[str]
batch_size: int
constant_scale: float
data_root: str
drop_last: bool
no_data_replace: float
no_label_replace: int
num_workers: int
test_transform: list[
albumentations.core.transforms_interface.BasicTransform]
class ImagePrompt(BaseModel):
data_format: Literal["b64_json", "bytes", "url"]
"""
This is the data type for the input image
"""
image_format: str
"""
This is the image format (e.g., jpeg, png, etc.)
"""
out_data_format: Literal["b64_json", "url"]
data: Any
"""
Input image data
"""
MultiModalPromptType = Union[ImagePrompt]
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]
format (str): The image format (e.g., jpeg, png, etc.)
data (Any): The resulting data.
"""
type: Literal["path", "b64_json"]
format: str
data: str
request_id: Optional[str] = None

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="prithvi_io_processor_plugin",
version="0.1",
packages=["prithvi_io_processor"],
entry_points={
"vllm.io_processor_plugins": [
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
]
},
)

View File

@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')

View File

@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
def test_loading_missing_plugin():
vllm_config = VllmConfig()
with pytest.raises(ValueError):
get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin():
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
def server():
args = [
"--runner",
"pooling",
"--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff_valencia"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
):
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": model_name,
}
ret = requests.post(
server.url_for("pooling"),
json=request_payload_url,
)
response = ret.json()
# verify the request response is in the correct format
assert (parsed_response := IOProcessorResponse(**response))
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(
plugin_data.get(attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"])

View File

@ -7,6 +7,15 @@ import torch
from vllm.plugins import load_general_plugins
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def test_platform_plugins():
# simulate workload by running an example
import runpy

View File

@ -501,6 +501,8 @@ class ModelConfig:
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin: Optional[str] = None
"""IOProcessor plugin name to load at model startup"""
def compute_hash(self) -> str:
"""

View File

@ -364,6 +364,7 @@ class EngineArgs:
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
io_processor_plugin: Optional[str] = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
enable_lora: bool = False
@ -577,6 +578,8 @@ class EngineArgs:
**model_kwargs["override_attention_dtype"])
model_group.add_argument("--logits-processors",
**model_kwargs["logits_processors"])
model_group.add_argument("--io-processor-plugin",
**model_kwargs["io_processor_plugin"])
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
@ -993,6 +996,7 @@ class EngineArgs:
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
io_processor_plugin=self.io_processor_plugin,
)
def validate_tensorizer_args(self):

View File

@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -267,6 +268,9 @@ class EngineClient(ABC):
"""Get the appropriate tokenizer for the request"""
...
async def get_io_processor(self) -> IOProcessor:
raise NotImplementedError
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...

View File

@ -37,13 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
@ -284,6 +286,11 @@ class LLM:
self.supported_tasks = supported_tasks
# Load the Input/Output processor plugin if any
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
@ -833,7 +840,7 @@ class LLM:
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
@ -915,6 +922,22 @@ class LLM:
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
# Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_request(prompts)
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
self._validate_and_add_requests(
prompts=prompts,
params=pooling_params,
@ -923,8 +946,24 @@ class LLM:
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput)
if io_processor_prompt:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(
model_output=model_outputs)
return [
PoolingRequestOutput[Any](request_id="",
outputs=processed_outputs,
prompt_token_ids=[],
finished=True)
]
else:
return model_outputs
def embed(
self,

View File

@ -64,6 +64,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse,
IOProcessorResponse,
LoadLoRAAdapterRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
@ -795,7 +796,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
elif isinstance(generator, PoolingResponse):
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@ -1782,7 +1783,7 @@ async def init_app_state(
) if "generate" in supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
vllm_config,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,

View File

@ -6,7 +6,8 @@
import json
import time
from http import HTTPStatus
from typing import Annotated, Any, ClassVar, Literal, Optional, Union
from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional,
TypeVar, Union)
import regex as re
import torch
@ -1405,7 +1406,46 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
T = TypeVar("T")
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
model: Optional[str] = None
priority: int = Field(default=0)
"""
The priority of the request (lower means earlier handling;
default: 0). Any priority other than 0 will raise an error
if the served model does not use priority scheduling.
"""
data: T
"""
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
"""
def to_pooling_params(self):
return PoolingParams(task="encode")
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: Optional[str] = None
"""
The request_id associated with this response
"""
created_at: int = Field(default_factory=lambda: int(time.time()))
data: T
"""
When using plugins IOProcessor plugins, the actual output is generated
by the plugin itself. Hence, we use a generic type for the response data
"""
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest,
IOProcessorRequest]
class ScoreRequest(OpenAIBaseModel):

View File

@ -49,9 +49,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse, PoolingResponse,
RerankRequest, ResponsesRequest,
ScoreRequest, ScoreResponse,
ErrorResponse,
IOProcessorRequest,
PoolingResponse, RerankRequest,
ResponsesRequest, ScoreRequest,
ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
@ -89,7 +91,7 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
ResponsesRequest]
ResponsesRequest, IOProcessorRequest]
AnyResponse = Union[
CompletionResponse,

View File

@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import Final, Literal, Optional, Union, cast
import jinja2
@ -13,19 +13,25 @@ import torch
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
vllm_config: VllmConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
model_config=vllm_config.model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
if isinstance(request, PoolingChatRequest):
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
else:
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
request.input,
add_special_tokens=request.add_special_tokens,
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id,
created_time,
model_name,
encoding_format,
request.encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
@ -18,6 +18,7 @@ target model.
"""
__all__ = [
"DataPrompt",
"TextPrompt",
"TokensPrompt",
"PromptType",

View File

@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:

View File

@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
from typing import Optional
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig,
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
if model_plugin is None:
logger.info("No IOProcessor plugins requested by the model")
return None
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
# Load all installed plugin in the group
multimodal_data_processor_plugins = \
load_plugins_by_group('vllm.io_processor_plugins')
loadable_plugins = {}
for name, func in multimodal_data_processor_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
logger.warning("Failed to load plugin %s.", name, exc_info=True)
num_available_plugins = len(loadable_plugins.keys())
if num_available_plugins == 0:
raise ValueError("No IOProcessor plugins installed"
f" but one is required ({model_plugin}).")
if model_plugin not in loadable_plugins:
raise ValueError(
f"The model requires the '{model_plugin}' IO Processor plugin "
"but it is not installed. "
f"Available plugins: {list(loadable_plugins.keys())}")
activated_plugin_cls = loadable_plugins[model_plugin]
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, Optional, TypeVar, Union
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError