[Model] MiniCPM-V/O supports V1 (#15487)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-27 21:07:29 +08:00
committed by GitHub
parent 8063dfc61a
commit ac5bc615b0
4 changed files with 573 additions and 594 deletions

View File

@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `openbmb/MiniCPM-o-2_6`, etc.
* ✅︎
* ✅︎
*
* ✅︎
- * `MiniCPMV`
* MiniCPM-V
* T + I<sup>E+</sup> + V<sup>E+</sup>
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
* ✅︎
* ✅︎
*
* ✅︎
- * `MllamaForConditionalGeneration`
* Llama 3.2
* T + I<sup>+</sup>

View File

@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
import torch
from torch import nn
@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser,
@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
_minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
audio_features: torch.Tensor
audio_features: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Padding is used therefore `audio_features` is `torch.Tensor`.
"""
audio_feature_lens: torch.Tensor
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices)`
Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]`
"""
audio_bounds: torch.Tensor
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices, 2)`
A boolean mask indicating which audio embeddings correspond
to patch tokens.
This should be in `(start, stop)` format.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: torch.Tensor
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images * num_slices, hidden_size)`
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
Length of each slice may vary, so pass it as a list.
"""
audio_bounds: torch.Tensor
"""
Shape: `(batch_size * num_audios * num_slices, 2)`
This should be in `(start, stop)` format.
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"audio": self.get_max_audio_tokens(),
"video": self.get_max_video_tokens(seq_len),
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
hf_processor = self.get_hf_processor()
return hf_processor.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def get_default_audio_pool_step(self) -> int:
return 2
@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_audio_tokens = self.get_max_audio_tokens(
) * max_audios + 2 * max_audios
max_image_tokens = self.get_max_image_tokens() * max_images
max_audio_tokens = self.get_max_audio_tokens() * max_audios
max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens -
max_audio_tokens)
@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts)
mm_data = {
"image":
processor_inputs.mm_data["image"],
"video":
processor_inputs.mm_data["video"],
audio_prompt_texts = self.info.audio_pattern * num_audios
audio_mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
audio_prompt_texts = self.info.audio_pattern * num_audios
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
audio_prompt_texts,
mm_data=mm_data)
return ProcessorInputs(
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
mm_data={
**processor_inputs.mm_data,
**audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor(
@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
return MiniCPMOMultiModalDataParser(
target_sr=self.info.get_default_audio_sampling_rate())
def get_audio_prompt_texts(self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1) -> str:
return self.info.get_hf_processor().get_audio_placeholder(
audio_lens, chunk_input, chunk_length)
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
tokenizer = self.info.get_tokenizer()
special_tokens = super().get_special_tokens()
if hasattr(tokenizer, "audio_start_id"):
special_tokens["audio_start_id"] = torch.tensor(
tokenizer.audio_start_id)
special_tokens["audio_end_id"] = torch.tensor(
tokenizer.audio_end_id)
return special_tokens
def get_audio_prompt_texts(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
return self.info.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def process_audios(
self,
@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios
}).get_items("audio", AudioProcessorItems))
}).get_items("audio",
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs, "chunk_input": True
},
out_keys={"audio_features", "audio_feature_lens"},
)
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs,
"chunk_input": True,
},
out_keys={"audio_features", "audio_feature_lens"},
)
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
)
]
audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
audio_inputs["audio_features"] = unpadded_audio_features
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
audio_len = self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
sum(map(len, single_audio_embeds)))
else:
audio_len = audios.get_audio_length(item_idx)
@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.apm = self.init_audio_module(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "apm"))
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily
audio_config = self.config.audio_config
@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
return input_lengths_after_cnn, input_lengths_after_pooling
# Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get(
"audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]]
def get_audio_hidden_states(
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
chunk_length = self.config.audio_chunk_length
if len(wavforms) == 0:
return []
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms_raw = data["audio_features"]
if isinstance(wavforms_raw, list):
B = len(wavforms_raw)
C = wavforms_raw[0].shape[-2]
L = max(item.shape[-1] for item in wavforms_raw)
device = wavforms_raw[0].device
dtype = wavforms_raw[0].dtype
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
for i, wavforms_item in enumerate(wavforms_raw):
L_item = wavforms_item.shape[-1]
wavforms[i, ..., :L_item] = wavforms_item
else:
wavforms = wavforms_raw
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = data["audio_feature_lens"]
if isinstance(audio_feature_lens_raw, torch.Tensor):
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
final_audio_embeds = list[torch.Tensor]()
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
target_audio_embeds_lst = list[torch.Tensor]()
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
target_audio_embeds_lst.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = [
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
audio_inputs, chunk_length)[0]
if audio_embeddings is None or len(audio_embeddings) == 0:
return vlm_embedding
audio_bounds = audio_inputs["audio_bounds"]
if self.config.chunk_input:
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
dtype=dtype)
audio_start_pos = 0
for bound in audio_bounds:
audio_len = bound[1] - bound[0]
vlm_embedding[bound[0]:bound[1]] = audio_embs[
audio_start_pos:audio_start_pos + audio_len, :]
audio_start_pos += audio_len
else:
for embs, bound in zip(audio_embeddings, audio_bounds):
audio_indices = torch.arange(bound[0],
bound[1],
dtype=torch.long).to(device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
"Shape mismatch: Trying to assign embeddings "
f"of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}")
vlm_embedding[audio_indices] = embs.to(dtype)
return vlm_embedding
def _get_audio_bounds(self, input_ids: torch.Tensor,
audio_start_id: torch.Tensor,
audio_end_id: torch.Tensor) -> torch.Tensor:
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
audio_start_tokens += 1
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
return torch.hstack([
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
])
def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor,
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_token_id = kwargs.pop("audio_token_id")
if audio_token_id is not None:
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}")
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
)
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
)
audio_features_flat = flatten_bn(audio_features)
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object):
image_inputs = self._parse_and_validate_image_inputs(
input_ids, **kwargs)
if not any("audio" in key for key in kwargs):
return image_inputs, None
audio_inputs = self._parse_and_validate_audio_inputs(
input_ids, **kwargs)
return image_inputs, audio_inputs
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
if audio_inputs is not None:
vlm_embeddings = self.get_embedding_with_audios(
vlm_embeddings, audio_inputs,
self.config.audio_chunk_length)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
)
return output
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("audio_features",
"audio_embeds") and "audios" not in modalities:
modalities["audios"] = self._parse_and_validate_audio_input(
**kwargs)
return modalities
def _process_audio_input(
self,
audio_input: MiniCPMOAudioInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if audio_input["type"] == "audio_embeds":
return audio_input["audio_embeds"]
return self.get_audio_hidden_states(audio_input)
def _process_multimodal_inputs(self, modalities: dict):
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
for modality in modalities:
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings

View File

@ -23,17 +23,15 @@
# limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
import torch.types
from PIL import Image
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems,
@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
CPU_DEVICE = torch.device("cpu")
RawImageType = Union[Image.Image, torch.Tensor]
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class MiniCPMVImagePixelInputs(TypedDict):
@ -86,13 +80,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices, 2)`
@ -100,23 +87,34 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
image_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images * num_slices,
image_feature_size, hidden_size)`
Shape: `(batch_size * num_images, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images * num_slices, 2)`
A boolean mask indicating which image embeddings correspond
to patch tokens.
This should be in `(start, stop)` format.
Shape: `(batch_size * num_images, num_embeds)`
"""
@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
num_images = len(pixel_values)
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
num_videos = len(video_pixel_values)
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return get_version_by_config(self.get_hf_config())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
if self.get_model_version() == (2, 6):
return {"image": None, "video": None}
else:
return {"image": None}
mm_limits["video"] = None
return mm_limits
def get_mm_max_tokens_per_item(
self,
@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
return mm_max_tokens
def get_slice_image_placeholder(
self,
image_size: ImageSize,
# For MiniCPM V/O 2.6
image_idx: int = 0,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> str:
image_processor = self.get_image_processor()
version = self.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size,
image_idx=image_idx,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
def get_num_image_tokens(
self,
image_size: ImageSize,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int:
tokenizer = self.get_tokenizer()
image_placeholders = self.get_slice_image_placeholder(
image_size,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
image_token_ids = tokenizer.encode(image_placeholders,
add_special_tokens=False)
return len(image_token_ids)
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_size)
def get_image_max_slice_num(self) -> int:
return getattr(self.get_hf_config(), "max_slice_num", 9)
def get_image_size_with_most_features(self) -> ImageSize:
image_size = getattr(self.get_hf_config(), "image_size", 448)
max_slice_num = self.get_image_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frame_tokens(self) -> int:
frame_size = self.get_video_frame_size_with_most_features()
return self.get_num_image_tokens(frame_size,
self.get_video_max_slice_num())
return self.get_num_image_tokens(
frame_size,
max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
)
def get_max_video_tokens(self, seq_len: int) -> int:
return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len)
def get_slice_query_num(self) -> int:
hf_config = self.get_hf_config()
query_num = getattr(hf_config, "query_num", 64)
return query_num
def get_max_slice_num(self) -> int:
hf_config = self.get_hf_config()
max_slice_num = getattr(hf_config, "max_slice_num", 9)
return max_slice_num
def get_sliced_grid(self, image_size: ImageSize,
max_slice_num: int) -> Tuple[int, int]:
if self.get_model_version() == (2, 6):
slice_grid = self.get_image_processor().get_sliced_grid(
image_size, max_slice_num)
else:
slice_grid = self.get_image_processor().get_sliced_grid(image_size)
return slice_grid
def get_num_image_tokens(self, image_size: ImageSize,
max_slice_num: int) -> int:
slice_grid = self.get_sliced_grid(image_size, max_slice_num)
num_tokens = self.get_slice_query_num(
) + 2 # <image>(<unk> * query_num)</image>
if slice_grid is not None:
if self.get_model_version() == (2, 6):
num_additional_tokens = 0
else:
# <slice><image>(<unk> * query_num)</image></slice>
num_additional_tokens = 2
num_tokens += ((self.get_slice_query_num() + 2) \
* slice_grid[0] * slice_grid[1]) \
+ slice_grid[1] - 1 + num_additional_tokens
return num_tokens
def get_image_slice_nums(self, image_size: torch.Tensor,
max_slice_nums: int) -> int:
grid = self.get_sliced_grid(image_size, max_slice_nums)
return 1 if grid is None else grid[0] * grid[1] + 1
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_size, self.get_max_slice_num())
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 9:1)
return self.get_default_image_sizes(self.get_max_slice_num())
def get_video_max_slice_num(self) -> int:
return 1
def get_video_frame_size_with_most_features(self) -> ImageSize:
return self.get_default_image_sizes(self.get_video_max_slice_num())
image_size = getattr(self.get_hf_config(), "image_size", 448)
max_slice_num = self.get_video_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frames(self, max_tokens: int) -> int:
num_frame_tokens = self.get_max_video_frame_tokens()
@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens)
@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return num_frames
def get_default_image_sizes(self, num_slices: int) -> ImageSize:
image_size = getattr(self.get_hf_config(), "image_size", 448)
return ImageSize(width=image_size, height=image_size * num_slices)
_I = TypeVar("_I",
bound=MiniCPMVProcessingInfo,
@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser()
def get_slice_image_placeholder(self, image_size: ImageSize,
**kwargs) -> str:
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size, **kwargs)
def get_image_prompt_texts(self,
image_size: ImageSize,
image_idx: int = 0) -> str:
return self.get_slice_image_placeholder(image_size,
image_idx=image_idx)
return self.info.get_slice_image_placeholder(
image_size,
image_idx=image_idx,
)
def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str:
return self.get_slice_image_placeholder(
return self.info.get_slice_image_placeholder(
image_size=image_size,
image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False,
) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer()
special_tokens = {
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
}
if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = tokenizer.slice_start_id
special_tokens["slice_end_id"] = tokenizer.slice_end_id
return {k: torch.tensor(v) for k, v in special_tokens.items()}
unk_token_id = tokenizer.get_vocab()["<unk>"]
return torch.tensor(input_ids) == unk_token_id
def process_images(
self,
@ -546,14 +545,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_images = (self._get_data_parser().parse_mm_data({
"image": images
}).get_items("image", ImageProcessorItems))
}).get_items("image",
(MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
return self._base_call_hf_processor(
prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
image_inputs = {}
else:
image_inputs = self._base_call_hf_processor(
prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
return image_inputs
def process_videos(
self,
@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_videos = (self._get_data_parser().parse_mm_data({
"video": videos
}).get_items("video", VideoProcessorItems))
}).get_items("video",
(MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))
max_slice_num = self.info.get_video_max_slice_num()
if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
video_inputs = {}
else:
video_inputs = self._base_call_hf_processor(
prompts=[
self.info.image_pattern * len(video)
for video in parsed_videos
],
mm_data={"images": list(parsed_videos)},
mm_kwargs={
**mm_kwargs,
"max_slice_nums":
self.info.get_video_max_slice_num(),
},
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
video_inputs = self._base_call_hf_processor(
prompts=[
self.info.image_pattern * len(video) for video in parsed_videos
],
mm_data={"images": list(parsed_videos)},
mm_kwargs={
**mm_kwargs, "max_slice_nums": max_slice_num
},
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
frame_sizes = [
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
]
num_frames = [
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
]
video_repl_features = [
self.get_video_prompt_texts(size, nframes)
for size, nframes in zip(frame_sizes, num_frames)
]
return {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video)>./</\1>\)"
embed_is_patch = [
self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
return video_inputs
def process_mm_inputs(
self,
@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_kwargs: Mapping[str, object],
*,
out_keys: set[str],
) -> Mapping[str, NestedTensors]:
) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor(
@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
return BatchFeature({
"input_ids":
torch.tensor([tokenizer.encode(prompt)]),
"input_ids": input_ids,
**mm_inputs,
})
@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]:
return _minicpmv_field_config(hf_inputs)
def apply(
self,
prompt: Union[str, List[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
mm_orders = {
f"{modality}_orders":
torch.tensor(
[index for index, m in enumerate(matches) if m == modality])
for modality in self.info.get_supported_mm_limits()
}
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
# Exclude <image_id>x</image_id> from placeholders
if "image" in result["mm_placeholders"] and \
self.info.get_model_version() == (2, 6):
result["mm_placeholders"]["image"] = [
PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
length=p["length"] - 3 - idx // 10)
for idx, p in enumerate(result["mm_placeholders"]["image"])
]
result["mm_kwargs"].update(**mm_orders)
result["mm_kwargs"].update(**self.get_special_tokens())
return result
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsV0Only):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix=maybe_prefix(
prefix, "resampler"))
self.mm_token_ids = set[int]()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
return get_sampler()
def get_embedding_with_vision(
def _parse_and_validate_vision_input(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> torch.Tensor:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None:
return vlm_embedding
if image_inputs["type"] == "image_embeds":
vision_hidden_states = image_inputs["image_embeds"].to(
device=vlm_embedding.device,
dtype=vlm_embedding.dtype,
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])
image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
image_end_tokens, = torch.where(end_cond)
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
return torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
])
def _parse_and_validate_image_inputs(
self,
input_ids: torch.Tensor,
modality: str,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
image_keys = {"pixel_values", "tgt_sizes"}
pixel_data = {
"image": {
key: kwargs.pop(key, None)
for key in image_keys
},
"video": {
key: kwargs.pop("video_" + key, None)
for key in image_keys
}
}
embed_data = {
"image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
all_pixel_data = [
v for vs in pixel_data.values() for v in vs.values()
if v is not None
]
all_embed_data = [v for v in embed_data.values() if v is not None]
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
if pixel_values is None and image_embeds is None:
return None
im_start_id = kwargs.pop("im_start_id")
if not isinstance(im_start_id, torch.Tensor):
raise ValueError("Incorrect type of im_start_id. "
f"Got type: {type(im_start_id)}")
image_token_id = kwargs.pop("image_token_id")
if image_token_id is not None:
assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item())
im_end_id = kwargs.pop("im_end_id")
if not isinstance(im_end_id, torch.Tensor):
raise ValueError("Incorrect type of im_end_id. "
f"Got type: {type(im_end_id)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
slice_start_id = kwargs.pop("slice_start_id", None)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
embed_is_patch = flatten_bn(embed_is_patch)
slice_end_id = kwargs.pop("slice_end_id", None)
if slice_end_id is not None and not isinstance(slice_end_id,
torch.Tensor):
raise ValueError("Incorrect type of slice_end_id. "
f"Got type: {type(slice_end_id)}")
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image_embeds for {modality=}. "
f"Got type: {type(image_embeds)}")
if len(all_embed_data) > 0:
if len(all_embed_data) > 1:
raise ValueError("Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not "
"exist simultaneously.")
vision_embeds, = all_embed_data
if not isinstance(vision_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of vision_embeds. "
f"Got type: {type(vision_embeds)}")
image_embeds_flat = flatten_bn(image_embeds)
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
concat=True),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
image_embeds=image_embeds_flat,
embed_is_patch=embed_is_patch,
)
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
for modality in ("image", "video"):
modality_orders = kwargs.pop(f"{modality}_orders", None)
if modality_orders is not None:
if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(pixel_values)}")
order_data[modality] = modality_orders
tgt_sizes = kwargs.pop("tgt_sizes")
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(tgt_sizes)}")
batch_sizes = {
modality: len(modality_orders)
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
num_slices = [[len(p) for p in ps] for ps in pixel_values]
num_slices_flat = flatten_bn(torch.tensor(num_slices))
pixel_values_flat = list[torch.Tensor]()
tgt_sizes_flat = list[torch.Tensor]()
for b in range(batch_size):
mm_orders_b = [(idx_b.item(), modality)
for modality, modality_orders in order_data.items()
for idx_b in modality_orders[b]]
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
modality_pixel_data = pixel_data[modality]
modality_pixel_values = modality_pixel_data["pixel_values"]
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(modality_pixel_values)}")
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}")
if len(pixel_values_flat) == 0:
return None
return MiniCPMVImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
tgt_sizes=tgt_sizes_flat,
embed_is_patch=embed_is_patch,
num_slices=num_slices_flat,
)
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object):
return self._parse_and_validate_image_inputs(input_ids, **kwargs)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_vision_input(
"images", **kwargs)
if input_key in ("video_pixel_values",
"video_embeds") and "videos" not in modalities:
def _image_key(video_key: str):
if video_key == "video_token_id":
return "image_token_id"
return video_key.removeprefix("video_")
modalities["videos"] = self._parse_and_validate_vision_input(
"videos", **{
_image_key(k): v
for k, v in kwargs.items()
})
return modalities
def _process_vision_input(
self,
image_input: MiniCPMVImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
image_features_flat = self.get_vision_hidden_states(image_input)
# Reconstruct the batch dimension
return image_features_flat.split(image_input["num_slices"].tolist())
def _process_multimodal_inputs(self, modalities: dict):
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple(
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
if modality == "videos":
video_input = modalities["videos"]
video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple(
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
return multimodal_embeddings
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return self._process_multimodal_inputs(modalities)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
list(self.mm_token_ids),
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
inputs_embeds = None
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
output = self.llm.model(
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.llm.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
inputs_embeds=inputs_embeds,
)
return output
return hidden_states
def compute_logits(
self,
@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self,
embed_dim: int,
vision_dim: int,

View File

@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
@dataclass
@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs(
images=images,
@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(image_features, feat_is_patch)
feats[f_is_patch] for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]
def get_multimodal_embeddings(