mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix broken Minimax-01-VL model (#22116)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -889,6 +889,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
|
||||
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
|
||||
|
||||
|
||||
def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "MiniMaxAI/MiniMax-VL-01"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": question}],
|
||||
}
|
||||
]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Mistral-3 HF-format
|
||||
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1539,6 +1572,7 @@ model_example_map = {
|
||||
"mantis": run_mantis,
|
||||
"minicpmo": run_minicpmo,
|
||||
"minicpmv": run_minicpmv,
|
||||
"minimax_vl_01": run_minimax_vl_01,
|
||||
"mistral3": run_mistral3,
|
||||
"mllama": run_mllama,
|
||||
"molmo": run_molmo,
|
||||
|
@ -30,7 +30,6 @@ from ..utils import dummy_hf_overrides
|
||||
|
||||
ARCH_TO_SKIP = {
|
||||
"MolmoForCausalLM": "incompatible requirements",
|
||||
"MiniMaxVL01ForConditionalGeneration": "broken model",
|
||||
}
|
||||
ARCH_NEEDS_EXTRAS = [
|
||||
"InternVLChatModel",
|
||||
|
@ -1,11 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Literal, Optional, TypedDict, Union, cast
|
||||
from typing import Annotated, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -17,6 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -29,24 +32,36 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class MiniMaxVL01ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
class MiniMaxVL01ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- np: Number of patches + 1
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})]
|
||||
|
||||
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
|
||||
# This should be in `(height, width)` format.
|
||||
|
||||
|
||||
class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
class MiniMaxVL01ImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- ifs: Image feature size
|
||||
- hs: Hidden size (must match language model backbone)
|
||||
"""
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
|
||||
|
||||
|
||||
MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
|
||||
@ -141,6 +156,7 @@ class MiniMaxVL01MultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return {
|
||||
"pixel_values": MultiModalFieldConfig.batched("image"),
|
||||
"image_sizes": MultiModalFieldConfig.batched("image"),
|
||||
"image_embeds": MultiModalFieldConfig.batched("image"),
|
||||
}
|
||||
|
||||
@ -239,7 +255,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values)
|
||||
image_features = tuple(vision_tower(p) for p in pixel_values)
|
||||
|
||||
def select_features(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
@ -252,6 +268,56 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
json_map_leaves(select_features, image_features),
|
||||
)
|
||||
|
||||
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
|
||||
def pack_image_features(self, image_features: list[torch.Tensor],
|
||||
image_sizes: torch.Tensor):
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = (self.config.vision_config.image_size //
|
||||
self.config.vision_config.patch_size)
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with "
|
||||
"the image size.")
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(num_patch_height,
|
||||
num_patch_width, height,
|
||||
width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1,
|
||||
3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature,
|
||||
image_sizes[image_idx])
|
||||
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1).to(
|
||||
image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature),
|
||||
dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature,
|
||||
self.image_newline[None].to(image_feature)),
|
||||
dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
return new_image_features
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: MiniMaxVL01ImagePixelInputs,
|
||||
@ -259,7 +325,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||
|
||||
def _process_image_input(
|
||||
@ -281,38 +346,31 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
image_embeds = self.multi_modal_projector(torch.cat(image_features))
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
return self.pack_image_features(image_embeds, image_sizes)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if pixel_values is not None and image_sizes is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return MiniMaxVL01ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
|
Reference in New Issue
Block a user