diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 19ce8c0672..35a5fa0c2e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -638,7 +638,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index f2c09d3e84..ee546e7af8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -451,7 +451,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 trust_remote_code=True, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c22d871ab2..2d785c30fd 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,12 +27,14 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import chain from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar @@ -47,10 +49,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -218,6 +221,187 @@ class Resampler2_5(BaseResampler): return x +class Resampler4_5(Resampler2_5): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix) + + trunc_normal_(self.query, std=.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, + pos: np.ndarray): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu") -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size)).float().to(device) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu"): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, + device=device)) + else: + pos_embed_temporal.append(self.temporal_pos_embed[ + temporal_ids_flatten[i]].to(dtype)) # D + + pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, + padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1)) + + start = end + + k = torch.nn.utils.rnn.pad_sequence(merge_k, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence(merge_v, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, + padding_value=True).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -354,9 +538,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, - 6) or self.get_model_version() == (4, - 0): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -637,8 +819,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == ( - 2, 6) or self.info.get_model_version() == (4, 0): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -816,7 +997,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): # and config class self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.version = get_version_by_config(self.config) self.llm = self.init_llm(vllm_config=vllm_config, @@ -1364,11 +1544,9 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): prefix: str = "", ) -> nn.Module: quant_config = self._maybe_ignore_quant_config(quant_config) - model = Idefics2VisionTransformer( - config.vision_config, - quant_config=quant_config, - prefix=prefix, - use_data_parallel=self.use_data_parallel) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1436,11 +1614,121 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): return loader.load_weights(weights) +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get('temporal_ids', None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( + temporal_ids) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999..d09c5fa924 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_basic.jinja" +def _get_minicpmv_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", @@ -27,6 +37,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 0000000000..661ebd1cf5 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file