[VLM] Remove image_input_type from VLM config (#5852)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-02 00:57:09 -07:00
committed by GitHub
parent 2c37540aa6
commit 98d6682cd1
35 changed files with 329 additions and 751 deletions

View File

@ -8,10 +8,6 @@ set -o pipefail
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir -p images
cd images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg

View File

@ -1,13 +1,5 @@
sphinx == 6.2.1
sphinx-book-theme == 1.0.1
sphinx-copybutton == 0.5.2
myst-parser == 2.0.0
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse
# packages to install to build the documentation
pydantic
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -9,8 +9,10 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
which allows you to pass in multi-modal input alongside text and token prompts.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.
# TODO: Add more instructions on how to do that once embeddings is in.
Module Contents
+++++++++++++++
@ -29,7 +31,7 @@ Registry
Base Classes
------------
.. autoclass:: vllm.multimodal.MultiModalData
.. autoclass:: vllm.multimodal.MultiModalDataDict
:members:
:show-inheritance:

View File

@ -36,7 +36,6 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
@ -49,7 +48,12 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
.. note::
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.
.. code-block:: python
@ -61,7 +65,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {"image": image},
})
for o in outputs:
@ -93,7 +97,6 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \

View File

@ -1,38 +1,32 @@
import argparse
import os
import subprocess
import torch
from PIL import Image
from vllm import LLM
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_llava_pixel_values(*, disable_image_processor: bool = False):
def run_llava():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("images/stop_sign.jpg")
image = Image.open("images/stop_sign.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
})
for o in outputs:
@ -40,45 +34,11 @@ def run_llava_pixel_values(*, disable_image_processor: bool = False):
print(generated_text)
def run_llava_image_features():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
def main():
run_llava()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
@ -95,4 +55,4 @@ if __name__ == "__main__":
local_directory,
"--no-sign-request",
])
main(args)
main()

View File

@ -4,35 +4,44 @@ import requests
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData
# Dynamic image input is currently not supported and therefore
# a fixed image input shape and its corresponding feature size is required.
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
# configuration matrix.
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
)
prompt = "[INST] " + "<image>" * 1176 + "\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)
def run_llava_next():
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
},
sampling_params=sampling_params)
prompt = "[INST] " + "<image>" * 1176 + (
"\nWhat is shown in this image? [/INST]")
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=100)
generated_text = ""
for o in outputs:
generated_text += o.outputs[0].text
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
print(f"LLM output:{generated_text}")
generated_text = ""
for o in outputs:
generated_text += o.outputs[0].text
print(f"LLM output:{generated_text}")
if __name__ == "__main__":
run_llava_next()

View File

@ -3,7 +3,6 @@
Launch the vLLM server with the following command:
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \

View File

@ -4,7 +4,6 @@ import subprocess
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData
def run_phi3v():
@ -17,7 +16,6 @@ def run_phi3v():
llm = LLM(
model=model_path,
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
@ -35,7 +33,9 @@ def run_phi3v():
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:

View File

@ -17,20 +17,18 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
if TYPE_CHECKING:
# it will call torch.cuda.device_count()
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
@ -51,14 +49,6 @@ def _read_prompts(filename: str) -> List[str]:
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")
@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")
@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
@ -66,20 +56,8 @@ class ImageAsset:
def for_hf(self) -> Image.Image:
return self.pil_image
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)
raise NotImplementedError
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}
class _ImageAssetPrompts(TypedDict):
@ -453,7 +431,7 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
@ -502,7 +480,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)

View File

@ -39,8 +39,6 @@ def server():
"--max-model-len",
"4096",
"--enforce-eager",
"--image-input-type",
"pixel_values",
"--image-token-id",
"32000",
"--image-input-shape",

View File

@ -25,17 +25,11 @@ def iter_llava_configs(model_name: str):
}
for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape))
model_and_vl_config = [
@ -81,8 +75,8 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
@ -104,7 +98,7 @@ def run_test(
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)

View File

@ -33,16 +33,13 @@ def iter_llava_next_configs(model_name: str):
}
for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
))
model_and_vl_config = [
@ -85,14 +82,14 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,

View File

@ -27,16 +27,11 @@ def iter_phi3v_configs(model_name: str):
}
for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32044,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(image_feature_size=f,
image_token_id=32044,
image_input_shape=input_shape))
model_and_vl_config = [
@ -89,8 +84,8 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
@ -113,7 +108,7 @@ def run_test(
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]
vllm_image_prompts = [
p.replace("<|image_1|>",

View File

@ -2,9 +2,8 @@ import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData
from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
@ -12,7 +11,6 @@ from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_clip_image_processor(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, CLIPImageProcessor)
@ -25,14 +23,6 @@ def test_clip_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@ -42,7 +32,7 @@ def test_clip_image_processor(image_assets, dtype):
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
{"image": asset.pil_image},
)
assert hf_result.keys() == vllm_result.keys()
@ -60,7 +50,6 @@ def test_clip_image_processor(image_assets, dtype):
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_llava_next_image_processor(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, LlavaNextImageProcessor)
@ -73,14 +62,6 @@ def test_llava_next_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@ -90,7 +71,7 @@ def test_llava_next_image_processor(image_assets, dtype):
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
{"image": asset.pil_image},
)
assert hf_result.keys() == vllm_result.keys()
@ -107,7 +88,6 @@ def test_llava_next_image_processor(image_assets, dtype):
@pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560
model_config = ModelConfig(
model=MODEL_NAME,
@ -117,23 +97,15 @@ def test_image_pixel_types(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))
)
for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
{"image": asset.pil_image},
)
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pixel_values),
{"image": asset.pil_image},
)
assert image_result.keys() == tensor_result.keys()

View File

@ -11,7 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
@ -91,7 +91,7 @@ class AsyncLLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
multi_modal_data: Optional[MultiModalDataDict] = None,
) -> List[RequestOutput]:
if prompts is None:

View File

@ -1,20 +0,0 @@
import pytest
from transformers.image_processing_utils import BaseImageProcessor
from vllm.transformers_utils.image_processor import get_image_processor
IMAGE_PROCESSOR_NAMES = [
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-34b-hf",
]
@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES)
def test_image_processor_revision(processor_name: str):
# Assume that "main" branch always exists
image_processor = get_image_processor(processor_name, revision="main")
assert isinstance(image_processor, BaseImageProcessor)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
get_image_processor(processor_name, revision="never")

View File

@ -1250,28 +1250,11 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.")
# TODO: To be replaced by MultiModalConfig.
@dataclass
class VisionLanguageConfig:
"""Configs the input data format and how models should run for
vision language models."""
class ImageInputType(enum.Enum):
"""Image input type into the vision language model.
An image roughly goes through the following transformation:
Raw image --> pixel values --> image features --> image embeddings.
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES = enum.auto()
IMAGE_FEATURES = enum.auto()
image_input_type: ImageInputType
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
@ -1279,19 +1262,6 @@ class VisionLanguageConfig:
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
# The image processor to load from HuggingFace
image_processor: Optional[str]
image_processor_revision: Optional[str]
@classmethod
def get_image_input_enum_type(cls, value: str) -> ImageInputType:
"""Get the image input type from a string."""
try:
return cls.ImageInputType[value.upper()]
except KeyError as e:
raise ValueError(f"{value} is not a valid choice. "
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
@ -1318,8 +1288,6 @@ class VisionLanguageConfig:
else:
result[f.name] = value
result["disable_image_processor"] = self.image_processor is None
return result

View File

@ -1,7 +1,6 @@
import argparse
import dataclasses
import json
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@ -80,13 +79,9 @@ class EngineArgs:
preemption_mode: Optional[str] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
image_processor: Optional[str] = None
image_processor_revision: Optional[str] = None
disable_image_processor: bool = False
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
@ -114,14 +109,6 @@ class EngineArgs:
@staticmethod
def add_cli_args_for_vlm(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
choices=[
t.name.lower()
for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM.'))
parser.add_argument('--image-token-id',
type=int,
default=None,
@ -137,24 +124,6 @@ class EngineArgs:
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--image-processor',
type=str,
default=EngineArgs.image_processor,
help='Name or path of the huggingface image processor to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--image-processor-revision',
type=str,
default=None,
help='Revision of the huggingface image processor version to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--disable-image-processor',
action='store_true',
help='Disables the use of image processor, even if one is defined '
'for the model on huggingface.')
return parser
@ -679,33 +648,16 @@ class EngineArgs:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
if self.image_token_id is not None:
if (not self.image_input_shape or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)
self.image_processor = None
'Specify `image_input_shape` and '
'`image_feature_size` together with `image_token_id`.')
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None

View File

@ -213,15 +213,6 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args)
# Enforce pixel values as image input type for vision language models
# when serving with API server
if engine_args.image_input_type is not None and \
engine_args.image_input_type.upper() != "PIXEL_VALUES":
raise ValueError(
f"Invalid image_input_type: {engine_args.image_input_type}. "
"Only --image-input-type 'pixel_values' is supported for serving "
"vision language models with the vLLM API server.")
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)

View File

@ -26,7 +26,7 @@ from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_image,
get_full_image_text_prompt)
from vllm.outputs import RequestOutput
@ -47,7 +47,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
image_futures: List[Awaitable[ImagePixelData]] = field(
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
@ -103,7 +103,7 @@ class OpenAIServingChat(OpenAIServing):
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
image_futures: List[Awaitable[ImagePixelData]] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
vlm_config: Optional[VisionLanguageConfig] = getattr(
self.engine.engine, "vision_language_config", None)
@ -113,39 +113,34 @@ class OpenAIServingChat(OpenAIServing):
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
assert self.tokenizer is not None
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
elif len(image_futures) == 0:
assert self.tokenizer is not None
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
image_futures.append(image_future)
else:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
mm_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(mm_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if vlm_config is not None and len(image_futures):
if vlm_config is not None and len(mm_futures):
assert len(
mm_futures
) == 1, "Multiple 'image_url' input is currently not supported."
(image_token_prompt,
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
@ -171,8 +166,7 @@ class OpenAIServingChat(OpenAIServing):
else:
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages,
image_futures=image_futures)
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def _parse_chat_message_content(
self,
@ -182,10 +176,10 @@ class OpenAIServingChat(OpenAIServing):
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], image_futures=[])
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, image_futures=[])
return ChatMessageParseResult(messages=messages, mm_futures=[])
return self._parse_chat_message_content_parts(role, content)
@ -210,13 +204,13 @@ class OpenAIServingChat(OpenAIServing):
try:
conversation: List[ConversationMessage] = []
image_futures: List[Awaitable[ImagePixelData]] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)
conversation.extend(chat_parsed_result.messages)
image_futures.extend(chat_parsed_result.image_futures)
mm_futures.extend(chat_parsed_result.mm_futures)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
@ -235,15 +229,14 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
# Fetch image data
image_data: Optional[ImagePixelData] = None
mm_data: Optional[MultiModalDataDict] = None
try:
if len(image_futures):
# since we support only single image currently
assert len(image_futures) == 1
image_data = await image_futures[0]
if len(mm_futures):
# since we support only single mm data currently
assert len(mm_futures) == 1
mm_data = await mm_futures[0]
except Exception as e:
logger.error("Error in loading image data: %s", e)
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
@ -274,8 +267,8 @@ class OpenAIServingChat(OpenAIServing):
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if image_data is not None:
inputs["multi_modal_data"] = image_data
if mm_data is not None:
inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None

View File

@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
@ -72,7 +72,7 @@ class TextPrompt(TypedDict):
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
@ -85,7 +85,7 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
@ -103,7 +103,7 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
@ -128,7 +128,6 @@ class LLMInputs(TypedDict):
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
@ -137,7 +136,7 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired[Optional["MultiModalData"]]
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.

View File

@ -12,7 +12,7 @@ from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
from vllm.sequence import SequenceData
logger = init_logger(__name__)
@ -66,7 +66,8 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData", Optional["MultiModalData"]]]
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
@ -94,7 +95,7 @@ class InputRegistry:
self,
ctx: InputContext,
seq_len: int,
) -> Tuple["SequenceData", Optional["MultiModalData"]]:
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.

View File

@ -84,9 +84,8 @@ def _get_model_initialization_kwargs(
if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["vlm_config"] = vlm_config

View File

@ -12,7 +12,6 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SequenceData
@ -49,7 +48,7 @@ def dummy_seq_data_for_clip(
return SequenceData(token_ids)
def dummy_pixel_data_for_clip(
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
*,
image_width_override: Optional[int] = None,
@ -62,22 +61,7 @@ def dummy_pixel_data_for_clip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return ImagePixelData(image)
def dummy_feature_data_for_clip(
hf_config: CLIPVisionConfig,
*,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
dtype=torch.float16)
return ImageFeatureData(values)
return {"image": image}
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
@ -17,11 +17,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip)
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
@ -76,17 +75,10 @@ class LlavaImagePixelInputs(TypedDict):
"""Shape: (batch_size, num_channels, height, width)"""
class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
LlavaImageInputs = LlavaImagePixelInputs
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
@ -97,22 +89,14 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
image_token_id=hf_config.image_token_index,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(vision_config)
mm_data = dummy_image_for_clip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
@ -126,11 +110,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
self.vision_tower = None
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
@ -165,44 +146,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if pixel_values is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
)
if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")
return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(image_features),
)
return None
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@ -237,12 +192,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def forward(
@ -273,25 +224,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
This model has two modes of image inputs:
`PIXEL_VALUES` and `IMAGE_FEATURES`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
Expects a batch with shape `[1, 3, 336, 336]`.
(Only applicable to `PIXEL_VALUES` mode)
image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
"""
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -1,8 +1,8 @@
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
Union)
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
@ -21,12 +21,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip, get_clip_patch_grid_length)
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
@ -47,17 +46,7 @@ class LlavaNextImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
class LlavaNextImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
LlavaNextImageInputs = LlavaNextImagePixelInputs
def _get_llava_next_num_unpadded_features(
@ -138,20 +127,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
image_feature_size_override=image_feature_size,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
@ -159,32 +139,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
if isinstance(image, torch.Tensor):
pixel_values = image.to(ctx.model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
if isinstance(image, Image.Image):
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h))
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_mapper(ctx, data)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
@ -198,11 +172,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
@ -255,36 +225,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if pixel_values is None or image_sizes is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
"Failed to validate this at initialization time")
return None
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@ -391,11 +348,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)

View File

@ -35,10 +35,9 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
logger = init_logger(__name__)
@ -286,7 +285,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_pixel_data_for_clip(
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
@ -331,8 +330,7 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
def _image_processor(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
image: object) -> Dict[str, torch.Tensor]:
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
@ -343,13 +341,14 @@ def _image_processor(ctx: InputContext,
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
data.image = image.resize((w, h))
image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_mapper(ctx, data)
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
@ -375,14 +374,6 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type != ImageInputType.PIXEL_VALUES:
raise ValueError(
f"Unexpected image input type: {expected_input_type}."
"Phi3v only support pixel_values input currently.")
if pixel_values is not None and image_sizes is not None:
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,

View File

@ -1,4 +1,4 @@
from .base import MultiModalData, MultiModalPlugin
from .base import MultiModalDataDict, MultiModalPlugin
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -11,6 +11,8 @@ See also:
"""
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
"MultiModalRegistry"
"MultiModalPlugin",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
"MultiModalDataDict",
]

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
TypedDict, TypeVar, Union)
from vllm.config import ModelConfig
from vllm.inputs import InputContext
@ -8,38 +8,35 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from PIL import Image
from torch import nn
logger = init_logger(__name__)
class MultiModalData:
"""
Base class that contains multi-modal data.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
:class:`~MultiModalPlugin`.
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.map_input` for
the new modality.
"""
pass
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
class MultiModalDataBuiltins(TypedDict, total=False):
image: "Image.Image"
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
The data belonging to each modality is converted into keyword arguments
to the model by the corresponding mapper. By default, the mapper of
the corresponding plugin with the same modality key is applied.
"""
MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
"torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
class MultiModalPlugin(ABC, Generic[D]):
class MultiModalPlugin(ABC):
"""
Base class that defines data processing logic for a specific modality.
@ -52,19 +49,18 @@ class MultiModalPlugin(ABC, Generic[D]):
def __init__(self) -> None:
self._input_mappers: Dict[Type["nn.Module"],
MultiModalInputMapper[D]] = {}
MultiModalInputMapper] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
def get_data_key(self) -> str:
"""
Get the modality (subclass of :class:`~MultiModalData`) served by
this plugin.
Get the data key corresponding to the modality.
"""
raise NotImplementedError
@abstractmethod
def _default_input_mapper(self, ctx: InputContext,
data: D) -> Dict[str, "torch.Tensor"]:
data: object) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
@ -73,11 +69,10 @@ class MultiModalPlugin(ABC, Generic[D]):
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[D]] = None,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided function is
invoked to transform the data into a dictionary of model inputs.
@ -102,11 +97,13 @@ class MultiModalPlugin(ABC, Generic[D]):
return wrapper
def map_input(self, model_config: ModelConfig,
data: D) -> Dict[str, "torch.Tensor"]:
data: object) -> Dict[str, "torch.Tensor"]:
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs.
If the data is not something that the mapper expects, throws TypeError.
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]

View File

@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Dict, Type, Union
from typing import Dict
import torch
from PIL import Image
@ -9,105 +9,36 @@ from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from .base import MultiModalData, MultiModalPlugin
from .base import MultiModalPlugin
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
class ImagePlugin(MultiModalPlugin):
- :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:`torch.Tensor`: The raw pixel data which is passed to the model
without additional pre-processing.
"""
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
if isinstance(image, Image.Image):
# So that this class can be created inside the Image context manager
image.load()
self.image = image
def __repr__(self) -> str:
image = self.image
if isinstance(image, Image.Image):
return f"{type(self).__name__}(image={image})"
return (f"{type(self).__name__}(image=torch.Tensor(shape="
f"{image.shape}, dtype={image.dtype}))")
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def get_data_key(self) -> str:
return "image"
def _get_hf_image_processor(self, model_config: ModelConfig):
vlm_config = model_config.multimodal_config
if vlm_config is None or vlm_config.image_processor is None:
return None
return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)
model_config.model,
trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
data: object) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image = data.image
if isinstance(image, Image.Image):
if isinstance(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(image, return_tensors="pt") \
return image_processor.preprocess(data, return_tensors="pt") \
.to(model_config.dtype).data
except Exception:
logger.error("Failed to process image (%s)", image)
logger.error("Failed to process image (%s)", data)
raise
elif isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
return {"pixel_values": pixel_values}
raise TypeError(f"Invalid image type: {type(image)}")
class ImageFeatureData(MultiModalData):
"""
The feature vector of an image, passed directly to the model.
This should be the output of the vision tower.
"""
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
def __repr__(self) -> str:
image_features = self.image_features
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
f"{image_features.shape}, dtype={image_features.dtype}))")
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_mapper(
self, ctx: InputContext,
data: ImageFeatureData) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}
raise TypeError(f"Invalid type for 'image': {type(data)}")

View File

@ -1,18 +1,16 @@
import functools
from typing import Any, Optional, Sequence, Type, TypeVar
from typing import Optional, Sequence, Type, TypeVar
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
from .image import ImagePlugin
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type[nn.Module])
@ -20,81 +18,91 @@ class MultiModalRegistry:
"""
A registry to dispatch data processing
according to its modality and the target model.
The registry handles both external and internal data input.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
DEFAULT_PLUGINS = (ImagePlugin(), )
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
def register_plugin(self, plugin: MultiModalPlugin) -> None:
data_type_key = plugin.get_data_key()
if data_type in self._plugins_by_data_type:
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins_by_data_type[data_type] = plugin
self._plugins[data_type_key] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object,
model_config: ModelConfig):
plugin = self._plugins.get(key)
if plugin:
return plugin.map_input(model_config, value)
msg = f"Unknown multi-modal data type: {key}"
raise NotImplementedError(msg)
def register_input_mapper(
self,
data_type: Type[D],
mapper: Optional[MultiModalInputMapper[D]] = None,
data_type: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_mapper(mapper)
plugin = self._plugins.get(data_type)
if not plugin:
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
return plugin.register_input_mapper(mapper)
def register_image_pixel_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
):
def register_image_input(self,
mapper: Optional[MultiModalInputMapper] = None):
"""
Register an input mapper for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImagePixelData, mapper)
return self.register_input_mapper("image", mapper)
def register_image_feature_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
):
def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
"""
Register an input mapper for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImageFeatureData, mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalData):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
Apply an input mapper to the data passed to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.map_input(model_config, data)
result_list = [
self._process_input(k, v, model_config) for k, v in data.items()
]
return {k: v for d in result_list for k, v in d.items()}
def create_input_mapper(self, model_config: ModelConfig):
"""

View File

@ -8,7 +8,7 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal.base import MultiModalDataDict
class ImageFetchAiohttp:
@ -53,14 +53,10 @@ class ImageFetchAiohttp:
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
image.load()
return image
async def async_get_and_parse_image(image_url: str) -> ImagePixelData:
with await ImageFetchAiohttp.fetch_image(image_url) as image:
return ImagePixelData(image)
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""Encode a pillow image to base64 format."""
@ -91,3 +87,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}

View File

@ -14,7 +14,7 @@ from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@ -280,8 +280,8 @@ class Sequence:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs.get("multi_modal_data")
def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.get("multi_modal_data") or {}
@property
def lora_int_id(self) -> int:
@ -457,7 +457,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@ -639,7 +639,7 @@ class SequenceGroupMetadata:
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalData"] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None:

View File

@ -1,5 +1,3 @@
from typing import Optional
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
@ -12,7 +10,6 @@ def get_image_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> BaseImageProcessor:
"""Gets an image processor for the given model name via HuggingFace."""
@ -21,7 +18,6 @@ def get_image_processor(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not

View File

@ -504,7 +504,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():