mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] IO Processor plugins for pooling models (#22820)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@ -769,6 +769,11 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
# end platform plugin tests
|
||||
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# end io_processor plugins test
|
||||
# other tests continue here:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
|
78
docs/design/io_processor_plugins.md
Normal file
78
docs/design/io_processor_plugins.md
Normal file
@ -0,0 +1,78 @@
|
||||
# IO Processor Plugins
|
||||
|
||||
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
|
||||
|
||||
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
|
||||
|
||||
## Writing an IO Processor Plugin
|
||||
|
||||
IO Processor plugins implement the `IOProcessor` interface (<gh-file:vllm/plugins/io_processors/interface.py>):
|
||||
|
||||
```python
|
||||
IOProcessorInput = TypeVar('IOProcessorInput')
|
||||
IOProcessorOutput = TypeVar('IOProcessorOutput')
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def post_process(self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs) -> IOProcessorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
raise NotImplementedError
|
||||
```
|
||||
|
||||
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
|
||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||
|
||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py).
|
||||
|
||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples.
|
||||
|
||||
## Using an IO Processor plugin
|
||||
|
||||
IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded:
|
||||
|
||||
1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode.
|
||||
2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json).
|
||||
|
||||
The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json).
|
@ -49,6 +49,8 @@ Every plugin has three parts:
|
||||
|
||||
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
|
||||
|
||||
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name.
|
||||
|
||||
## Guidelines for Writing Plugins
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
# This example shows how to perform an offline inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Reuirement - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM.
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff_india",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
pooler_output = llm.encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
print(output)
|
||||
decoded_data = base64.b64decode(output.data)
|
||||
|
||||
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(decoded_data)
|
||||
|
||||
print(f"Output file path: {file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
54
examples/online_serving/prithvi_geospatial_mae.py
Normal file
54
examples/online_serving/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
# This example shows how to perform an online inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Reuirements :
|
||||
# - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --task embed --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin prithvi_to_tiff_india
|
||||
|
||||
|
||||
def main():
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
server_endpoint = "http://localhost:8000/pooling"
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
}
|
||||
|
||||
ret = requests.post(server_endpoint, json=request_payload_url)
|
||||
|
||||
print(f"response.status_code: {ret.status_code}")
|
||||
print(f"response.reason:{ret.reason}")
|
||||
|
||||
response = ret.json()
|
||||
|
||||
decoded_image = base64.b64decode(response["data"]["data"])
|
||||
|
||||
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
|
||||
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(decoded_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1120,6 +1120,9 @@ class VllmRunner:
|
||||
|
||||
return self.llm.llm_engine.collective_rpc(_apply_model)
|
||||
|
||||
def get_llm(self) -> LLM:
|
||||
return self.llm
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def register_prithvi_india():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
|
||||
|
||||
|
||||
def register_prithvi_valencia():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501
|
@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
|
||||
IOProcessorResponse)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (IOProcessor,
|
||||
IOProcessorInput,
|
||||
IOProcessorOutput)
|
||||
|
||||
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
datamodule_config: DataModuleConfig = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size":
|
||||
16,
|
||||
"constant_scale":
|
||||
0.0001,
|
||||
"data_root":
|
||||
"/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last":
|
||||
True,
|
||||
"no_data_replace":
|
||||
0.0,
|
||||
"no_label_replace":
|
||||
-1,
|
||||
"num_workers":
|
||||
8,
|
||||
"test_transform": [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def save_geotiff(image: torch.Tensor, meta: dict,
|
||||
out_format: str) -> str | bytes:
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
image: np.ndarray with shape (bands, height, width)
|
||||
output_path: path where to save the image
|
||||
meta: dict with meta info.
|
||||
"""
|
||||
if out_format == "path":
|
||||
# create temp file
|
||||
file_path = os.path.join(os.getcwd(), "prediction.tiff")
|
||||
with rasterio.open(file_path, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
return file_path
|
||||
elif out_format == "b64_json":
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
with rasterio.open(tmpfile.name, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
file_data = tmpfile.read()
|
||||
return base64.b64encode(file_data)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown output format")
|
||||
|
||||
|
||||
def _convert_np_uint8(float_image: torch.Tensor):
|
||||
image = float_image.numpy() * 255.0
|
||||
image = image.astype(dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def read_geotiff(
|
||||
file_path: Optional[str] = None,
|
||||
path_type: Optional[str] = None,
|
||||
file_data: Optional[bytes] = None,
|
||||
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
|
||||
"""Read all bands from *file_path* and return image + meta info.
|
||||
|
||||
Args:
|
||||
file_path: path to image file.
|
||||
|
||||
Returns:
|
||||
np.ndarray with shape (bands, height, width)
|
||||
meta info dict
|
||||
"""
|
||||
|
||||
if all([x is None for x in [file_path, path_type, file_data]]):
|
||||
raise Exception("All input fields to read_geotiff are None")
|
||||
write_to_file: Optional[bytes] = None
|
||||
path: Optional[str] = None
|
||||
if file_data is not None:
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(file_data)
|
||||
# path = tmpfile.name
|
||||
|
||||
write_to_file = file_data
|
||||
elif file_path is not None and path_type == "url":
|
||||
resp = urllib.request.urlopen(file_path)
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(resp.read())
|
||||
# path = tmpfile.name
|
||||
write_to_file = resp.read()
|
||||
elif file_path is not None and path_type == "path":
|
||||
path = file_path
|
||||
elif file_path is not None and path_type == "b64_json":
|
||||
image_data = base64.b64decode(file_path)
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(image_data)
|
||||
# path = tmpfile.name
|
||||
write_to_file = image_data
|
||||
else:
|
||||
raise Exception("Wrong combination of parameters to read_geotiff")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
path_to_use = None
|
||||
if write_to_file:
|
||||
tmpfile.write(write_to_file)
|
||||
path_to_use = tmpfile.name
|
||||
elif path:
|
||||
path_to_use = path
|
||||
|
||||
with rasterio.open(path_to_use) as src:
|
||||
img = src.read()
|
||||
meta = src.meta
|
||||
try:
|
||||
coords = src.lnglat()
|
||||
except Exception:
|
||||
# Cannot read coords
|
||||
coords = None
|
||||
|
||||
return img, meta, coords
|
||||
|
||||
|
||||
def load_image(
|
||||
data: Union[list[str]],
|
||||
path_type: str,
|
||||
mean: Optional[list[float]] = None,
|
||||
std: Optional[list[float]] = None,
|
||||
indices: Optional[Union[list[int], None]] = None,
|
||||
):
|
||||
"""Build an input example by loading images in *file_paths*.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the
|
||||
images in *file_paths*.
|
||||
std: list containing std values for each band in the
|
||||
images in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
list of meta info for each image in *file_paths*
|
||||
"""
|
||||
|
||||
imgs = []
|
||||
metas = []
|
||||
temporal_coords = []
|
||||
location_coords = []
|
||||
|
||||
for file in data:
|
||||
# if isinstance(file, bytes):
|
||||
# img, meta, coords = read_geotiff(file_data=file)
|
||||
# else:
|
||||
img, meta, coords = read_geotiff(file_path=file, path_type=path_type)
|
||||
# Rescaling (don't normalize on nodata)
|
||||
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||
if indices is not None:
|
||||
img = img[..., indices]
|
||||
if mean is not None and std is not None:
|
||||
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||
|
||||
imgs.append(img)
|
||||
metas.append(meta)
|
||||
if coords is not None:
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r"(\d{7,8}T\d{6})", file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split("T")[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = (datetime.datetime.strptime(
|
||||
julian_day, "%m%d").timetuple().tm_yday)
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception:
|
||||
logger.exception("Could not extract timestamp for %s", file)
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config["data_root"],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"],
|
||||
)
|
||||
self.img_size = 512
|
||||
self.h1 = 1
|
||||
self.w1 = 1
|
||||
self.original_h = 512
|
||||
self.original_w = 512
|
||||
self.batch_size = 1
|
||||
self.meta_data = None
|
||||
self.requests_cache: dict[str, dict[str, Any]] = {}
|
||||
self.indices = DEFAULT_INPUT_INDICES
|
||||
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
if type(request) is dict:
|
||||
image_prompt = ImagePrompt(**request)
|
||||
return image_prompt
|
||||
if isinstance(request, IOProcessorRequest):
|
||||
if not hasattr(request, "data"):
|
||||
raise ValueError(
|
||||
"missing 'data' field in OpenAIBaseModel Request")
|
||||
|
||||
request_data = request.data
|
||||
|
||||
if type(request_data) is dict:
|
||||
return ImagePrompt(**request_data)
|
||||
else:
|
||||
raise ValueError("Unable to parse the request data")
|
||||
|
||||
raise ValueError("Unable to parse request")
|
||||
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
return IOProcessorResponse(
|
||||
request_id=plugin_output.request_id,
|
||||
data=plugin_output,
|
||||
)
|
||||
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
|
||||
image_data = dict(prompt)
|
||||
|
||||
if request_id:
|
||||
self.requests_cache[request_id] = {
|
||||
"out_format": image_data["out_data_format"],
|
||||
}
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_image(
|
||||
data=[image_data["data"]],
|
||||
indices=self.indices,
|
||||
path_type=image_data["data_format"],
|
||||
)
|
||||
|
||||
self.meta_data = meta_data[0]
|
||||
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
self.original_h, self.original_w = input_data.shape[-2:]
|
||||
pad_h = (self.img_size -
|
||||
(self.original_h % self.img_size)) % self.img_size
|
||||
pad_w = (self.img_size -
|
||||
(self.original_w % self.img_size)) % self.img_size
|
||||
input_data = np.pad(
|
||||
input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
mode="reflect",
|
||||
)
|
||||
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, self.img_size,
|
||||
self.img_size).unfold(4, self.img_size,
|
||||
self.img_size)
|
||||
self.h1, self.w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
windows,
|
||||
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||
h=self.img_size,
|
||||
w=self.img_size,
|
||||
)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = (windows.shape[0] // self.batch_size
|
||||
if windows.shape[0] > self.batch_size else 1)
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
prompts = []
|
||||
for window in windows:
|
||||
# Apply standardization
|
||||
window = self.datamodule.test_transform(
|
||||
image=window.squeeze().numpy().transpose(1, 2, 0))
|
||||
window = self.datamodule.aug(window)["image"]
|
||||
prompts.append({
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": {
|
||||
"pixel_values": window.to(torch.float16)[0],
|
||||
"location_coords": location_coords.to(torch.float16),
|
||||
},
|
||||
})
|
||||
|
||||
return prompts
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
|
||||
pred_imgs_list = []
|
||||
|
||||
if request_id and (request_id in self.requests_cache):
|
||||
out_format = self.requests_cache[request_id]["out_format"]
|
||||
else:
|
||||
out_format = "b64_json"
|
||||
|
||||
for output in model_output:
|
||||
y_hat = output.outputs.data.argmax(dim=1)
|
||||
pred = torch.nn.functional.interpolate(
|
||||
y_hat.unsqueeze(1).float(),
|
||||
size=self.img_size,
|
||||
mode="nearest",
|
||||
)
|
||||
pred_imgs_list.append(pred)
|
||||
|
||||
pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0)
|
||||
|
||||
# Build images from patches
|
||||
pred_imgs = rearrange(
|
||||
pred_imgs,
|
||||
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||
h=self.img_size,
|
||||
w=self.img_size,
|
||||
b=1,
|
||||
c=1,
|
||||
h1=self.h1,
|
||||
w1=self.w1,
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
|
||||
if not self.meta_data:
|
||||
raise ValueError("No metadata available for the current task")
|
||||
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
|
||||
out_format)
|
||||
|
||||
return ImageRequestOutput(type=out_format,
|
||||
format="tiff",
|
||||
data=out_data,
|
||||
request_id=request_id)
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [1, 2, 3, 8, 11, 12]
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [0, 1, 2, 3, 4, 5]
|
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
|
||||
import albumentations
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DataModuleConfig(TypedDict):
|
||||
bands: list[str]
|
||||
batch_size: int
|
||||
constant_scale: float
|
||||
data_root: str
|
||||
drop_last: bool
|
||||
no_data_replace: float
|
||||
no_label_replace: int
|
||||
num_workers: int
|
||||
test_transform: list[
|
||||
albumentations.core.transforms_interface.BasicTransform]
|
||||
|
||||
|
||||
class ImagePrompt(BaseModel):
|
||||
|
||||
data_format: Literal["b64_json", "bytes", "url"]
|
||||
"""
|
||||
This is the data type for the input image
|
||||
"""
|
||||
|
||||
image_format: str
|
||||
"""
|
||||
This is the image format (e.g., jpeg, png, etc.)
|
||||
"""
|
||||
|
||||
out_data_format: Literal["b64_json", "url"]
|
||||
|
||||
data: Any
|
||||
"""
|
||||
Input image data
|
||||
"""
|
||||
|
||||
|
||||
MultiModalPromptType = Union[ImagePrompt]
|
||||
|
||||
|
||||
class ImageRequestOutput(BaseModel):
|
||||
"""
|
||||
The output data of an image request to vLLM.
|
||||
|
||||
Args:
|
||||
type (str): The data content type [path, object]
|
||||
format (str): The image format (e.g., jpeg, png, etc.)
|
||||
data (Any): The resulting data.
|
||||
"""
|
||||
|
||||
type: Literal["path", "b64_json"]
|
||||
format: str
|
||||
data: str
|
||||
request_id: Optional[str] = None
|
16
tests/plugins/prithvi_io_processor_plugin/setup.py
Normal file
16
tests/plugins/prithvi_io_processor_plugin/setup.py
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="prithvi_io_processor_plugin",
|
||||
version="0.1",
|
||||
packages=["prithvi_io_processor"],
|
||||
entry_points={
|
||||
"vllm.io_processor_plugins": [
|
||||
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
|
||||
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
|
||||
]
|
||||
},
|
||||
)
|
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
|
||||
|
||||
def test_loading_engine_with_wrong_plugin():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LLM(
|
||||
model=MODEL_NAME,
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="wrong_plugin",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--skip-tokenizer-init",
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff_valencia"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_prithvi_mae_plugin_online(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
):
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload_url,
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response))
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
plugin_data = parsed_response.data
|
||||
|
||||
assert all(
|
||||
plugin_data.get(attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(plugin_data["data"])
|
@ -7,6 +7,15 @@ import torch
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def test_platform_plugins():
|
||||
# simulate workload by running an example
|
||||
import runpy
|
||||
|
@ -501,6 +501,8 @@ class ModelConfig:
|
||||
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
||||
"""One or more logits processors' fully-qualified class names or class
|
||||
definitions"""
|
||||
io_processor_plugin: Optional[str] = None
|
||||
"""IOProcessor plugin name to load at model startup"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
@ -364,6 +364,7 @@ class EngineArgs:
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
io_processor_plugin: Optional[str] = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
@ -577,6 +578,8 @@ class EngineArgs:
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
model_group.add_argument("--logits-processors",
|
||||
**model_kwargs["logits_processors"])
|
||||
model_group.add_argument("--io-processor-plugin",
|
||||
**model_kwargs["io_processor_plugin"])
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
@ -993,6 +996,7 @@ class EngineArgs:
|
||||
model_impl=self.model_impl,
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
logits_processors=self.logits_processors,
|
||||
io_processor_plugin=self.io_processor_plugin,
|
||||
)
|
||||
|
||||
def validate_tensorizer_args(self):
|
||||
|
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -267,6 +268,9 @@ class EngineClient(ABC):
|
||||
"""Get the appropriate tokenizer for the request"""
|
||||
...
|
||||
|
||||
async def get_io_processor(self) -> IOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
...
|
||||
|
@ -37,13 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||
log_non_default_args)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
@ -284,6 +286,11 @@ class LLM:
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
# Load the Input/Output processor plugin if any
|
||||
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
||||
io_processor_plugin)
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -833,7 +840,7 @@ class LLM:
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
@ -915,6 +922,22 @@ class LLM:
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
io_processor_prompt = False
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
io_processor_prompt = True
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
validated_prompt = self.io_processor.parse_request(prompts)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
@ -923,8 +946,24 @@ class LLM:
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
model_outputs = self.engine_class.validate_outputs(
|
||||
outputs, PoolingRequestOutput)
|
||||
|
||||
if io_processor_prompt:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(
|
||||
model_output=model_outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](request_id="",
|
||||
outputs=processed_outputs,
|
||||
prompt_token_ids=[],
|
||||
finished=True)
|
||||
]
|
||||
else:
|
||||
return model_outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
|
@ -64,6 +64,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse,
|
||||
IOProcessorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
@ -795,7 +796,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.error.code)
|
||||
elif isinstance(generator, PoolingResponse):
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
@ -1782,7 +1783,7 @@ async def init_app_state(
|
||||
) if "generate" in supported_tasks else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
vllm_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
|
@ -6,7 +6,8 @@
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, ClassVar, Literal, Optional, Union
|
||||
from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional,
|
||||
TypeVar, Union)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -1405,7 +1406,46 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
PoolingCompletionRequest = EmbeddingCompletionRequest
|
||||
PoolingChatRequest = EmbeddingChatRequest
|
||||
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
model: Optional[str] = None
|
||||
|
||||
priority: int = Field(default=0)
|
||||
"""
|
||||
The priority of the request (lower means earlier handling;
|
||||
default: 0). Any priority other than 0 will raise an error
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual input is processed
|
||||
by the plugin itself. Hence, we use a generic type for the request data
|
||||
"""
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(task="encode")
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""
|
||||
The request_id associated with this response
|
||||
"""
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual output is generated
|
||||
by the plugin itself. Hence, we use a generic type for the response data
|
||||
"""
|
||||
|
||||
|
||||
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest,
|
||||
IOProcessorRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
|
@ -49,9 +49,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse, PoolingResponse,
|
||||
RerankRequest, ResponsesRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse, RerankRequest,
|
||||
ResponsesRequest, ScoreRequest,
|
||||
ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
@ -89,7 +91,7 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
|
||||
ResponsesRequest]
|
||||
ResponsesRequest, IOProcessorRequest]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
|
@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
@ -13,19 +13,25 @@ import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, ErrorResponse]:
|
||||
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
|
||||
if isinstance(request, PoolingChatRequest):
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
request_prompts: Sequence[RequestPrompt] = [
|
||||
""
|
||||
] * len(engine_prompts)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request.input,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
request.encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
@ -18,6 +18,7 @@ target model.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"DataPrompt",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"PromptType",
|
||||
|
@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class DataPrompt(TypedDict):
|
||||
"""Represents generic inputs handled by IO processor plugins."""
|
||||
|
||||
data: Any
|
||||
"""The input data"""
|
||||
|
||||
data_format: str
|
||||
"""The input data format"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||
"""
|
||||
Set of possible schemas for a single prompt:
|
||||
|
68
vllm/plugins/io_processors/__init__.py
Normal file
68
vllm/plugins/io_processors/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_io_processor(
|
||||
vllm_config: VllmConfig,
|
||||
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
|
||||
# Input.Output processors are loaded as plugins under the
|
||||
# 'vllm.io_processor_plugins' group. Similar to platform
|
||||
# plugins, these plugins register a function that returns the class
|
||||
# name for the processor to install.
|
||||
|
||||
if plugin_from_init:
|
||||
model_plugin = plugin_from_init
|
||||
else:
|
||||
# A plugin can be specified via the model config
|
||||
# Retrieve the model specific plugin if available
|
||||
# This is using a custom field in the hf_config for the model
|
||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
||||
config_plugin = hf_config.get("io_processor_plugin")
|
||||
model_plugin = config_plugin
|
||||
|
||||
if model_plugin is None:
|
||||
logger.info("No IOProcessor plugins requested by the model")
|
||||
return None
|
||||
|
||||
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
|
||||
|
||||
# Load all installed plugin in the group
|
||||
multimodal_data_processor_plugins = \
|
||||
load_plugins_by_group('vllm.io_processor_plugins')
|
||||
|
||||
loadable_plugins = {}
|
||||
for name, func in multimodal_data_processor_plugins.items():
|
||||
try:
|
||||
assert callable(func)
|
||||
processor_cls_qualname = func()
|
||||
if processor_cls_qualname is not None:
|
||||
loadable_plugins[name] = processor_cls_qualname
|
||||
except Exception:
|
||||
logger.warning("Failed to load plugin %s.", name, exc_info=True)
|
||||
|
||||
num_available_plugins = len(loadable_plugins.keys())
|
||||
if num_available_plugins == 0:
|
||||
raise ValueError("No IOProcessor plugins installed"
|
||||
f" but one is required ({model_plugin}).")
|
||||
|
||||
if model_plugin not in loadable_plugins:
|
||||
raise ValueError(
|
||||
f"The model requires the '{model_plugin}' IO Processor plugin "
|
||||
"but it is not installed. "
|
||||
f"Available plugins: {list(loadable_plugins.keys())}")
|
||||
|
||||
activated_plugin_cls = loadable_plugins[model_plugin]
|
||||
|
||||
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
|
62
vllm/plugins/io_processors/interface.py
Normal file
62
vllm/plugins/io_processors/interface.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
|
||||
IOProcessorInput = TypeVar('IOProcessorInput')
|
||||
IOProcessorOutput = TypeVar('IOProcessorOutput')
|
||||
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def post_process(self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs) -> IOProcessorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
raise NotImplementedError
|
Reference in New Issue
Block a user