mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Bugfix] Clean up MiniMax-VL and fix processing (#17354)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
		| @ -979,6 +979,13 @@ See [this page](#generative-models) for more information on how to use generativ | ||||
|   * ✅︎ | ||||
|   * ✅︎ | ||||
|   * ✅︎ | ||||
| - * `MiniMaxVL01ForConditionalGeneration` | ||||
|   * MiniMax-VL | ||||
|   * T + I<sup>E+</sup> | ||||
|   * `MiniMaxAI/MiniMax-VL-01`, etc. | ||||
|   * | ||||
|   * ✅︎ | ||||
|   * ✅︎ | ||||
| - * `Mistral3ForConditionalGeneration` | ||||
|   * Mistral3 | ||||
|   * T + I<sup>+</sup> | ||||
|  | ||||
| @ -270,6 +270,7 @@ def _test_processing_correctness_mistral( | ||||
|     "openbmb/MiniCPM-Llama3-V-2_5", | ||||
|     "openbmb/MiniCPM-o-2_6", | ||||
|     "openbmb/MiniCPM-V-2_6", | ||||
|     "MiniMaxAI/MiniMax-VL-01", | ||||
|     "allenai/Molmo-7B-D-0924", | ||||
|     "allenai/Molmo-7B-O-0924", | ||||
|     "nvidia/NVLM-D-72B", | ||||
|  | ||||
| @ -12,7 +12,6 @@ from ...utils import build_model_context | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"]) | ||||
| # yapf: enable | ||||
| @pytest.mark.parametrize("num_imgs", [1, 2]) | ||||
| def test_processor_override( | ||||
|     image_assets: _ImageAssets, | ||||
|  | ||||
| @ -1,52 +1,32 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| from collections.abc import Iterable, Mapping | ||||
| from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast | ||||
|  | ||||
| from abc import abstractmethod | ||||
| from collections.abc import Iterable, Mapping, Sequence | ||||
| from dataclasses import dataclass | ||||
| from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, | ||||
|                     TypeVar, Union, cast) | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig | ||||
| from transformers.image_processing_utils import select_best_resolution | ||||
| from transformers import BatchFeature | ||||
|  | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.jsontree import json_map_leaves | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.activation import get_act_fn | ||||
| from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||||
|                                                RowParallelLinear) | ||||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||
| from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict | ||||
| from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs | ||||
| from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, | ||||
|                                    ImageSize, MultiModalDataItems) | ||||
| from vllm.multimodal.processing import (BaseMultiModalProcessor, | ||||
|                                         BaseProcessingInfo, PromptReplacement, | ||||
|                                         PromptUpdate) | ||||
| from vllm.multimodal.profiling import BaseDummyInputsBuilder | ||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||
| from vllm.multimodal.inputs import MultiModalFieldConfig | ||||
| from vllm.sequence import IntermediateTensors | ||||
| from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config | ||||
|  | ||||
| from .clip import CLIPVisionModel | ||||
| from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP | ||||
| from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder, | ||||
|                     init_vision_tower_for_llava) | ||||
| from .llava_next import LlavaNextProcessingInfo | ||||
| from .pixtral import PixtralHFVisionModel | ||||
| from .siglip import SiglipVisionModel | ||||
| from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, | ||||
|                     maybe_prefix, merge_multimodal_embeddings) | ||||
| from .vision import get_vision_encoder_info | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| # For dummy input only | ||||
| @dataclass | ||||
| class MaxImageTokenMeta: | ||||
|     width: int = 1024 | ||||
|     height: int = 1024 | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01ImagePixelInputs(TypedDict): | ||||
| @ -69,66 +49,8 @@ class MiniMaxVL01ImageEmbeddingInputs(TypedDict): | ||||
|     """ | ||||
|  | ||||
|  | ||||
| def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): | ||||
|     if not isinstance(grid_pinpoints, list): | ||||
|         raise TypeError("grid_pinpoints should be a list of tuples or lists") | ||||
|  | ||||
|     # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, | ||||
|     # otherwise it will cause wrong calculate | ||||
|     if not isinstance(image_size, (list, tuple)): | ||||
|         if not isinstance(image_size, (torch.Tensor, np.ndarray)): | ||||
|             raise TypeError("image_size invalid type " + | ||||
|                             f"{type(image_size)} with value {image_size}") | ||||
|         image_size = image_size.tolist() | ||||
|  | ||||
|     best_resolution = select_best_resolution(image_size, grid_pinpoints) | ||||
|     height, width = best_resolution | ||||
|     num_patches = 0 | ||||
|     # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 | ||||
|     for i in range(0, height, patch_size): | ||||
|         for j in range(0, width, patch_size): | ||||
|             num_patches += 1 | ||||
|     # add the base patch | ||||
|     num_patches += 1 | ||||
|     return num_patches | ||||
|  | ||||
|  | ||||
| def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): | ||||
|     if not isinstance(grid_pinpoints, list): | ||||
|         raise TypeError("grid_pinpoints should be a list of tuples or lists") | ||||
|  | ||||
|     # ! VERY IMPORTANT if image_size is tensor, | ||||
|     # must convert to into tuple, | ||||
|     # otherwise it will cause wrong calculate | ||||
|     if not isinstance(image_size, (list, tuple)): | ||||
|         if not isinstance(image_size, (torch.Tensor, np.ndarray)): | ||||
|             raise TypeError( | ||||
|                 "image_size invalid type " + | ||||
|                 f"{type(image_size)} not valid, " + | ||||
|                 "should be either list, tuple, np.ndarray or tensor") | ||||
|         image_size = image_size.tolist() | ||||
|  | ||||
|     height, width = select_best_resolution(image_size, grid_pinpoints) | ||||
|     return height // patch_size, width // patch_size | ||||
|  | ||||
|  | ||||
| def unpad_image(tensor, original_size): | ||||
|     original_height, original_width = original_size | ||||
|     current_height, current_width = tensor.shape[1:] | ||||
|  | ||||
|     original_aspect_ratio = original_width / original_height | ||||
|     current_aspect_ratio = current_width / current_height | ||||
|  | ||||
|     if original_aspect_ratio > current_aspect_ratio: | ||||
|         new_height = int(original_height * current_width) // original_width | ||||
|         padding = (current_height - new_height) // 2 | ||||
|         unpadded_tensor = tensor[:, padding:current_height - padding, :] | ||||
|     else: | ||||
|         new_width = int(original_width * current_height) // original_height | ||||
|         padding = (current_width - new_width) // 2 | ||||
|         unpadded_tensor = tensor[:, :, padding:current_width - padding] | ||||
|  | ||||
|     return unpadded_tensor | ||||
| MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, | ||||
|                                MiniMaxVL01ImageEmbeddingInputs] | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01MultiModalProjector(nn.Module): | ||||
| @ -161,144 +83,29 @@ class MiniMaxVL01MultiModalProjector(nn.Module): | ||||
|         return hidden_states | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01LikeConfig(Protocol): | ||||
|     vision_config: Final[PretrainedConfig] | ||||
|     image_token_index: Final[int] | ||||
|     vision_feature_select_strategy: Final[str] | ||||
|     vision_feature_layer: Final[Union[int, list[int]]] | ||||
| class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01LikeProcessor(Protocol): | ||||
|     image_token: Final[str] | ||||
|  | ||||
|  | ||||
| _I = TypeVar("_I", bound=BaseProcessingInfo) | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]): | ||||
|  | ||||
|     def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: | ||||
|         num_images = mm_counts.get("image", 0) | ||||
|         processor = self.info.get_hf_processor() | ||||
|         image_token = processor.image_token | ||||
|         return image_token * num_images | ||||
|  | ||||
|     def get_dummy_mm_data( | ||||
|         self, | ||||
|         seq_len: int, | ||||
|         mm_counts: Mapping[str, int], | ||||
|     ) -> MultiModalDataDict: | ||||
|         num_images = mm_counts.get("image", 0) | ||||
|  | ||||
|         return { | ||||
|             "image": | ||||
|             self._get_dummy_images(width=MaxImageTokenMeta.width, | ||||
|                                    height=MaxImageTokenMeta.height, | ||||
|                                    num_images=num_images) | ||||
|         } | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01ProcessingInfo(BaseProcessingInfo): | ||||
| class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo): | ||||
|  | ||||
|     def get_hf_config(self): | ||||
|         return self.ctx.get_hf_config(MiniMaxVL01Config) | ||||
|  | ||||
|     def get_hf_processor(self, **kwargs: object): | ||||
|         hf_processor = self.ctx.get_hf_processor(**kwargs) | ||||
|         image_processor = hf_processor.image_processor | ||||
|         image_processor.anyres_preprocess = ( | ||||
|             image_processor.anyres_for_vllm_preprocess) | ||||
|  | ||||
|         return hf_processor | ||||
|  | ||||
|     def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: | ||||
|         return {"image": None} | ||||
|  | ||||
|     def get_vision_encoder_info(self): | ||||
|         return get_vision_encoder_info(self.get_hf_config()) | ||||
|  | ||||
|     def _apply_feature_select_strategy( | ||||
|         self, | ||||
|         strategy: str, | ||||
|         encoder_num_image_tokens: int, | ||||
|     ) -> int: | ||||
|         if strategy == "default": | ||||
|             return encoder_num_image_tokens - 1 | ||||
|         if strategy == "full": | ||||
|             return encoder_num_image_tokens | ||||
|  | ||||
|         msg = f"Unexpected feature select strategy: {strategy!r}" | ||||
|         raise NotImplementedError(msg) | ||||
|  | ||||
|     def get_num_image_tokens( | ||||
|         self, | ||||
|         *, | ||||
|         image_width: int, | ||||
|         image_height: int, | ||||
|     ) -> int: | ||||
|         hf_config = self.get_hf_config() | ||||
|         vision_encoder_info = self.get_vision_encoder_info() | ||||
|  | ||||
|         return self._apply_feature_select_strategy( | ||||
|             hf_config.vision_feature_select_strategy, | ||||
|             vision_encoder_info.get_num_image_tokens( | ||||
|                 image_width=image_width, | ||||
|                 image_height=image_height, | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def get_image_size_with_most_features(self) -> ImageSize: | ||||
|         vision_encoder_info = self.get_vision_encoder_info() | ||||
|         width = height = vision_encoder_info.get_image_size() | ||||
|         return ImageSize(width=width, height=height) | ||||
|  | ||||
|     def get_max_image_tokens(self) -> int: | ||||
|         target_width, target_height = self.get_image_size_with_most_features() | ||||
|  | ||||
|         return self.get_num_image_tokens( | ||||
|             image_width=target_width, | ||||
|             image_height=target_height, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]): | ||||
|  | ||||
|     # Copied from BaseMultiModalProcessor | ||||
|     @abstractmethod | ||||
|     def _get_mm_fields_config( | ||||
|         self, | ||||
|         hf_inputs: BatchFeature, | ||||
|         hf_processor_mm_kwargs: Mapping[str, object], | ||||
|     ) -> Mapping[str, MultiModalFieldConfig]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _get_prompt_updates( | ||||
|         self, | ||||
|         mm_items: MultiModalDataItems, | ||||
|         hf_processor_mm_kwargs: Mapping[str, object], | ||||
|         out_mm_kwargs: MultiModalKwargs, | ||||
|     ) -> Sequence[PromptUpdate]: | ||||
|         hf_config = self.info.get_hf_config() | ||||
|         image_token_id = hf_config.image_token_index | ||||
|  | ||||
|         def get_replacement(item_idx: int): | ||||
|             images = mm_items.get_items( | ||||
|                 "image", (ImageEmbeddingItems, ImageProcessorItems)) | ||||
|  | ||||
|             if isinstance(images, ImageEmbeddingItems): | ||||
|                 num_image_tokens = images.get_feature_size(item_idx) | ||||
|             else: | ||||
|                 image_size = images.get_image_size(item_idx) | ||||
|                 num_image_tokens = self.info.get_num_image_tokens( | ||||
|                     image_width=image_size.width, | ||||
|                     image_height=image_size.height, | ||||
|                 ) | ||||
|  | ||||
|             return [image_token_id] * num_image_tokens | ||||
|  | ||||
|         return [ | ||||
|             PromptReplacement( | ||||
|                 modality="image", | ||||
|                 target=[image_token_id], | ||||
|                 replacement=get_replacement, | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class MiniMaxVL01MultiModalProcessor( | ||||
|         BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]): | ||||
|         BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]): | ||||
|  | ||||
|     def _call_hf_processor( | ||||
|         self, | ||||
| @ -314,10 +121,9 @@ class MiniMaxVL01MultiModalProcessor( | ||||
|  | ||||
|         pixel_values = processed_outputs.get("pixel_values") | ||||
|         if pixel_values is not None: | ||||
|             # Avoid padding since we need the output for each image to be | ||||
|             # independent of other images for the cache to work correctly | ||||
|             image_sizes = processed_outputs["image_sizes"] | ||||
|             min_len = min(len(pixel_values), len(image_sizes)) | ||||
|             pixel_values = pixel_values[:min_len] | ||||
|             image_sizes = image_sizes[:min_len] | ||||
|             assert len(pixel_values) == len(image_sizes) | ||||
|  | ||||
|             processed_outputs["pixel_values"] = [ | ||||
| @ -337,65 +143,6 @@ class MiniMaxVL01MultiModalProcessor( | ||||
|         } | ||||
|  | ||||
|  | ||||
| def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int: | ||||
|     """Determine the number of hidden layers to initialize up to in the | ||||
|     visual encoder. | ||||
|      | ||||
|     Args: | ||||
|         hf_config: Model config with vision feature layer(s). | ||||
|     """ | ||||
|     feature_layers = hf_config.vision_feature_layer | ||||
|     num_hidden_layers = hf_config.vision_config.num_hidden_layers | ||||
|     # If we have one feature layer, initialize up to that layer | ||||
|     if isinstance(feature_layers, int): | ||||
|         return _get_layer_index(feature_layers, num_hidden_layers) | ||||
|     # If we have multiple feature layers, initialize up to the deepest one | ||||
|     elif isinstance(feature_layers, (list, tuple)): | ||||
|         return max( | ||||
|             _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) | ||||
|     raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" | ||||
|                     " is not supported") | ||||
|  | ||||
|  | ||||
| def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: | ||||
|     """Given a signed vision feature layer, get the number of hidden layers | ||||
|     needed to leverage it. | ||||
|  | ||||
|     Args: | ||||
|         feature_layer_index: Index of a required layer in the visual encoder. | ||||
|         num_hidden_layers: The total number of hidden layers in the visual | ||||
|             encoder. | ||||
|     """ | ||||
|     if feature_layer_index < 0: | ||||
|         return num_hidden_layers + feature_layer_index + 1 | ||||
|     return feature_layer_index | ||||
|  | ||||
|  | ||||
| def init_vision_tower_for_MiniMaxVL01( | ||||
|     hf_config: MiniMaxVL01LikeConfig, | ||||
|     quant_config: Optional[QuantizationConfig], | ||||
|     *, | ||||
|     require_post_norm: Optional[bool] = None, | ||||
|     prefix: str = "", | ||||
| ) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]: | ||||
|     vision_config = hf_config.vision_config | ||||
|  | ||||
|     # Initialize the vision tower only up to the deepest required feature layer | ||||
|     num_hidden_layers = _get_num_hidden_layers(hf_config) | ||||
|  | ||||
|     if isinstance(vision_config, CLIPVisionConfig): | ||||
|         return CLIPVisionModel( | ||||
|             vision_config, | ||||
|             quant_config=quant_config, | ||||
|             num_hidden_layers_override=num_hidden_layers, | ||||
|             require_post_norm=require_post_norm, | ||||
|             prefix=prefix, | ||||
|         ) | ||||
|  | ||||
|     msg = f"Unsupported vision config: {type(vision_config)}" | ||||
|     raise NotImplementedError(msg) | ||||
|  | ||||
|  | ||||
| @MULTIMODAL_REGISTRY.register_processor( | ||||
|     MiniMaxVL01MultiModalProcessor, | ||||
|     info=MiniMaxVL01ProcessingInfo, | ||||
| @ -419,7 +166,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||
|         self.multimodal_config = multimodal_config | ||||
|  | ||||
|         # TODO: Optionally initializes this for supporting embeddings. | ||||
|         self.vision_tower = init_vision_tower_for_MiniMaxVL01( | ||||
|         self.vision_tower = init_vision_tower_for_llava( | ||||
|             config, | ||||
|             quant_config, | ||||
|             require_post_norm=False, | ||||
| @ -476,7 +223,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||
|  | ||||
|     def _image_pixels_to_features( | ||||
|         self, | ||||
|         vision_tower: Union[CLIPVisionModel], | ||||
|         vision_tower: Union[CLIPVisionModel, SiglipVisionModel, | ||||
|                             PixtralHFVisionModel], | ||||
|         pixel_values: Union[torch.Tensor, list[torch.Tensor]], | ||||
|     ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: | ||||
|         # NOTE: we skip the step to select the vision feature layer since | ||||
| @ -496,7 +244,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||
|  | ||||
|     def _process_image_pixels( | ||||
|         self, | ||||
|         inputs: Union[MiniMaxVL01ImagePixelInputs], | ||||
|         inputs: MiniMaxVL01ImagePixelInputs, | ||||
|     ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: | ||||
|         assert self.vision_tower is not None | ||||
|  | ||||
| @ -506,7 +254,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||
|  | ||||
|     def _process_image_input( | ||||
|         self, | ||||
|         image_input: MiniMaxVL01ImagePixelInputs, | ||||
|         image_input: MiniMaxVL01ImageInputs, | ||||
|     ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: | ||||
|         if image_input["type"] == "image_embeds": | ||||
|             return image_input["data"] | ||||
| @ -539,7 +287,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||
|         return data | ||||
|  | ||||
|     def _parse_and_validate_image_input( | ||||
|             self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]: | ||||
|             self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: | ||||
|         pixel_values = kwargs.pop("pixel_values", None) | ||||
|         image_embeds = kwargs.pop("image_embeds", None) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user