mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -631,8 +631,11 @@ class InternS1ForConditionalGeneration(
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
if isinstance(image_token_id, torch.Tensor):
|
||||
image_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(image_token_id, int)
|
||||
self.img_context_token_id = image_token_id
|
||||
|
||||
if pixel_values is not None:
|
||||
h, w = self.config.vision_config.image_size
|
||||
@ -665,8 +668,11 @@ class InternS1ForConditionalGeneration(
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
assert isinstance(video_token_id, torch.Tensor)
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
if isinstance(video_token_id, torch.Tensor):
|
||||
video_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(video_token_id, int)
|
||||
self.video_context_token_id = video_token_id
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
h, w = self.config.vision_config.image_size
|
||||
|
@ -1232,8 +1232,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
if isinstance(image_token_id, torch.Tensor):
|
||||
image_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(image_token_id, int)
|
||||
self.img_context_token_id = image_token_id
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
@ -1265,8 +1268,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
assert isinstance(video_token_id, torch.Tensor)
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
if isinstance(video_token_id, torch.Tensor):
|
||||
video_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(video_token_id, int)
|
||||
self.video_context_token_id = video_token_id
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
|
@ -26,7 +26,7 @@
|
||||
import collections
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import Any, TypeAlias, TypedDict, cast
|
||||
from typing import Annotated, Any, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.midashenglm import DashengConfig
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module):
|
||||
|
||||
|
||||
# === Audio Inputs === #
|
||||
class MiDashengLMAudioInputs(TypedDict):
|
||||
input_values: torch.Tensor
|
||||
"""Shape: `(num_audios, num_sampling_points)`"""
|
||||
audio_length: torch.Tensor
|
||||
"""Shape: `(num_audios, 1)`"""
|
||||
class MiDashengLMAudioInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- p: Number of sampling points
|
||||
"""
|
||||
|
||||
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
|
||||
audio_length: Annotated[torch.Tensor, TensorShape("n")]
|
||||
|
||||
|
||||
class MiDashengLMProcessingInfo(BaseProcessingInfo):
|
||||
@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor(
|
||||
dummy_inputs=MiDashengLMDummyInputsBuilder,
|
||||
)
|
||||
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.decoder.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
|
||||
if name == "input_values":
|
||||
max_length = max(tensor.shape[1] for tensor in mm_input)
|
||||
padded_mm_input = [
|
||||
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
|
||||
if tensor.shape[1] < max_length
|
||||
else tensor
|
||||
for tensor in mm_input
|
||||
]
|
||||
return torch.concat(padded_mm_input)
|
||||
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> MiDashengLMAudioInputs | None:
|
||||
@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if input_values is None:
|
||||
return None
|
||||
input_values = self._validate_and_reshape_mm_tensor(
|
||||
input_values, "input_values"
|
||||
)
|
||||
audio_length = self._validate_and_reshape_mm_tensor(
|
||||
audio_length, "audio_length"
|
||||
)
|
||||
if not isinstance(input_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_values)}"
|
||||
|
||||
if isinstance(input_values, list):
|
||||
input_values = torch.nn.utils.rnn.pad_sequence(
|
||||
input_values,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
return MiDashengLMAudioInputs(
|
||||
@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_length=audio_length,
|
||||
)
|
||||
|
||||
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: MiDashengLMAudioInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# Process audio through encoder and projector
|
||||
input_values = audio_input["input_values"]
|
||||
audio_length = audio_input["audio_length"]
|
||||
@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
|
||||
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
||||
|
||||
audio_length_np = (
|
||||
audio_length.cpu().numpy()
|
||||
if isinstance(audio_length, torch.Tensor)
|
||||
else audio_length
|
||||
)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
|
||||
for length in audio_length_np
|
||||
for length in audio_length.tolist()
|
||||
]
|
||||
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
||||
audio_embeddings.device
|
||||
audio_output_lengths = torch.tensor(
|
||||
audio_output_lengths,
|
||||
device=audio_embeddings.device,
|
||||
)
|
||||
|
||||
audio_feature_mask = torch.arange(
|
||||
@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.decoder.model(
|
||||
input_ids,
|
||||
|
@ -71,7 +71,7 @@ from .minicpmv import (
|
||||
MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@ -132,15 +132,11 @@ MiniCPMOAudioInputs: TypeAlias = (
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features = hf_inputs.get("audio_features", torch.empty(0))
|
||||
num_audios = len(audio_features)
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||
)
|
||||
|
||||
|
||||
@ -332,10 +328,6 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def process_mm_inputs(
|
||||
@ -436,12 +428,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
past_key_values = None
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, past_key_values = self.self_attn(
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_values,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
@ -567,8 +557,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
|
||||
)
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
@ -731,43 +719,18 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_token_id = kwargs.pop("audio_token_id")
|
||||
if audio_token_id is not None:
|
||||
assert isinstance(audio_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
|
||||
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=audio_embeds_flat,
|
||||
)
|
||||
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
|
||||
audio_embeds=audio_embeds,
|
||||
)
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}"
|
||||
)
|
||||
|
||||
audio_features_flat = flatten_bn(audio_features)
|
||||
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=audio_features_flat,
|
||||
audio_feature_lens=audio_feature_lens_flat,
|
||||
audio_features=audio_features,
|
||||
audio_feature_lens=audio_feature_lens,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
|
@ -114,7 +114,7 @@ class MiniCPMVImagePixelInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
# Note that the image size may vary, so we pass it as a list instead of a
|
||||
# Note that the patch size may vary, so we pass it as a list instead of a
|
||||
# batched tensor.
|
||||
pixel_values: Annotated[
|
||||
list[torch.Tensor],
|
||||
@ -453,12 +453,6 @@ def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
|
||||
|
||||
|
||||
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
|
||||
num_images = len(pixel_values)
|
||||
|
||||
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
|
||||
num_videos = len(video_pixel_values)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
@ -468,8 +462,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
|
||||
|
||||
@ -792,10 +784,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return image_inputs
|
||||
|
||||
def process_videos(
|
||||
@ -831,10 +819,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return video_inputs
|
||||
|
||||
def process_mm_inputs(
|
||||
@ -1021,6 +1005,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
instantiated.
|
||||
"""
|
||||
|
||||
merge_by_field_config = True
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@classmethod
|
||||
@ -1066,7 +1052,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "resampler"),
|
||||
)
|
||||
|
||||
self.mm_token_ids = set[int]()
|
||||
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
|
||||
|
||||
def _parse_and_validate_vision_input(
|
||||
@ -1080,43 +1065,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
image_token_id = kwargs.pop("image_token_id")
|
||||
if image_token_id is not None:
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of image_embeds for {modality=}. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
image_embeds_flat = flatten_bn(image_embeds)
|
||||
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds_flat,
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel_values for {modality=}. "
|
||||
f"Got type: {type(pixel_values)}"
|
||||
image_embeds=image_embeds,
|
||||
)
|
||||
|
||||
tgt_sizes = kwargs.pop("tgt_sizes")
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of tgt_sizes for {modality=}. "
|
||||
f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
num_slices = [[len(p) for p in ps] for ps in pixel_values]
|
||||
num_slices_flat = flatten_bn(torch.tensor(num_slices))
|
||||
|
||||
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
|
||||
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
|
||||
num_slices_flat = torch.tensor([len(ps) for ps in pixel_values])
|
||||
pixel_values_flat = flatten_bn(pixel_values)
|
||||
tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True)
|
||||
|
||||
return MiniCPMVImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -1142,15 +1101,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
input_key in ("video_pixel_values", "video_embeds")
|
||||
and "videos" not in modalities
|
||||
):
|
||||
|
||||
def _image_key(video_key: str):
|
||||
if video_key == "video_token_id":
|
||||
return "image_token_id"
|
||||
|
||||
return video_key.removeprefix("video_")
|
||||
|
||||
modalities["videos"] = self._parse_and_validate_vision_input(
|
||||
"videos", **{_image_key(k): v for k, v in kwargs.items()}
|
||||
"videos", **{k.removeprefix("video_"): v for k, v in kwargs.items()}
|
||||
)
|
||||
|
||||
return modalities
|
||||
|
@ -71,7 +71,7 @@ from .interfaces import (
|
||||
SupportsPP,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
flat_data: Annotated[
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
|
||||
]
|
||||
@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
flattened just like `pixel_values`.
|
||||
"""
|
||||
|
||||
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
|
||||
@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
class Llama4ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration(
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
# num_images x num_chunks, channel, image_size, image_size
|
||||
# TODO: confirm handling for variable lengths
|
||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||
patches_per_image = kwargs.pop("patches_per_image")
|
||||
aspect_ratios = kwargs.pop("aspect_ratios")
|
||||
if aspect_ratios.ndim == 3:
|
||||
aspect_ratios = aspect_ratios.squeeze(1)
|
||||
|
||||
return Llama4ImagePatchInputs(
|
||||
type="pixel_values",
|
||||
flat_data=flat_pixel_values,
|
||||
pixel_values=pixel_values,
|
||||
patches_per_image=patches_per_image,
|
||||
aspect_ratios=aspect_ratios,
|
||||
)
|
||||
@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration(
|
||||
self, image_input: Llama4ImagePatchInputs
|
||||
) -> MultiModalEmbeddings:
|
||||
assert self.vision_model and self.multi_modal_projector
|
||||
flat_data = image_input["flat_data"]
|
||||
pixel_values = image_input["pixel_values"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
# shard image input
|
||||
if self.use_data_parallel:
|
||||
vision_embeddings_flat = run_dp_sharded_vision_model(
|
||||
flat_data, self.vision_model
|
||||
pixel_values, self.vision_model
|
||||
)
|
||||
else:
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
vision_embeddings_flat = self.vision_model(pixel_values)
|
||||
|
||||
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
|
||||
|
||||
|
@ -75,7 +75,6 @@ from .interfaces import (
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- nc: Number of crops (dynamic)
|
||||
- bnc: Batch size * number of images * number of crops (dynamic)
|
||||
- np: Number of patches
|
||||
- tp: Token sequence positions
|
||||
- pd: Patch dimension
|
||||
"""
|
||||
|
||||
images: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
|
||||
]
|
||||
# Number of crops may vary per batch and image, so pass it as a list.
|
||||
images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
|
||||
|
||||
image_masks: Annotated[
|
||||
torch.Tensor | list[torch.Tensor] | None,
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||
]
|
||||
image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]
|
||||
|
||||
image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
|
||||
"""An index tensor that maps image features to their corresponding patch tokens."""
|
||||
|
||||
image_input_idx: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||
]
|
||||
# An index tensor that maps image features to their corresponding patch tokens.
|
||||
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
class MolmoForCausalLM(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
# vision backbone mapping
|
||||
@ -1451,18 +1443,12 @@ class MolmoForCausalLM(
|
||||
if images is None:
|
||||
return None
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
|
||||
)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
img_patch_id = kwargs.pop("img_patch_id", None)
|
||||
if not isinstance(img_patch_id, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
|
||||
)
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
if isinstance(img_patch_id, torch.Tensor):
|
||||
img_patch_id = img_patch_id.item()
|
||||
|
||||
assert isinstance(img_patch_id, int)
|
||||
self.img_patch_id = img_patch_id
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
@ -1481,17 +1467,9 @@ class MolmoForCausalLM(
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (
|
||||
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
||||
)
|
||||
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(
|
||||
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
|
||||
),
|
||||
image_features = self.vision_backbone(
|
||||
images=images.unsqueeze(0),
|
||||
image_masks=None if image_masks is None else image_masks.unsqueeze(0),
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
@ -1499,8 +1477,8 @@ class MolmoForCausalLM(
|
||||
results = []
|
||||
num_crops_list = num_crops.tolist()
|
||||
for feats, img_idx in zip(
|
||||
image_features_flat.split(num_crops_list),
|
||||
image_input_idx_flat.split(num_crops_list),
|
||||
image_features.split(num_crops_list),
|
||||
image_input_idx.split(num_crops_list),
|
||||
):
|
||||
is_valid = img_idx >= 0
|
||||
valid_img_idx = img_idx[is_valid]
|
||||
|
@ -11,7 +11,7 @@ import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypedDict, TypeVar
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
@ -40,7 +40,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.utils import (
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
@ -96,31 +95,35 @@ MAX_FRAMES = 16
|
||||
DEFAULT_NUM_TILES = 12
|
||||
|
||||
|
||||
class NanoNemotronVLImagePixelInputs(TypedDict):
|
||||
class NanoNemotronVLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- bnp: Batch size * number of images * (1 + num_patches)
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image patch
|
||||
- w: Width of each image patch
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values_flat: torch.Tensor
|
||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
||||
Dimensions:
|
||||
- n: Number of images
|
||||
- f: Total image feature size
|
||||
- h: Hidden size (must match the hidden size of language model backbone)
|
||||
"""
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
class NanoNemotronVLImageEmbeddinInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor | list[torch.Tensor]
|
||||
"""
|
||||
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
|
||||
|
||||
|
||||
NanoNemotronVLImageInputs: TypeAlias = (
|
||||
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddinInputs
|
||||
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
@ -710,37 +713,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"""Basic image-only MultiModalProcessor for InternVL-style models."""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
image_token_id = hf_processor.image_token_id
|
||||
|
||||
# Since there may be extra tokens in the feature placeholders,
|
||||
# we need to pass the image token ID to the model to select the
|
||||
# tokens to merge from the vision encoder outputs
|
||||
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
num_images = len(image_num_patches)
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
@ -748,7 +726,6 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -814,25 +791,6 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
):
|
||||
"""MultiModalProcessor extended for video support"""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs, tok_kwargs
|
||||
)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
if (
|
||||
self.info.supports_video
|
||||
and (video_token_id := hf_processor.video_token_id) is not None
|
||||
):
|
||||
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@ -841,13 +799,12 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
|
||||
if self.info.supports_video:
|
||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||
num_videos = len(video_num_patches)
|
||||
|
||||
video_fields = dict(
|
||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_patches
|
||||
),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
else:
|
||||
video_fields = {}
|
||||
@ -999,6 +956,8 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
class NemotronH_Nano_VL_V2(
|
||||
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
@ -1051,8 +1010,6 @@ class NemotronH_Nano_VL_V2(
|
||||
)
|
||||
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.video_context_token_id = None
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
@ -1106,37 +1063,12 @@ class NemotronH_Nano_VL_V2(
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
return NanoNemotronVLImageEmbeddinInputs(
|
||||
return NanoNemotronVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}"
|
||||
)
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}"
|
||||
)
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
return NanoNemotronVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=pixel_values_flat,
|
||||
@ -1285,28 +1217,10 @@ class NemotronH_Nano_VL_V2(
|
||||
if video_embeds is not None:
|
||||
return NanoNemotronVLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
data=flatten_bn(video_embeds),
|
||||
data=video_embeds,
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
assert isinstance(video_token_id, torch.Tensor)
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}"
|
||||
)
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}"
|
||||
)
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
expected_h = expected_w = self.config.force_image_size
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
|
||||
|
@ -496,8 +496,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
if isinstance(image_token_id, torch.Tensor):
|
||||
image_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(image_token_id, int)
|
||||
self.img_context_token_id = image_token_id
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
return InternVLImagePixelInputs(
|
||||
|
@ -814,8 +814,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
if isinstance(image_token_id, torch.Tensor):
|
||||
image_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
assert isinstance(image_token_id, int)
|
||||
self.img_context_token_id = image_token_id
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
return SkyworkR1VImagePixelInputs(
|
||||
|
@ -432,7 +432,7 @@ def group_mm_kwargs_by_modality(
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = json_map_leaves(
|
||||
lambda x: x.to(device=device),
|
||||
lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
|
||||
mm_kwargs_group,
|
||||
)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user