[Bugfix] Clean up MiniMax-VL and fix processing (#17354)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-29 20:42:16 +08:00
committed by GitHub
parent 890f104cdf
commit 00ee37efa2
4 changed files with 38 additions and 283 deletions

View File

@ -979,6 +979,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `MiniMaxVL01ForConditionalGeneration`
* MiniMax-VL
* T + I<sup>E+</sup>
* `MiniMaxAI/MiniMax-VL-01`, etc.
*
* ✅︎
* ✅︎
- * `Mistral3ForConditionalGeneration`
* Mistral3
* T + I<sup>+</sup>

View File

@ -270,6 +270,7 @@ def _test_processing_correctness_mistral(
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"MiniMaxAI/MiniMax-VL-01",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",

View File

@ -12,7 +12,6 @@ from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
image_assets: _ImageAssets,

View File

@ -1,52 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast)
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig
from transformers.image_processing_utils import select_best_resolution
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder,
init_vision_tower_for_llava)
from .llava_next import LlavaNextProcessingInfo
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
# For dummy input only
@dataclass
class MaxImageTokenMeta:
width: int = 1024
height: int = 1024
class MiniMaxVL01ImagePixelInputs(TypedDict):
@ -69,66 +49,8 @@ class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
"""
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple,
# otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError("image_size invalid type " +
f"{type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor,
# must convert to into tuple,
# otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
"image_size invalid type " +
f"{type(image_size)} not valid, " +
"should be either list, tuple, np.ndarray or tensor")
image_size = image_size.tolist()
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(original_height * current_width) // original_width
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
new_width = int(original_width * current_height) // original_height
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
MiniMaxVL01ImageEmbeddingInputs]
class MiniMaxVL01MultiModalProjector(nn.Module):
@ -161,144 +83,29 @@ class MiniMaxVL01MultiModalProjector(nn.Module):
return hidden_states
class MiniMaxVL01LikeConfig(Protocol):
vision_config: Final[PretrainedConfig]
image_token_index: Final[int]
vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, list[int]]]
class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder):
pass
class MiniMaxVL01LikeProcessor(Protocol):
image_token: Final[str]
_I = TypeVar("_I", bound=BaseProcessingInfo)
class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=MaxImageTokenMeta.width,
height=MaxImageTokenMeta.height,
num_images=num_images)
}
class MiniMaxVL01ProcessingInfo(BaseProcessingInfo):
class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(MiniMaxVL01Config)
def get_hf_processor(self, **kwargs: object):
hf_processor = self.ctx.get_hf_processor(**kwargs)
image_processor = hf_processor.image_processor
image_processor.anyres_preprocess = (
image_processor.anyres_for_vllm_preprocess)
return hf_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
class MiniMaxVL01MultiModalProcessor(
BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]):
BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]):
def _call_hf_processor(
self,
@ -314,10 +121,9 @@ class MiniMaxVL01MultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
min_len = min(len(pixel_values), len(image_sizes))
pixel_values = pixel_values[:min_len]
image_sizes = image_sizes[:min_len]
assert len(pixel_values) == len(image_sizes)
processed_outputs["pixel_values"] = [
@ -337,65 +143,6 @@ class MiniMaxVL01MultiModalProcessor(
}
def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given a signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
feature_layer_index: Index of a required layer in the visual encoder.
num_hidden_layers: The total number of hidden layers in the visual
encoder.
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index
def init_vision_tower_for_MiniMaxVL01(
hf_config: MiniMaxVL01LikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the deepest required feature layer
num_hidden_layers = _get_num_hidden_layers(hf_config)
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(
MiniMaxVL01MultiModalProcessor,
info=MiniMaxVL01ProcessingInfo,
@ -419,7 +166,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_MiniMaxVL01(
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
@ -476,7 +223,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
@ -496,7 +244,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
inputs: Union[MiniMaxVL01ImagePixelInputs],
inputs: MiniMaxVL01ImagePixelInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
@ -506,7 +254,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self,
image_input: MiniMaxVL01ImagePixelInputs,
image_input: MiniMaxVL01ImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -539,7 +287,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]:
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)