[Model] support input image embedding for minicpmv (#9237)

This commit is contained in:
whyiug
2024-10-10 23:00:47 +08:00
committed by GitHub
parent 07c11cf4d4
commit 04de9057ab
3 changed files with 101 additions and 43 deletions

View File

@ -378,7 +378,7 @@ Text Generation
- ✅︎
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- Image\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎
- ✅︎

View File

@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
mm_data = {}
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": image_grid_thw,
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
}
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_size_list": [image.size] # list of image sizes
}
outputs = llm.generate({
"prompt": prompt,

View File

@ -24,8 +24,8 @@
import math
import re
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict)
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)
import torch
import torch.types
@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
}
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImageInput(TypedDict):
class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image
image: RawImageType
# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
type: Literal["pixel_values"]
data: List[torch.Tensor]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
@ -194,22 +218,22 @@ class Resampler2_5(BaseResampler):
def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput:
image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput(
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVImageInput(image=image,
im_start_id=torch.tensor(
tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
pattern = "(<image>./</image>)"
images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]
text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = []
for i in range(len(images)):
for i in range(len(image_size_list)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(images[i].size, i)
get_placeholder(image_size_list[i], i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImagePixelInputs],
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImagePixelInputs]:
) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if len(pixel_values_flat) == 0:
return None
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None:
return None
@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
pixel_values=pixel_values_flat,
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)
def forward(
@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
return self.get_vision_embedding(pixel_values)
@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
)
return vision_embedding
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device