[Bugfix] Fix broken Minimax-01-VL model (#22116)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-08-19 16:49:29 +08:00
committed by GitHub
parent 31436e8b4f
commit 31fd3265c8
3 changed files with 123 additions and 32 deletions

View File

@ -889,6 +889,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "MiniMaxAI/MiniMax-VL-01"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
tensor_parallel_size=8,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
[
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": question}],
}
]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1539,6 +1572,7 @@ model_example_map = {
"mantis": run_mantis,
"minicpmo": run_minicpmo,
"minicpmv": run_minicpmv,
"minimax_vl_01": run_minimax_vl_01,
"mistral3": run_mistral3,
"mllama": run_mllama,
"molmo": run_molmo,

View File

@ -30,7 +30,6 @@ from ..utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Literal, Optional, Union, cast
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
@ -17,6 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -29,24 +32,36 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
class MiniMaxVL01ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class MiniMaxVL01ImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width
Note that `height` or `width` may be different per batch and image,
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})]
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.
class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class MiniMaxVL01ImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
@ -141,6 +156,7 @@ class MiniMaxVL01MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
"image_sizes": MultiModalFieldConfig.batched("image"),
"image_embeds": MultiModalFieldConfig.batched("image"),
}
@ -239,7 +255,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = tuple(vision_tower(p) for p in pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -252,6 +268,56 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
json_map_leaves(select_features, image_features),
)
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],
image_sizes: torch.Tensor):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (self.config.vision_config.image_size //
self.config.vision_config.patch_size)
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with "
"the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height,
num_patch_width, height,
width, -1)
image_feature = image_feature.permute(4, 0, 2, 1,
3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature,
image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1).to(
image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature),
dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature,
self.image_newline[None].to(image_feature)),
dim=0)
new_image_features.append(image_feature)
return new_image_features
def _process_image_pixels(
self,
inputs: MiniMaxVL01ImagePixelInputs,
@ -259,7 +325,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(
@ -281,38 +346,31 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
image_sizes = image_input.get("image_sizes")
return self.pack_image_features(image_embeds, image_sizes)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if pixel_values is not None and image_sizes is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
)
if image_embeds is not None: