[VLM] Check required fields before initializing field config in DictEmbeddingItems (#13380)

This commit is contained in:
Cyrus Leung
2025-02-17 17:36:07 +08:00
committed by GitHub
parent 238dfc8ac3
commit 7b623fca0b
5 changed files with 35 additions and 22 deletions

View File

@ -184,8 +184,8 @@ llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={
mm_data = {
"image": {
"image_embeds": image_embeds,
# image_size_list is needed to calculate details of the sliced image.
"image_size_list": [image.size for image in images], # list of image sizes
# image_sizes is needed to calculate details of the sliced image.
"image_sizes": [image.size for image in images], # list of image sizes
}
}

View File

@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
import torch
from torch import nn
@ -122,13 +122,16 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"},
fields_factory=fields_factory,
)
@ -141,7 +144,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(
data,
fields_config=_minicpmo_field_config(data),
fields_factory=_minicpmo_field_config,
)
return super()._parse_audio_data(data)

View File

@ -255,13 +255,16 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"image_embeds", "image_sizes"},
fields_factory=fields_factory,
)
def get_image_size(self, index: int) -> ImageSize:
@ -274,13 +277,16 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="video",
fields_config=fields_config,
required_fields={"video_embeds", "video_image_sizes"},
fields_factory=fields_factory,
)
def get_frame_size(self, index: int) -> ImageSize:
@ -300,7 +306,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
fields_factory=_minicpmv_field_config,
)
return super()._parse_image_data(data)
@ -312,7 +318,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
fields_factory=_minicpmv_field_config,
)
return super()._parse_video_data(data)

View File

@ -691,8 +691,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems(
data,
modality="image",
fields_config=_qwen2vl_field_config(data),
required_fields={"image_embeds", "image_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_image_data(data)
@ -705,8 +705,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems(
data,
modality="video",
fields_config=_qwen2vl_field_config(data),
required_fields={"video_embeds", "video_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_video_data(data)

View File

@ -125,17 +125,14 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
self,
data: Mapping[str, torch.Tensor],
modality: str,
fields_config: Mapping[str, MultiModalFieldConfig],
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(data, modality)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
@ -143,6 +140,13 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields