mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] support input image embedding for minicpmv (#9237)
This commit is contained in:
@ -378,7 +378,7 @@ Text Generation
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- Image\ :sup:`+`
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
|
@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
|
||||
print(generated_text)
|
||||
|
||||
# Inference with image embeddings as input with additional parameters
|
||||
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
|
||||
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
||||
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
|
||||
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
|
||||
mm_data = {}
|
||||
|
||||
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
|
||||
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
|
||||
mm_data['image'] = {
|
||||
"image_embeds": image_embeds,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
|
||||
}
|
||||
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
|
||||
mm_data['image'] = {
|
||||
"image_embeds": image_embeds,
|
||||
"image_size_list": [image.size] # list of image sizes
|
||||
}
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
|
@ -24,8 +24,8 @@
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
||||
TypedDict)
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
}
|
||||
|
||||
RawImageType = Union[Image.Image, torch.Tensor]
|
||||
|
||||
class MiniCPMVImageInput(TypedDict):
|
||||
|
||||
class MiniCPMVRawImageInput(TypedDict):
|
||||
"""Input mapper input with auxiliary data for computing image bounds."""
|
||||
image: Image.Image
|
||||
image: RawImageType
|
||||
|
||||
# Image bounds token ids in 0-dim scaler tensor.
|
||||
im_start_id: torch.Tensor
|
||||
@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
|
||||
|
||||
|
||||
class MiniCPMVImagePixelInputs(TypedDict):
|
||||
pixel_values: List[torch.Tensor]
|
||||
type: Literal["pixel_values"]
|
||||
data: List[torch.Tensor]
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
"""
|
||||
|
||||
|
||||
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
||||
MiniCPMVImageEmbeddingInputs]
|
||||
|
||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
|
||||
@ -194,22 +218,22 @@ class Resampler2_5(BaseResampler):
|
||||
|
||||
|
||||
def _build_image_input(ctx: InputContext,
|
||||
image: Image.Image) -> MiniCPMVImageInput:
|
||||
image: RawImageType) -> MiniCPMVRawImageInput:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
if hasattr(tokenizer, "slice_start_id"):
|
||||
return MiniCPMVImageInput(
|
||||
return MiniCPMVRawImageInput(
|
||||
image=image,
|
||||
im_start_id=torch.tensor(tokenizer.im_start_id),
|
||||
im_end_id=torch.tensor(tokenizer.im_end_id),
|
||||
slice_start_id=torch.tensor(tokenizer.slice_start_id),
|
||||
slice_end_id=torch.tensor(tokenizer.slice_end_id))
|
||||
else:
|
||||
return MiniCPMVImageInput(image=image,
|
||||
im_start_id=torch.tensor(
|
||||
tokenizer.im_start_id),
|
||||
im_end_id=torch.tensor(tokenizer.im_end_id))
|
||||
return MiniCPMVRawImageInput(
|
||||
image=image,
|
||||
im_start_id=torch.tensor(tokenizer.im_start_id),
|
||||
im_end_id=torch.tensor(tokenizer.im_end_id))
|
||||
|
||||
|
||||
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
pattern = "(<image>./</image>)"
|
||||
images = multi_modal_data["image"]
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
image_tags = re.findall(pattern, prompt)
|
||||
|
||||
if len(image_tags) == 0:
|
||||
new_token_ids = token_ids
|
||||
new_prompt = prompt
|
||||
else:
|
||||
if isinstance(images, dict):
|
||||
image_size_list = images.get("image_size_list")
|
||||
images = [images.get("image_embeds")]
|
||||
else:
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
image_size_list = [image.size for image in images]
|
||||
|
||||
text_chunks = prompt.split(pattern)
|
||||
new_prompt_chunks: List[str] = []
|
||||
for i in range(len(images)):
|
||||
for i in range(len(image_size_list)):
|
||||
new_prompt_chunks += [
|
||||
text_chunks[i],
|
||||
get_placeholder(images[i].size, i)
|
||||
get_placeholder(image_size_list[i], i)
|
||||
]
|
||||
new_prompt_chunks.append(text_chunks[-1])
|
||||
new_prompt = "".join(new_prompt_chunks)
|
||||
@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(
|
||||
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
|
||||
batch_data = image_processor \
|
||||
.preprocess([img["image"] for img in data], return_tensors="pt") \
|
||||
.data
|
||||
|
||||
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
|
||||
batch_data = {
|
||||
"image_embeds": data[0]['image'],
|
||||
}
|
||||
else:
|
||||
batch_data = image_processor \
|
||||
.preprocess([img["image"] for img in data], return_tensors="pt") \
|
||||
.data
|
||||
|
||||
if len(data) > 0:
|
||||
batch_data["im_start_id"] = data[0]["im_start_id"]
|
||||
@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImagePixelInputs],
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_inputs is None: # No image
|
||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = (image_inputs["data"].type(
|
||||
vlm_embedding.dtype).to(vlm_embedding.device))
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(
|
||||
image_inputs)
|
||||
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
**kwargs: object,
|
||||
) -> Optional[MiniCPMVImagePixelInputs]:
|
||||
) -> Optional[MiniCPMVImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", [])
|
||||
tgt_sizes = kwargs.pop("tgt_sizes", [])
|
||||
im_start_id = kwargs.pop("im_start_id", None)
|
||||
im_end_id = kwargs.pop("im_end_id", None)
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if image_embeds is not None:
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
data=image_embeds,
|
||||
type="image_embeds",
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if len(pixel_values_flat) == 0:
|
||||
return None
|
||||
|
||||
im_start_id = kwargs.pop("im_start_id", None)
|
||||
im_end_id = kwargs.pop("im_end_id", None)
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
if im_start_id is None:
|
||||
return None
|
||||
|
||||
@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
pixel_values=pixel_values_flat,
|
||||
data=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
type="pixel_values",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
res.append(self.resampler(vision_embedding, tgt_size))
|
||||
return torch.vstack(res)
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
|
||||
return self.get_vision_embedding(pixel_values)
|
||||
|
||||
@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
device = self.vpm.embeddings.position_embedding.weight.device
|
||||
@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
device = self.vpm.embeddings.position_embedding.weight.device
|
||||
|
Reference in New Issue
Block a user