mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] MiniCPM-V/O supports V1 (#15487)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `openbmb/MiniCPM-o-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `MiniCPMV`
|
||||
* MiniCPM-V
|
||||
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `MllamaForConditionalGeneration`
|
||||
* Llama 3.2
|
||||
* T + I<sup>+</sup>
|
||||
|
@ -23,8 +23,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
_minicpmv_field_config)
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
from .vision import scatter_patch_features
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
audio_features: torch.Tensor
|
||||
audio_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: torch.Tensor
|
||||
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices)`
|
||||
Shape: `(batch_size * num_audios, num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
audio_bounds: torch.Tensor
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: torch.Tensor
|
||||
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
||||
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
audio_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features = hf_inputs.get("audio_features", torch.empty(0))
|
||||
num_audios = len(audio_features)
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||
)
|
||||
|
||||
|
||||
@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
audio_pattern = "(<audio>./</audio>)"
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
return {**super().get_supported_mm_limits(), "audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"audio": self.get_max_audio_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
|
||||
"audio":
|
||||
self.get_max_audio_tokens(),
|
||||
}
|
||||
|
||||
def get_audio_placeholder(
|
||||
self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1,
|
||||
) -> str:
|
||||
hf_processor = self.get_hf_processor()
|
||||
|
||||
return hf_processor.get_audio_placeholder(
|
||||
audio_lens,
|
||||
chunk_input=chunk_input,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
|
||||
def get_default_audio_pool_step(self) -> int:
|
||||
return 2
|
||||
|
||||
@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
max_audios = mm_config.get_limit_per_prompt("audio")
|
||||
|
||||
# count <image_idx></image_idx> tokens
|
||||
# which are not in get_max_image_tokens
|
||||
max_image_tokens = self.get_max_image_tokens(
|
||||
) * max_images + 4 * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens(
|
||||
) * max_audios + 2 * max_audios
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens() * max_audios
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens -
|
||||
max_audio_tokens)
|
||||
@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
|
||||
|
||||
processor_inputs = super().get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
mm_data = {
|
||||
"image":
|
||||
processor_inputs.mm_data["image"],
|
||||
"video":
|
||||
processor_inputs.mm_data["video"],
|
||||
|
||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||
audio_mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||
|
||||
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
|
||||
audio_prompt_texts,
|
||||
mm_data=mm_data)
|
||||
return ProcessorInputs(
|
||||
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
|
||||
mm_data={
|
||||
**processor_inputs.mm_data,
|
||||
**audio_mm_data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MiniCPMOMultiModalProcessor(
|
||||
@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
|
||||
return MiniCPMOMultiModalDataParser(
|
||||
target_sr=self.info.get_default_audio_sampling_rate())
|
||||
|
||||
def get_audio_prompt_texts(self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1) -> str:
|
||||
return self.info.get_hf_processor().get_audio_placeholder(
|
||||
audio_lens, chunk_input, chunk_length)
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens = super().get_special_tokens()
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
special_tokens["audio_start_id"] = torch.tensor(
|
||||
tokenizer.audio_start_id)
|
||||
special_tokens["audio_end_id"] = torch.tensor(
|
||||
tokenizer.audio_end_id)
|
||||
return special_tokens
|
||||
def get_audio_prompt_texts(
|
||||
self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1,
|
||||
) -> str:
|
||||
return self.info.get_audio_placeholder(
|
||||
audio_lens,
|
||||
chunk_input=chunk_input,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
|
||||
def process_audios(
|
||||
self,
|
||||
@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||
"audio": audios
|
||||
}).get_items("audio", AudioProcessorItems))
|
||||
}).get_items("audio",
|
||||
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
|
||||
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "chunk_input": True
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||
audio_inputs = {}
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
audio_lens = [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(map(len,
|
||||
parsed_audios.get(i)["audio_embeds"])))
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
else:
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs,
|
||||
"chunk_input": True,
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
)
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
audio_lens = [
|
||||
parsed_audios.get_audio_length(i)
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
|
||||
audio_repl_features = [
|
||||
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
audio_repls_feature_tokens = [
|
||||
tokenizer.encode(audio_repl, add_special_tokens=False)
|
||||
for audio_repl in audio_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(audio_repl_tokens)
|
||||
for audio_repl_tokens in audio_repls_feature_tokens
|
||||
]
|
||||
audio_inputs["audio_embed_is_patch"] = embed_is_patch
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video|audio)>./</\1>\)"
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
|
||||
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
sum(map(len, single_audio_embeds)))
|
||||
else:
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "apm"))
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||
|
||||
# Copied from HF repo of MiniCPM-o-2_6,
|
||||
# designed for batched inputs and outputs
|
||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> list[torch.Tensor]:
|
||||
wavforms = data.get(
|
||||
"audio_features",
|
||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
||||
[])] # list, [[x1, x2], [y1], [z1]]
|
||||
def get_audio_hidden_states(
|
||||
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
|
||||
chunk_length = self.config.audio_chunk_length
|
||||
|
||||
if len(wavforms) == 0:
|
||||
return []
|
||||
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||
wavforms_raw = data["audio_features"]
|
||||
if isinstance(wavforms_raw, list):
|
||||
B = len(wavforms_raw)
|
||||
C = wavforms_raw[0].shape[-2]
|
||||
L = max(item.shape[-1] for item in wavforms_raw)
|
||||
device = wavforms_raw[0].device
|
||||
dtype = wavforms_raw[0].dtype
|
||||
|
||||
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
|
||||
for i, wavforms_item in enumerate(wavforms_raw):
|
||||
L_item = wavforms_item.shape[-1]
|
||||
wavforms[i, ..., :L_item] = wavforms_item
|
||||
else:
|
||||
wavforms = wavforms_raw
|
||||
|
||||
# list, [[x1, x2], [y1], [z1]]
|
||||
audio_feature_lens_raw = data["audio_feature_lens"]
|
||||
if isinstance(audio_feature_lens_raw, torch.Tensor):
|
||||
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
|
||||
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
final_audio_embeds = list[torch.Tensor]()
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
target_audio_embeds_lst = list[torch.Tensor]()
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
target_audio_embeds_lst.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
|
||||
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
|
||||
|
||||
return final_audio_embeds
|
||||
|
||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
||||
audio_inputs: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
||||
if audio_inputs["type"] == "audio_embeds":
|
||||
audio_embeddings = [
|
||||
item.to(device=device, dtype=dtype)
|
||||
for item in audio_inputs["audio_embeds"]
|
||||
]
|
||||
else:
|
||||
audio_embeddings = self.get_audio_hidden_states(
|
||||
audio_inputs, chunk_length)[0]
|
||||
if audio_embeddings is None or len(audio_embeddings) == 0:
|
||||
return vlm_embedding
|
||||
audio_bounds = audio_inputs["audio_bounds"]
|
||||
if self.config.chunk_input:
|
||||
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
|
||||
dtype=dtype)
|
||||
audio_start_pos = 0
|
||||
for bound in audio_bounds:
|
||||
audio_len = bound[1] - bound[0]
|
||||
vlm_embedding[bound[0]:bound[1]] = audio_embs[
|
||||
audio_start_pos:audio_start_pos + audio_len, :]
|
||||
audio_start_pos += audio_len
|
||||
else:
|
||||
for embs, bound in zip(audio_embeddings, audio_bounds):
|
||||
audio_indices = torch.arange(bound[0],
|
||||
bound[1],
|
||||
dtype=torch.long).to(device)
|
||||
|
||||
if embs.shape[0] != len(audio_indices):
|
||||
raise ValueError(
|
||||
"Shape mismatch: Trying to assign embeddings "
|
||||
f"of shape {embs.shape} "
|
||||
f"to input indices of length {len(audio_indices)}")
|
||||
vlm_embedding[audio_indices] = embs.to(dtype)
|
||||
return vlm_embedding
|
||||
|
||||
def _get_audio_bounds(self, input_ids: torch.Tensor,
|
||||
audio_start_id: torch.Tensor,
|
||||
audio_end_id: torch.Tensor) -> torch.Tensor:
|
||||
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
|
||||
audio_start_tokens += 1
|
||||
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
|
||||
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
|
||||
return torch.hstack([
|
||||
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
|
||||
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
|
||||
])
|
||||
|
||||
def _parse_and_validate_audio_inputs(
|
||||
self, input_ids: torch.Tensor,
|
||||
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_start_id = kwargs.pop("audio_start_id")
|
||||
if not isinstance(audio_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_start_id. "
|
||||
f"Got type: {type(audio_start_id)}")
|
||||
audio_token_id = kwargs.pop("audio_token_id")
|
||||
if audio_token_id is not None:
|
||||
assert isinstance(audio_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||
|
||||
audio_end_id = kwargs.pop("audio_end_id")
|
||||
if not isinstance(audio_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_end_id. "
|
||||
f"Got type: {type(audio_end_id)}")
|
||||
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
|
||||
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embed_is_patch. "
|
||||
f"Got type: {type(audio_embed_is_patch)}")
|
||||
|
||||
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
|
||||
concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
audio_embeds=audio_embeds_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=flatten_bn(audio_features, concat=True),
|
||||
audio_feature_lens=flatten_bn(
|
||||
flatten_2d_lists(audio_feature_lens), concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
)
|
||||
audio_features_flat = flatten_bn(audio_features)
|
||||
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
image_inputs = self._parse_and_validate_image_inputs(
|
||||
input_ids, **kwargs)
|
||||
if not any("audio" in key for key in kwargs):
|
||||
return image_inputs, None
|
||||
audio_inputs = self._parse_and_validate_audio_inputs(
|
||||
input_ids, **kwargs)
|
||||
return image_inputs, audio_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
vlm_embeddings = None
|
||||
else:
|
||||
image_inputs, audio_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
if audio_inputs is not None:
|
||||
vlm_embeddings = self.get_embedding_with_audios(
|
||||
vlm_embeddings, audio_inputs,
|
||||
self.config.audio_chunk_length)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=audio_features_flat,
|
||||
audio_feature_lens=audio_feature_lens_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
return output
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("audio_features",
|
||||
"audio_embeds") and "audios" not in modalities:
|
||||
modalities["audios"] = self._parse_and_validate_audio_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: MiniCPMOAudioInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["audio_embeds"]
|
||||
|
||||
return self.get_audio_hidden_states(audio_input)
|
||||
|
||||
def _process_multimodal_inputs(self, modalities: dict):
|
||||
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
|
||||
|
||||
for modality in modalities:
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
audio_features,
|
||||
audio_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
return multimodal_embeddings
|
||||
|
@ -23,17 +23,15 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
||||
ImageProcessorItems, ImageSize,
|
||||
ModalityData, ModalityDataItems,
|
||||
@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
RawImageType = Union[Image.Image, torch.Tensor]
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
|
||||
class MiniCPMVImagePixelInputs(TypedDict):
|
||||
@ -86,13 +80,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
"""
|
||||
|
||||
tgt_sizes: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
||||
@ -100,23 +87,34 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_slices: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
image_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices,
|
||||
image_feature_size, hidden_size)`
|
||||
Shape: `(batch_size * num_images, num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_bounds: torch.Tensor
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
|
||||
|
||||
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
|
||||
num_images = len(pixel_values)
|
||||
|
||||
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
|
||||
num_videos = len(video_pixel_values)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
|
||||
|
||||
@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
return get_version_by_config(self.get_hf_config())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
mm_limits = {"image": None}
|
||||
if self.get_model_version() == (2, 6):
|
||||
return {"image": None, "video": None}
|
||||
else:
|
||||
return {"image": None}
|
||||
mm_limits["video"] = None
|
||||
|
||||
return mm_limits
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
mm_max_tokens = {"image": self.get_max_image_tokens()}
|
||||
if self.get_model_version() == (2, 6):
|
||||
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
|
||||
|
||||
return mm_max_tokens
|
||||
|
||||
def get_slice_image_placeholder(
|
||||
self,
|
||||
image_size: ImageSize,
|
||||
# For MiniCPM V/O 2.6
|
||||
image_idx: int = 0,
|
||||
max_slice_nums: Optional[int] = None,
|
||||
use_image_id: bool = True,
|
||||
) -> str:
|
||||
image_processor = self.get_image_processor()
|
||||
version = self.get_model_version()
|
||||
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
return image_processor.get_slice_image_placeholder(image_size)
|
||||
|
||||
return image_processor.get_slice_image_placeholder(
|
||||
image_size,
|
||||
image_idx=image_idx,
|
||||
max_slice_nums=max_slice_nums,
|
||||
use_image_id=use_image_id,
|
||||
)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
image_size: ImageSize,
|
||||
max_slice_nums: Optional[int] = None,
|
||||
use_image_id: bool = True,
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_placeholders = self.get_slice_image_placeholder(
|
||||
image_size,
|
||||
max_slice_nums=max_slice_nums,
|
||||
use_image_id=use_image_id,
|
||||
)
|
||||
image_token_ids = tokenizer.encode(image_placeholders,
|
||||
add_special_tokens=False)
|
||||
|
||||
return len(image_token_ids)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
image_size = self.get_image_size_with_most_features()
|
||||
return self.get_num_image_tokens(image_size)
|
||||
|
||||
def get_image_max_slice_num(self) -> int:
|
||||
return getattr(self.get_hf_config(), "max_slice_num", 9)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
||||
max_slice_num = self.get_image_max_slice_num()
|
||||
return ImageSize(width=image_size, height=image_size * max_slice_num)
|
||||
|
||||
def get_max_video_frame_tokens(self) -> int:
|
||||
frame_size = self.get_video_frame_size_with_most_features()
|
||||
return self.get_num_image_tokens(frame_size,
|
||||
self.get_video_max_slice_num())
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
frame_size,
|
||||
max_slice_nums=self.get_video_max_slice_num(),
|
||||
use_image_id=False,
|
||||
)
|
||||
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
return self.get_max_video_frame_tokens(
|
||||
) * self.get_num_frames_with_most_features(seq_len)
|
||||
|
||||
def get_slice_query_num(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
query_num = getattr(hf_config, "query_num", 64)
|
||||
return query_num
|
||||
|
||||
def get_max_slice_num(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
max_slice_num = getattr(hf_config, "max_slice_num", 9)
|
||||
return max_slice_num
|
||||
|
||||
def get_sliced_grid(self, image_size: ImageSize,
|
||||
max_slice_num: int) -> Tuple[int, int]:
|
||||
if self.get_model_version() == (2, 6):
|
||||
slice_grid = self.get_image_processor().get_sliced_grid(
|
||||
image_size, max_slice_num)
|
||||
else:
|
||||
slice_grid = self.get_image_processor().get_sliced_grid(image_size)
|
||||
return slice_grid
|
||||
|
||||
def get_num_image_tokens(self, image_size: ImageSize,
|
||||
max_slice_num: int) -> int:
|
||||
slice_grid = self.get_sliced_grid(image_size, max_slice_num)
|
||||
num_tokens = self.get_slice_query_num(
|
||||
) + 2 # <image>(<unk> * query_num)</image>
|
||||
if slice_grid is not None:
|
||||
if self.get_model_version() == (2, 6):
|
||||
num_additional_tokens = 0
|
||||
else:
|
||||
# <slice><image>(<unk> * query_num)</image></slice>
|
||||
num_additional_tokens = 2
|
||||
num_tokens += ((self.get_slice_query_num() + 2) \
|
||||
* slice_grid[0] * slice_grid[1]) \
|
||||
+ slice_grid[1] - 1 + num_additional_tokens
|
||||
return num_tokens
|
||||
|
||||
def get_image_slice_nums(self, image_size: torch.Tensor,
|
||||
max_slice_nums: int) -> int:
|
||||
grid = self.get_sliced_grid(image_size, max_slice_nums)
|
||||
return 1 if grid is None else grid[0] * grid[1] + 1
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
image_size = self.get_image_size_with_most_features()
|
||||
return self.get_num_image_tokens(image_size, self.get_max_slice_num())
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
# Result in the max possible feature size (h:w = 9:1)
|
||||
return self.get_default_image_sizes(self.get_max_slice_num())
|
||||
|
||||
def get_video_max_slice_num(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_video_frame_size_with_most_features(self) -> ImageSize:
|
||||
return self.get_default_image_sizes(self.get_video_max_slice_num())
|
||||
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
||||
max_slice_num = self.get_video_max_slice_num()
|
||||
return ImageSize(width=image_size, height=image_size * max_slice_num)
|
||||
|
||||
def get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_frame_tokens = self.get_max_video_frame_tokens()
|
||||
@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
|
||||
# count <image_idx></image_idx> tokens
|
||||
# which are not in get_max_image_tokens
|
||||
max_image_tokens = self.get_max_image_tokens(
|
||||
) * max_images + 4 * max_images
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
|
||||
@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_default_image_sizes(self, num_slices: int) -> ImageSize:
|
||||
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
||||
return ImageSize(width=image_size, height=image_size * num_slices)
|
||||
|
||||
|
||||
_I = TypeVar("_I",
|
||||
bound=MiniCPMVProcessingInfo,
|
||||
@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MiniCPMVMultiModalDataParser()
|
||||
|
||||
def get_slice_image_placeholder(self, image_size: ImageSize,
|
||||
**kwargs) -> str:
|
||||
image_processor = self.info.get_image_processor()
|
||||
version = self.info.get_model_version()
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
return image_processor.get_slice_image_placeholder(image_size)
|
||||
return image_processor.get_slice_image_placeholder(
|
||||
image_size, **kwargs)
|
||||
|
||||
def get_image_prompt_texts(self,
|
||||
image_size: ImageSize,
|
||||
image_idx: int = 0) -> str:
|
||||
return self.get_slice_image_placeholder(image_size,
|
||||
image_idx=image_idx)
|
||||
return self.info.get_slice_image_placeholder(
|
||||
image_size,
|
||||
image_idx=image_idx,
|
||||
)
|
||||
|
||||
def get_video_prompt_texts(self, image_size: ImageSize,
|
||||
num_frames: int) -> str:
|
||||
return self.get_slice_image_placeholder(
|
||||
return self.info.get_slice_image_placeholder(
|
||||
image_size=image_size,
|
||||
image_idx=0,
|
||||
max_slice_nums=self.info.get_video_max_slice_num(),
|
||||
use_image_id=False,
|
||||
) * num_frames
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
||||
def get_embed_is_patch(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
) -> torch.Tensor:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
special_tokens = {
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
}
|
||||
if hasattr(tokenizer, "slice_start_id"):
|
||||
special_tokens["slice_start_id"] = tokenizer.slice_start_id
|
||||
special_tokens["slice_end_id"] = tokenizer.slice_end_id
|
||||
|
||||
return {k: torch.tensor(v) for k, v in special_tokens.items()}
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
return torch.tensor(input_ids) == unk_token_id
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
@ -546,14 +545,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image": images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
}).get_items("image",
|
||||
(MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
|
||||
|
||||
return self._base_call_hf_processor(
|
||||
prompts=[self.info.image_pattern] * len(parsed_images),
|
||||
mm_data={"images": [[image] for image in parsed_images]},
|
||||
mm_kwargs=mm_kwargs,
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
|
||||
image_inputs = {}
|
||||
else:
|
||||
image_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.image_pattern] * len(parsed_images),
|
||||
mm_data={"images": [[image] for image in parsed_images]},
|
||||
mm_kwargs=mm_kwargs,
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
image_repl_features = [
|
||||
self.get_image_prompt_texts(size, idx)
|
||||
for idx, size in enumerate(image_sizes)
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(image_repl_tokens)
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return image_inputs
|
||||
|
||||
def process_videos(
|
||||
self,
|
||||
@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
parsed_videos = (self._get_data_parser().parse_mm_data({
|
||||
"video": videos
|
||||
}).get_items("video", VideoProcessorItems))
|
||||
}).get_items("video",
|
||||
(MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))
|
||||
|
||||
max_slice_num = self.info.get_video_max_slice_num()
|
||||
if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
|
||||
video_inputs = {}
|
||||
else:
|
||||
video_inputs = self._base_call_hf_processor(
|
||||
prompts=[
|
||||
self.info.image_pattern * len(video)
|
||||
for video in parsed_videos
|
||||
],
|
||||
mm_data={"images": list(parsed_videos)},
|
||||
mm_kwargs={
|
||||
**mm_kwargs,
|
||||
"max_slice_nums":
|
||||
self.info.get_video_max_slice_num(),
|
||||
},
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
video_inputs = self._base_call_hf_processor(
|
||||
prompts=[
|
||||
self.info.image_pattern * len(video) for video in parsed_videos
|
||||
],
|
||||
mm_data={"images": list(parsed_videos)},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "max_slice_nums": max_slice_num
|
||||
},
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
frame_sizes = [
|
||||
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
|
||||
]
|
||||
num_frames = [
|
||||
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
|
||||
]
|
||||
video_repl_features = [
|
||||
self.get_video_prompt_texts(size, nframes)
|
||||
for size, nframes in zip(frame_sizes, num_frames)
|
||||
]
|
||||
|
||||
return {f"video_{k}": v for k, v in video_inputs.items()}
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
video_repls_feature_tokens = [
|
||||
tokenizer.encode(video_repl, add_special_tokens=False)
|
||||
for video_repl in video_repl_features
|
||||
]
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video)>./</\1>\)"
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(video_repl_tokens)
|
||||
for video_repl_tokens in video_repls_feature_tokens
|
||||
]
|
||||
video_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return video_inputs
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
out_keys: set[str],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> dict[str, NestedTensors]:
|
||||
# This processor supports zipping prompt and mm_data together
|
||||
if self.info.get_model_version() == (2, 6):
|
||||
inputs = super()._call_hf_processor(
|
||||
@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Do not support combination inputs of images and videos for now
|
||||
# Try to handle interleaved multimodal data
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
input_ids = torch.tensor([tokenizer.encode(prompt)])
|
||||
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
||||
|
||||
return BatchFeature({
|
||||
"input_ids":
|
||||
torch.tensor([tokenizer.encode(prompt)]),
|
||||
"input_ids": input_ids,
|
||||
**mm_inputs,
|
||||
})
|
||||
|
||||
@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _minicpmv_field_config(hf_inputs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
if isinstance(prompt, list):
|
||||
prompt = self.info.get_tokenizer().decode(prompt)
|
||||
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
|
||||
mm_orders = {
|
||||
f"{modality}_orders":
|
||||
torch.tensor(
|
||||
[index for index, m in enumerate(matches) if m == modality])
|
||||
for modality in self.info.get_supported_mm_limits()
|
||||
}
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
return_mm_hashes)
|
||||
# Exclude <image_id>x</image_id> from placeholders
|
||||
if "image" in result["mm_placeholders"] and \
|
||||
self.info.get_model_version() == (2, 6):
|
||||
result["mm_placeholders"]["image"] = [
|
||||
PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
|
||||
length=p["length"] - 3 - idx // 10)
|
||||
for idx, p in enumerate(result["mm_placeholders"]["image"])
|
||||
]
|
||||
result["mm_kwargs"].update(**mm_orders)
|
||||
result["mm_kwargs"].update(**self.get_special_tokens())
|
||||
return result
|
||||
|
||||
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsV0Only):
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"""
|
||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||
instantiated.
|
||||
@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "resampler"))
|
||||
|
||||
self.mm_token_ids = set[int]()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.llm.make_empty_intermediate_tensors)
|
||||
|
||||
@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def get_embedding_with_vision(
|
||||
def _parse_and_validate_vision_input(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> torch.Tensor:
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
if image_inputs is None:
|
||||
return vlm_embedding
|
||||
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = image_inputs["image_embeds"].to(
|
||||
device=vlm_embedding.device,
|
||||
dtype=vlm_embedding.dtype,
|
||||
)
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack([
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]).to(vlm_embedding.device)
|
||||
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
||||
)
|
||||
|
||||
return vlm_embedding
|
||||
|
||||
def _get_image_bounds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
im_start_id: torch.Tensor,
|
||||
im_end_id: torch.Tensor,
|
||||
slice_start_id: Optional[torch.Tensor] = None,
|
||||
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_cond = input_ids == im_start_id[0]
|
||||
end_cond = input_ids == im_end_id[0]
|
||||
if slice_start_id is not None:
|
||||
start_cond |= (input_ids == slice_start_id[0])
|
||||
end_cond |= (input_ids == slice_end_id[0])
|
||||
|
||||
image_start_tokens, = torch.where(start_cond)
|
||||
image_start_tokens += 1
|
||||
image_end_tokens, = torch.where(end_cond)
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
|
||||
if valid_image_nums == 0:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
return torch.hstack([
|
||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
])
|
||||
|
||||
def _parse_and_validate_image_inputs(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
modality: str,
|
||||
**kwargs: object,
|
||||
) -> Optional[MiniCPMVImageInputs]:
|
||||
image_keys = {"pixel_values", "tgt_sizes"}
|
||||
pixel_data = {
|
||||
"image": {
|
||||
key: kwargs.pop(key, None)
|
||||
for key in image_keys
|
||||
},
|
||||
"video": {
|
||||
key: kwargs.pop("video_" + key, None)
|
||||
for key in image_keys
|
||||
}
|
||||
}
|
||||
embed_data = {
|
||||
"image": kwargs.pop("image_embeds", None),
|
||||
"video": kwargs.pop("video_embeds", None),
|
||||
}
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
all_pixel_data = [
|
||||
v for vs in pixel_data.values() for v in vs.values()
|
||||
if v is not None
|
||||
]
|
||||
all_embed_data = [v for v in embed_data.values() if v is not None]
|
||||
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
im_start_id = kwargs.pop("im_start_id")
|
||||
if not isinstance(im_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of im_start_id. "
|
||||
f"Got type: {type(im_start_id)}")
|
||||
image_token_id = kwargs.pop("image_token_id")
|
||||
if image_token_id is not None:
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
||||
|
||||
im_end_id = kwargs.pop("im_end_id")
|
||||
if not isinstance(im_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of im_end_id. "
|
||||
f"Got type: {type(im_end_id)}")
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of embed_is_patch for {modality=}. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
if slice_start_id is not None and not isinstance(
|
||||
slice_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of slice_start_id. "
|
||||
f"Got type: {type(slice_start_id)}")
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
if slice_end_id is not None and not isinstance(slice_end_id,
|
||||
torch.Tensor):
|
||||
raise ValueError("Incorrect type of slice_end_id. "
|
||||
f"Got type: {type(slice_end_id)}")
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of image_embeds for {modality=}. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
if len(all_embed_data) > 0:
|
||||
if len(all_embed_data) > 1:
|
||||
raise ValueError("Incorrect inputs for vision embeddings. "
|
||||
"Image embeds and video embeds can not "
|
||||
"exist simultaneously.")
|
||||
|
||||
vision_embeds, = all_embed_data
|
||||
if not isinstance(vision_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of vision_embeds. "
|
||||
f"Got type: {type(vision_embeds)}")
|
||||
image_embeds_flat = flatten_bn(image_embeds)
|
||||
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
|
||||
concat=True),
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
image_embeds=image_embeds_flat,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
|
||||
for modality in ("image", "video"):
|
||||
modality_orders = kwargs.pop(f"{modality}_orders", None)
|
||||
if modality_orders is not None:
|
||||
if not isinstance(modality_orders, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {modality}_orders. "
|
||||
f"Got type: {type(modality_orders)}")
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel_values for {modality=}. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
order_data[modality] = modality_orders
|
||||
tgt_sizes = kwargs.pop("tgt_sizes")
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
|
||||
f"Got type: {type(tgt_sizes)}")
|
||||
|
||||
batch_sizes = {
|
||||
modality: len(modality_orders)
|
||||
for modality, modality_orders in order_data.items()
|
||||
}
|
||||
unique_batch_sizes = set(batch_sizes.values())
|
||||
assert len(unique_batch_sizes) == 1, (
|
||||
f"Found inconsistent batch sizes: {batch_sizes}")
|
||||
batch_size, = unique_batch_sizes
|
||||
num_slices = [[len(p) for p in ps] for ps in pixel_values]
|
||||
num_slices_flat = flatten_bn(torch.tensor(num_slices))
|
||||
|
||||
pixel_values_flat = list[torch.Tensor]()
|
||||
tgt_sizes_flat = list[torch.Tensor]()
|
||||
for b in range(batch_size):
|
||||
mm_orders_b = [(idx_b.item(), modality)
|
||||
for modality, modality_orders in order_data.items()
|
||||
for idx_b in modality_orders[b]]
|
||||
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
|
||||
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
|
||||
|
||||
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
|
||||
modality_pixel_data = pixel_data[modality]
|
||||
|
||||
modality_pixel_values = modality_pixel_data["pixel_values"]
|
||||
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel_values for {modality=}. "
|
||||
f"Got type: {type(modality_pixel_values)}")
|
||||
|
||||
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
|
||||
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of tgt_sizes for {modality=}. "
|
||||
f"Got type: {type(modality_tgt_sizes)}")
|
||||
|
||||
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
|
||||
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
|
||||
|
||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
||||
# so we allow it to be empty
|
||||
if len(pixel_values_flat) != len(tgt_sizes_flat):
|
||||
raise ValueError("Inconsistent flattened lengths, found: "
|
||||
f"{len(pixel_values_flat)} vs. "
|
||||
f"{len(tgt_sizes_flat)}")
|
||||
|
||||
if len(pixel_values_flat) == 0:
|
||||
return None
|
||||
|
||||
return MiniCPMVImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
tgt_sizes=tgt_sizes_flat,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_slices=num_slices_flat,
|
||||
)
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
return self._parse_and_validate_image_inputs(input_ids, **kwargs)
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_vision_input(
|
||||
"images", **kwargs)
|
||||
if input_key in ("video_pixel_values",
|
||||
"video_embeds") and "videos" not in modalities:
|
||||
|
||||
def _image_key(video_key: str):
|
||||
if video_key == "video_token_id":
|
||||
return "image_token_id"
|
||||
|
||||
return video_key.removeprefix("video_")
|
||||
|
||||
modalities["videos"] = self._parse_and_validate_vision_input(
|
||||
"videos", **{
|
||||
_image_key(k): v
|
||||
for k, v in kwargs.items()
|
||||
})
|
||||
|
||||
return modalities
|
||||
|
||||
def _process_vision_input(
|
||||
self,
|
||||
image_input: MiniCPMVImageInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"]
|
||||
|
||||
image_features_flat = self.get_vision_hidden_states(image_input)
|
||||
|
||||
# Reconstruct the batch dimension
|
||||
return image_features_flat.split(image_input["num_slices"].tolist())
|
||||
|
||||
def _process_multimodal_inputs(self, modalities: dict):
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
image_features = self._process_vision_input(image_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_features = self._process_vision_input(video_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
video_features,
|
||||
video_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
|
||||
return self._process_multimodal_inputs(modalities)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
assert len(self.mm_token_ids) > 0
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
list(self.mm_token_ids),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
vlm_embeddings = None
|
||||
else:
|
||||
image_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
inputs_embeds = None
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
output = self.llm.model(
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return output
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
return model
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_tokens(input_ids)
|
||||
|
||||
def init_resampler(self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
|
@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feat_is_patch = image_input["feat_is_patch"]
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
if isinstance(images, list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(None if image_masks_flat is None else
|
||||
image_masks_flat.unsqueeze(0)),
|
||||
).squeeze(0)
|
||||
|
||||
# Reconstruct the batch dimension
|
||||
num_crops_per_image = [nc.sum().item() for nc in num_crops]
|
||||
image_features = image_features_flat.split(num_crops_per_image)
|
||||
else:
|
||||
image_features = self.vision_backbone(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
)
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(None if image_masks_flat is None else
|
||||
image_masks_flat.unsqueeze(0)),
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(image_features, feat_is_patch)
|
||||
feats[f_is_patch] for feats, f_is_patch in zip(
|
||||
image_features_flat.split(num_crops.tolist()),
|
||||
feat_is_patch_flat.split(num_crops.tolist()),
|
||||
)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
|
Reference in New Issue
Block a user