[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 01:27:01 +08:00
committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
11 changed files with 127 additions and 331 deletions

View File

@ -631,8 +631,11 @@ class InternS1ForConditionalGeneration(
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if isinstance(image_token_id, torch.Tensor):
image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values is not None:
h, w = self.config.vision_config.image_size
@ -665,8 +668,11 @@ class InternS1ForConditionalGeneration(
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if isinstance(video_token_id, torch.Tensor):
video_token_id = video_token_id.flatten().unique().item()
assert isinstance(video_token_id, int)
self.video_context_token_id = video_token_id
if pixel_values_flat_video is not None:
h, w = self.config.vision_config.image_size

View File

@ -1232,8 +1232,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if isinstance(image_token_id, torch.Tensor):
image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None:
expected_h = expected_w = self.config.vision_config.image_size
@ -1265,8 +1268,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if isinstance(video_token_id, torch.Tensor):
video_token_id = video_token_id.flatten().unique().item()
assert isinstance(video_token_id, int)
self.video_context_token_id = video_token_id
if pixel_values_flat_video is not None:
expected_h = expected_w = self.config.vision_config.image_size

View File

@ -26,7 +26,7 @@
import collections
import collections.abc
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, TypeAlias, TypedDict, cast
from typing import Annotated, Any, TypeAlias, cast
import numpy as np
import torch
@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module):
# === Audio Inputs === #
class MiDashengLMAudioInputs(TypedDict):
input_values: torch.Tensor
"""Shape: `(num_audios, num_sampling_points)`"""
audio_length: torch.Tensor
"""Shape: `(num_audios, 1)`"""
class MiDashengLMAudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- p: Number of sampling points
"""
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
audio_length: Annotated[torch.Tensor, TensorShape("n")]
class MiDashengLMProcessingInfo(BaseProcessingInfo):
@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder,
)
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
self.decoder.make_empty_intermediate_tensors
)
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
if name == "input_values":
max_length = max(tensor.shape[1] for tensor in mm_input)
padded_mm_input = [
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length
else tensor
for tensor in mm_input
]
return torch.concat(padded_mm_input)
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> MiDashengLMAudioInputs | None:
@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
if input_values is None:
return None
input_values = self._validate_and_reshape_mm_tensor(
input_values, "input_values"
)
audio_length = self._validate_and_reshape_mm_tensor(
audio_length, "audio_length"
)
if not isinstance(input_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_values)}"
if isinstance(input_values, list):
input_values = torch.nn.utils.rnn.pad_sequence(
input_values,
batch_first=True,
)
return MiDashengLMAudioInputs(
@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_length=audio_length,
)
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
def _process_audio_input(
self,
audio_input: MiDashengLMAudioInputs,
) -> tuple[torch.Tensor, ...]:
# Process audio through encoder and projector
input_values = audio_input["input_values"]
audio_length = audio_input["audio_length"]
@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = (
audio_length.cpu().numpy()
if isinstance(audio_length, torch.Tensor)
else audio_length
)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
for length in audio_length_np
for length in audio_length.tolist()
]
audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device
audio_output_lengths = torch.tensor(
audio_output_lengths,
device=audio_embeddings.device,
)
audio_feature_mask = torch.arange(
@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_id,
)
input_ids = None
return self.decoder.model(
input_ids,

View File

@ -71,7 +71,7 @@ from .minicpmv import (
MiniCPMVProcessingInfo,
_minicpmv_field_config,
)
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
CPU_DEVICE = torch.device("cpu")
@ -132,15 +132,11 @@ MiniCPMOAudioInputs: TypeAlias = (
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
@ -332,10 +328,6 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
]
audio_inputs["audio_features"] = unpadded_audio_features
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs
def process_mm_inputs(
@ -436,12 +428,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
past_key_values = None
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, past_key_values = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_values,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
@ -567,8 +557,6 @@ class MiniCPMO(MiniCPMV2_6):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
)
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily
audio_config = self.config.audio_config
@ -731,43 +719,18 @@ class MiniCPMO(MiniCPMV2_6):
if audio_features is None and audio_embeds is None:
return None
audio_token_id = kwargs.pop("audio_token_id")
if audio_token_id is not None:
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
)
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=audio_embeds_flat,
)
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
audio_embeds=audio_embeds,
)
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}"
)
audio_features_flat = flatten_bn(audio_features)
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
audio_features=audio_features,
audio_feature_lens=audio_feature_lens,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

View File

@ -114,7 +114,7 @@ class MiniCPMVImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values"
# Note that the image size may vary, so we pass it as a list instead of a
# Note that the patch size may vary, so we pass it as a list instead of a
# batched tensor.
pixel_values: Annotated[
list[torch.Tensor],
@ -453,12 +453,6 @@ def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
num_images = len(pixel_values)
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
num_videos = len(video_pixel_values)
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
@ -468,8 +462,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
@ -792,10 +784,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
return image_inputs
def process_videos(
@ -831,10 +819,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
return video_inputs
def process_mm_inputs(
@ -1021,6 +1005,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
instantiated.
"""
merge_by_field_config = True
supports_encoder_tp_data = True
@classmethod
@ -1066,7 +1052,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "resampler"),
)
self.mm_token_ids = set[int]()
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
def _parse_and_validate_vision_input(
@ -1080,43 +1065,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is None and image_embeds is None:
return None
image_token_id = kwargs.pop("image_token_id")
if image_token_id is not None:
assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item())
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image_embeds for {modality=}. "
f"Got type: {type(image_embeds)}"
)
image_embeds_flat = flatten_bn(image_embeds)
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds_flat,
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(pixel_values)}"
image_embeds=image_embeds,
)
tgt_sizes = kwargs.pop("tgt_sizes")
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(tgt_sizes)}"
)
num_slices = [[len(p) for p in ps] for ps in pixel_values]
num_slices_flat = flatten_bn(torch.tensor(num_slices))
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
num_slices_flat = torch.tensor([len(ps) for ps in pixel_values])
pixel_values_flat = flatten_bn(pixel_values)
tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True)
return MiniCPMVImagePixelInputs(
type="pixel_values",
@ -1142,15 +1101,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
input_key in ("video_pixel_values", "video_embeds")
and "videos" not in modalities
):
def _image_key(video_key: str):
if video_key == "video_token_id":
return "image_token_id"
return video_key.removeprefix("video_")
modalities["videos"] = self._parse_and_validate_vision_input(
"videos", **{_image_key(k): v for k, v in kwargs.items()}
"videos", **{k.removeprefix("video_"): v for k, v in kwargs.items()}
)
return modalities

View File

@ -71,7 +71,7 @@ from .interfaces import (
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix
from .vision import run_dp_sharded_vision_model
@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[
pixel_values: Annotated[
torch.Tensor,
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
]
@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema):
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `pixel_values`.
"""
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration(
if pixel_values is None:
return None
# num_images x num_chunks, channel, image_size, image_size
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
patches_per_image = kwargs.pop("patches_per_image")
aspect_ratios = kwargs.pop("aspect_ratios")
if aspect_ratios.ndim == 3:
aspect_ratios = aspect_ratios.squeeze(1)
return Llama4ImagePatchInputs(
type="pixel_values",
flat_data=flat_pixel_values,
pixel_values=pixel_values,
patches_per_image=patches_per_image,
aspect_ratios=aspect_ratios,
)
@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration(
self, image_input: Llama4ImagePatchInputs
) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"]
pixel_values = image_input["pixel_values"]
patches_per_image = image_input["patches_per_image"].tolist()
# shard image input
if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model
pixel_values, self.vision_model
)
else:
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.vision_model(pixel_values)
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)

View File

@ -75,7 +75,6 @@ from .interfaces import (
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- nc: Number of crops (dynamic)
- bnc: Batch size * number of images * number of crops (dynamic)
- np: Number of patches
- tp: Token sequence positions
- pd: Patch dimension
"""
images: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
]
# Number of crops may vary per batch and image, so pass it as a list.
images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
image_masks: Annotated[
torch.Tensor | list[torch.Tensor] | None,
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]
image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]
image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
"""An index tensor that maps image features to their corresponding patch tokens."""
image_input_idx: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
class MolmoForCausalLM(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
@ -1451,18 +1443,12 @@ class MolmoForCausalLM(
if images is None:
return None
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
)
num_crops = flatten_bn(num_crops, concat=True)
img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError(
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
)
self.img_patch_id = img_patch_id.flatten().unique().item()
if isinstance(img_patch_id, torch.Tensor):
img_patch_id = img_patch_id.item()
assert isinstance(img_patch_id, int)
self.img_patch_id = img_patch_id
return MolmoImageInputs(
images=images,
@ -1481,17 +1467,9 @@ class MolmoForCausalLM(
num_crops = image_input["num_crops"]
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True)
)
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
),
image_features = self.vision_backbone(
images=images.unsqueeze(0),
image_masks=None if image_masks is None else image_masks.unsqueeze(0),
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
@ -1499,8 +1477,8 @@ class MolmoForCausalLM(
results = []
num_crops_list = num_crops.tolist()
for feats, img_idx in zip(
image_features_flat.split(num_crops_list),
image_input_idx_flat.split(num_crops_list),
image_features.split(num_crops_list),
image_input_idx.split(num_crops_list),
):
is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid]

View File

@ -11,7 +11,7 @@ import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, TypedDict, TypeVar
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt
import torch
@ -40,7 +40,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -96,31 +95,35 @@ MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12
class NanoNemotronVLImagePixelInputs(TypedDict):
class NanoNemotronVLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
"""
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Dimensions:
- n: Number of images
- f: Total image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class NanoNemotronVLImageEmbeddinInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor | list[torch.Tensor]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
NanoNemotronVLImageInputs: TypeAlias = (
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddinInputs
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
)
@ -710,37 +713,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Basic image-only MultiModalProcessor for InternVL-style models."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
@ -748,7 +726,6 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
@ -814,25 +791,6 @@ class NanoNemotronVLMultiModalProcessor(
):
"""MultiModalProcessor extended for video support"""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if (
self.info.supports_video
and (video_token_id := hf_processor.video_token_id) is not None
):
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -841,13 +799,12 @@ class NanoNemotronVLMultiModalProcessor(
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
if self.info.supports_video:
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
else:
video_fields = {}
@ -999,6 +956,8 @@ class NanoNemotronVLDummyInputsBuilder(
class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@ -1051,8 +1010,6 @@ class NemotronH_Nano_VL_V2(
)
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
self.img_context_token_id = None
self.video_context_token_id = None
self.config = config
self.model_config = vllm_config.model_config
@ -1106,37 +1063,12 @@ class NemotronH_Nano_VL_V2(
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return NanoNemotronVLImageEmbeddinInputs(
return NanoNemotronVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}"
)
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}"
)
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return NanoNemotronVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,
@ -1285,28 +1217,10 @@ class NemotronH_Nano_VL_V2(
if video_embeds is not None:
return NanoNemotronVLVideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}"
)
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}"
)
pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.force_image_size
resolve_bindings = {"h": expected_h, "w": expected_w}

View File

@ -496,8 +496,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if isinstance(image_token_id, torch.Tensor):
image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None:
return InternVLImagePixelInputs(

View File

@ -814,8 +814,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if isinstance(image_token_id, torch.Tensor):
image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None:
return SkyworkR1VImagePixelInputs(

View File

@ -432,7 +432,7 @@ def group_mm_kwargs_by_modality(
if device is not None:
mm_kwargs_group = json_map_leaves(
lambda x: x.to(device=device),
lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
mm_kwargs_group,
)
else: