mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0405645a6c | 
| @ -530,6 +530,39 @@ def run_qwen2_vl(question: str, modality: str): | |||||||
|     return llm, prompt, stop_token_ids |     return llm, prompt, stop_token_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Qwen2-VL | ||||||
|  | def run_qwen2_5_vl(question: str, modality: str): | ||||||
|  |  | ||||||
|  |     model_name = "Qwen/Qwen2.5-VL-3B-Instruct" | ||||||
|  |  | ||||||
|  |     llm = LLM( | ||||||
|  |         model=model_name, | ||||||
|  |         max_model_len=4096, | ||||||
|  |         max_num_seqs=5, | ||||||
|  |         mm_processor_kwargs={ | ||||||
|  |             "min_pixels": 28 * 28, | ||||||
|  |             "max_pixels": 256 * 28 * 28, | ||||||
|  |         }, | ||||||
|  |         disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, | ||||||
|  |         limit_mm_per_prompt={ | ||||||
|  |             "image": 1, | ||||||
|  |             "video": 0 | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     if modality == "image": | ||||||
|  |         placeholder = "<|image_pad|>" | ||||||
|  |     elif modality == "video": | ||||||
|  |         placeholder = "<|video_pad|>" | ||||||
|  |  | ||||||
|  |     prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" | ||||||
|  |               f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" | ||||||
|  |               f"{question}<|im_end|>\n" | ||||||
|  |               "<|im_start|>assistant\n") | ||||||
|  |     stop_token_ids = None | ||||||
|  |     return llm, prompt, stop_token_ids | ||||||
|  |  | ||||||
|  |  | ||||||
| model_example_map = { | model_example_map = { | ||||||
|     "aria": run_aria, |     "aria": run_aria, | ||||||
|     "blip-2": run_blip2, |     "blip-2": run_blip2, | ||||||
| @ -556,6 +589,7 @@ model_example_map = { | |||||||
|     "pixtral_hf": run_pixtral_hf, |     "pixtral_hf": run_pixtral_hf, | ||||||
|     "qwen_vl": run_qwen_vl, |     "qwen_vl": run_qwen_vl, | ||||||
|     "qwen2_vl": run_qwen2_vl, |     "qwen2_vl": run_qwen2_vl, | ||||||
|  |     "qwen2_5_vl": run_qwen2_5_vl, | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										748
									
								
								vllm/model_executor/models/qwen2_5_vl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										748
									
								
								vllm/model_executor/models/qwen2_5_vl.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,748 @@ | |||||||
|  | # Adapted from | ||||||
|  | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | ||||||
|  | # Copyright 2025 The vLLM team. | ||||||
|  | # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. | ||||||
|  | # | ||||||
|  | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||||||
|  | # and OPT implementations in this library. It has been modified from its | ||||||
|  | # original forms to accommodate minor architectural differences compared | ||||||
|  | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" | ||||||
|  | from functools import cached_property, partial | ||||||
|  | from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, | ||||||
|  |                     Set, Tuple, Type, TypedDict, Union) | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from einops import rearrange, repeat | ||||||
|  | from transformers import BatchFeature | ||||||
|  | from transformers.models.qwen2_5_vl import (Qwen2_5_VLImageProcessor, | ||||||
|  |                                             Qwen2_5_VLProcessor) | ||||||
|  | from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( | ||||||
|  |     Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) | ||||||
|  | from transformers.models.qwen2_5_vl.image_processing_qwen2_5_vl import ( | ||||||
|  |     smart_resize) | ||||||
|  | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( | ||||||
|  |     Qwen2_5_VisionTransformerPretrainedModel) | ||||||
|  |  | ||||||
|  | from vllm.attention import AttentionMetadata | ||||||
|  | from vllm.config import VllmConfig | ||||||
|  | from vllm.distributed import parallel_state, tensor_model_parallel_all_gather | ||||||
|  | from vllm.distributed import utils as dist_utils | ||||||
|  | from vllm.logger import init_logger | ||||||
|  | from vllm.model_executor import SamplingMetadata | ||||||
|  | from vllm.model_executor.layers.activation import QuickGELU | ||||||
|  | from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||||||
|  |                                                RowParallelLinear) | ||||||
|  | from vllm.model_executor.layers.quantization import QuantizationConfig | ||||||
|  | from vllm.model_executor.layers.quantization.gptq import GPTQConfig | ||||||
|  | from vllm.model_executor.layers.quantization.gptq_marlin import ( | ||||||
|  |     GPTQMarlinConfig) | ||||||
|  | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||||||
|  | from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||||||
|  | from vllm.model_executor.models.module_mapping import MultiModelKeys | ||||||
|  | from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||
|  | from vllm.multimodal.inputs import (ImageItem, ModalityData, | ||||||
|  |                                     MultiModalFieldConfig, MultiModalKwargs, | ||||||
|  |                                     VideoItem) | ||||||
|  | from vllm.multimodal.parse import (ImageSize, ModalityDataItems, | ||||||
|  |                                    MultiModalDataItems, MultiModalDataParser) | ||||||
|  | from vllm.multimodal.processing import (BaseMultiModalProcessor, | ||||||
|  |                                         BaseProcessingInfo, PromptReplacement) | ||||||
|  | from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs | ||||||
|  | from vllm.platforms import _Backend | ||||||
|  | from vllm.sequence import IntermediateTensors | ||||||
|  | from vllm.transformers_utils.config import uses_mrope | ||||||
|  |  | ||||||
|  | from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP | ||||||
|  | from .utils import (AutoWeightsLoader, WeightsMapper, | ||||||
|  |                     init_vllm_registered_model, maybe_prefix, | ||||||
|  |                     merge_multimodal_embeddings) | ||||||
|  | from .vision import get_vit_attn_backend | ||||||
|  |  | ||||||
|  | logger = init_logger(__name__) | ||||||
|  |  | ||||||
|  | # For profile run | ||||||
|  | _MAX_FRAMES_PER_VIDEO = 16 | ||||||
|  |  | ||||||
|  | # === Vision Inputs === # | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qwen2_5_VLImagePixelInputs(TypedDict): | ||||||
|  |     type: Literal["pixel_values"] | ||||||
|  |     pixel_values: torch.Tensor | ||||||
|  |     """Shape: | ||||||
|  |     `(num_patches, num_channels * patch_size * patch_size)` | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     image_grid_thw: torch.Tensor | ||||||
|  |     """Shape: `(num_images, 3)` | ||||||
|  |     This should be in `(grid_t, grid_h, grid_w)` format. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qwen2_5_VLVideoPixelInputs(TypedDict): | ||||||
|  |     type: Literal["pixel_values_videos"] | ||||||
|  |     pixel_values_videos: torch.Tensor | ||||||
|  |     """Shape: | ||||||
|  |     `(num_patches, | ||||||
|  |       num_channels * temporal_patch_size * patch_size * patch_size)` | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     video_grid_thw: torch.Tensor | ||||||
|  |     """Shape: `(num_videos, 3)` | ||||||
|  |  | ||||||
|  |     This should be in `(grid_t, grid_h, grid_w)` format. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qwen2_5_VLProcessingInfo(BaseProcessingInfo): | ||||||
|  |  | ||||||
|  |     def get_hf_config(self): | ||||||
|  |         return self.ctx.get_hf_config(Qwen2_5_VLConfig) | ||||||
|  |  | ||||||
|  |     def get_hf_processor( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         min_pixels: Optional[int] = None, | ||||||
|  |         max_pixels: Optional[int] = None, | ||||||
|  |     ) -> Qwen2_5_VLProcessor: | ||||||
|  |         hf_processor = self.ctx.get_hf_processor(Qwen2_5_VLProcessor) | ||||||
|  |         image_processor = hf_processor.image_processor  # type: ignore | ||||||
|  |         assert isinstance(image_processor, Qwen2_5_VLImageProcessor) | ||||||
|  |  | ||||||
|  |         if min_pixels: | ||||||
|  |             image_processor.min_pixels = min_pixels | ||||||
|  |         if max_pixels: | ||||||
|  |             image_processor.max_pixels = max_pixels | ||||||
|  |         if max_pixels or min_pixels: | ||||||
|  |             image_processor.size = { | ||||||
|  |                 "min_pixels": image_processor.min_pixels, | ||||||
|  |                 "max_pixels": image_processor.max_pixels, | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         return hf_processor | ||||||
|  |  | ||||||
|  |     def get_image_processor( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         min_pixels: Optional[int] = None, | ||||||
|  |         max_pixels: Optional[int] = None, | ||||||
|  |     ): | ||||||
|  |         hf_processor = self.get_hf_processor(min_pixels=min_pixels, | ||||||
|  |                                              max_pixels=max_pixels) | ||||||
|  |         image_processor = hf_processor.image_processor  # type: ignore | ||||||
|  |         assert isinstance(image_processor, Qwen2_5_VLImageProcessor) | ||||||
|  |         return image_processor | ||||||
|  |  | ||||||
|  |     def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: | ||||||
|  |         return {"image": None, "video": None} | ||||||
|  |  | ||||||
|  |     def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: | ||||||
|  |         return { | ||||||
|  |             "image": self.get_max_image_tokens(), | ||||||
|  |             "video": self.get_max_video_tokens(seq_len), | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     def _get_vision_info( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         image_width: int, | ||||||
|  |         image_height: int, | ||||||
|  |         num_frames: int = 1, | ||||||
|  |         do_resize: bool = True, | ||||||
|  |         image_processor: Optional[Qwen2_5_VLImageProcessor], | ||||||
|  |     ) -> tuple[ImageSize, int]: | ||||||
|  |         if image_processor is None: | ||||||
|  |             image_processor = self.get_image_processor() | ||||||
|  |  | ||||||
|  |         hf_config = self.get_hf_config() | ||||||
|  |         vision_config = hf_config.vision_config | ||||||
|  |         patch_size = vision_config.patch_size | ||||||
|  |         merge_size = vision_config.spatial_merge_size | ||||||
|  |         temporal_patch_size = vision_config.temporal_patch_size | ||||||
|  |  | ||||||
|  |         if do_resize: | ||||||
|  |             resized_height, resized_width = smart_resize( | ||||||
|  |                 height=image_height, | ||||||
|  |                 width=image_width, | ||||||
|  |                 factor=patch_size * merge_size, | ||||||
|  |                 min_pixels=image_processor.min_pixels, | ||||||
|  |                 max_pixels=image_processor.max_pixels, | ||||||
|  |             ) | ||||||
|  |             preprocessed_size = ImageSize(width=resized_width, | ||||||
|  |                                           height=resized_height) | ||||||
|  |         else: | ||||||
|  |             preprocessed_size = ImageSize(width=image_width, | ||||||
|  |                                           height=image_height) | ||||||
|  |  | ||||||
|  |         grid_t = max(num_frames // temporal_patch_size, 1) | ||||||
|  |         grid_h = preprocessed_size.height // patch_size | ||||||
|  |         grid_w = preprocessed_size.width // patch_size | ||||||
|  |  | ||||||
|  |         num_patches = grid_t * grid_h * grid_w | ||||||
|  |         num_vision_tokens = num_patches // (merge_size**2) | ||||||
|  |  | ||||||
|  |         return preprocessed_size, num_vision_tokens | ||||||
|  |  | ||||||
|  |     def get_num_image_tokens( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         image_width: int, | ||||||
|  |         image_height: int, | ||||||
|  |         image_processor: Optional[Qwen2_5_VLImageProcessor], | ||||||
|  |     ) -> int: | ||||||
|  |         _, num_image_tokens = self._get_vision_info( | ||||||
|  |             image_width=image_width, | ||||||
|  |             image_height=image_height, | ||||||
|  |             image_processor=image_processor, | ||||||
|  |         ) | ||||||
|  |         return num_image_tokens | ||||||
|  |  | ||||||
|  |     def get_num_video_tokens( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         image_width: int, | ||||||
|  |         image_height: int, | ||||||
|  |         num_frames: int, | ||||||
|  |         image_processor: Optional[Qwen2_5_VLImageProcessor], | ||||||
|  |     ) -> int: | ||||||
|  |         _, num_video_tokens = self._get_vision_info( | ||||||
|  |             image_width=image_width, | ||||||
|  |             image_height=image_height, | ||||||
|  |             num_frames=num_frames, | ||||||
|  |             image_processor=image_processor, | ||||||
|  |         ) | ||||||
|  |         return num_video_tokens | ||||||
|  |  | ||||||
|  |     def get_image_size_with_most_features(self) -> ImageSize: | ||||||
|  |         max_image_size, _ = self._get_vision_info( | ||||||
|  |             image_width=9999999, | ||||||
|  |             image_height=9999999, | ||||||
|  |             image_processor=None, | ||||||
|  |         ) | ||||||
|  |         return max_image_size | ||||||
|  |  | ||||||
|  |     def get_max_image_tokens(self) -> int: | ||||||
|  |         target_width, target_height = self.get_image_size_with_most_features() | ||||||
|  |  | ||||||
|  |         return self.get_num_image_tokens( | ||||||
|  |             image_width=target_width, | ||||||
|  |             image_height=target_height, | ||||||
|  |             image_processor=None, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _get_max_video_frames(self, max_tokens: int) -> int: | ||||||
|  |         target_width, target_height = self.get_image_size_with_most_features() | ||||||
|  |  | ||||||
|  |         num_frames = 0 | ||||||
|  |  | ||||||
|  |         while True: | ||||||
|  |             next_num_frames = num_frames + 1 | ||||||
|  |             next_max_tokens = self.get_num_video_tokens( | ||||||
|  |                 image_width=target_width, | ||||||
|  |                 image_height=target_height, | ||||||
|  |                 num_frames=next_num_frames, | ||||||
|  |                 image_processor=None, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             if next_max_tokens > max_tokens: | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |             num_frames = next_num_frames | ||||||
|  |  | ||||||
|  |         return num_frames | ||||||
|  |  | ||||||
|  |     def get_num_frames_with_most_features(self, seq_len: int) -> int: | ||||||
|  |         mm_config = self.ctx.get_mm_config() | ||||||
|  |         max_images = mm_config.limit_per_prompt.get("image", 1) | ||||||
|  |         max_videos = mm_config.limit_per_prompt.get("video", 1) | ||||||
|  |  | ||||||
|  |         max_image_tokens = self.get_max_image_tokens() * max_images | ||||||
|  |         max_total_frames = self._get_max_video_frames(seq_len - | ||||||
|  |                                                       max_image_tokens) | ||||||
|  |         num_frames = min(max(max_total_frames // max(max_videos, 1), 1), | ||||||
|  |                          _MAX_FRAMES_PER_VIDEO) | ||||||
|  |  | ||||||
|  |         # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 | ||||||
|  |         if num_frames > 1 and num_frames % 2 == 1: | ||||||
|  |             num_frames += 1 | ||||||
|  |  | ||||||
|  |         return num_frames | ||||||
|  |  | ||||||
|  |     def get_max_video_tokens(self, seq_len: int) -> int: | ||||||
|  |         target_width, target_height = self.get_image_size_with_most_features() | ||||||
|  |  | ||||||
|  |         return self.get_num_video_tokens( | ||||||
|  |             image_width=target_width, | ||||||
|  |             image_height=target_height, | ||||||
|  |             num_frames=self.get_num_frames_with_most_features(seq_len), | ||||||
|  |             image_processor=None, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qwen2_5_VLDummyInputsBuilder( | ||||||
|  |         BaseDummyInputsBuilder[Qwen2_5_VLProcessingInfo]): | ||||||
|  |  | ||||||
|  |     def get_dummy_processor_inputs( | ||||||
|  |         self, | ||||||
|  |         seq_len: int, | ||||||
|  |         mm_counts: Mapping[str, int], | ||||||
|  |     ) -> ProcessorInputs: | ||||||
|  |         num_images = mm_counts.get("image", 0) | ||||||
|  |         num_videos = mm_counts.get("video", 0) | ||||||
|  |  | ||||||
|  |         hf_processor = self.info.get_hf_processor() | ||||||
|  |         image_token: str = hf_processor.image_token | ||||||
|  |         video_token: str = hf_processor.video_token | ||||||
|  |  | ||||||
|  |         target_width, target_height = \ | ||||||
|  |             self.info.get_image_size_with_most_features() | ||||||
|  |         target_num_frames = \ | ||||||
|  |             self.info.get_num_frames_with_most_features(seq_len) | ||||||
|  |  | ||||||
|  |         mm_data = { | ||||||
|  |             "image": | ||||||
|  |             self._get_dummy_images(width=target_width, | ||||||
|  |                                    height=target_height, | ||||||
|  |                                    num_images=num_images), | ||||||
|  |             "video": | ||||||
|  |             self._get_dummy_videos( | ||||||
|  |                 width=target_width, | ||||||
|  |                 height=target_height, | ||||||
|  |                 num_frames=target_num_frames, | ||||||
|  |                 num_videos=num_videos, | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return ProcessorInputs( | ||||||
|  |             prompt_text=image_token * num_images + video_token * num_videos, | ||||||
|  |             mm_data=mm_data, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qwen2_5_VLMultiModalProcessor( | ||||||
|  |         BaseMultiModalProcessor[Qwen2_5_VLProcessingInfo]): | ||||||
|  |  | ||||||
|  |     def _get_prompt_replacements( | ||||||
|  |         self, | ||||||
|  |         mm_items: MultiModalDataItems, | ||||||
|  |         hf_processor_mm_kwargs: Mapping[str, Any], | ||||||
|  |         out_mm_kwargs: MultiModalKwargs, | ||||||
|  |     ) -> list[PromptReplacement]: | ||||||
|  |         hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) | ||||||
|  |         image_processor = self.info.get_image_processor( | ||||||
|  |             **hf_processor_mm_kwargs) | ||||||
|  |         tokenizer = self.info.get_tokenizer() | ||||||
|  |         vocab = tokenizer.get_vocab() | ||||||
|  |  | ||||||
|  |         # NOTE: Only Qwen2_5_VLProcessor in transformers 4.47.0 has | ||||||
|  |         # image_token and video_token registered | ||||||
|  |         placeholder = { | ||||||
|  |             "image": vocab[hf_processor.image_token], | ||||||
|  |             "video": vocab[hf_processor.video_token], | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         merge_length = image_processor.merge_size**2 | ||||||
|  |  | ||||||
|  |         def get_replacement_Qwen2_5_VL(item_idx: int, modality: str): | ||||||
|  |             grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] | ||||||
|  |             assert isinstance(grid_thw, torch.Tensor) | ||||||
|  |  | ||||||
|  |             num_tokens = int(grid_thw.prod()) // merge_length | ||||||
|  |             return [placeholder[modality]] * num_tokens | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             PromptReplacement( | ||||||
|  |                 modality=modality, | ||||||
|  |                 target=[placeholder[modality]], | ||||||
|  |                 replacement=partial(get_replacement_Qwen2_5_VL, | ||||||
|  |                                     modality=modality), | ||||||
|  |             ) for modality in ("image", "video") | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     def _get_mm_fields_config( | ||||||
|  |         self, | ||||||
|  |         hf_inputs: BatchFeature, | ||||||
|  |         hf_processor_mm_kwargs: Mapping[str, object], | ||||||
|  |     ) -> Mapping[str, MultiModalFieldConfig]: | ||||||
|  |         image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) | ||||||
|  |         image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist() | ||||||
|  |         image_slices = [ | ||||||
|  |             slice(image_slice_idxs[i], image_slice_idxs[i + 1]) | ||||||
|  |             for i in range(len(image_grid_thw)) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) | ||||||
|  |         video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist() | ||||||
|  |         video_slices = [ | ||||||
|  |             slice(video_slice_idxs[i], video_slice_idxs[i + 1]) | ||||||
|  |             for i in range(len(video_grid_thw)) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         return dict( | ||||||
|  |             pixel_values=MultiModalFieldConfig.flat("image", image_slices), | ||||||
|  |             image_grid_thw=MultiModalFieldConfig.batched("image"), | ||||||
|  |             pixel_values_videos=MultiModalFieldConfig.flat( | ||||||
|  |                 "video", video_slices), | ||||||
|  |             video_grid_thw=MultiModalFieldConfig.batched("video"), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @MULTIMODAL_REGISTRY.register_processor( | ||||||
|  |     Qwen2_5_VLMultiModalProcessor, | ||||||
|  |     info=Qwen2_5_VLProcessingInfo, | ||||||
|  |     dummy_inputs=Qwen2_5_VLDummyInputsBuilder) | ||||||
|  | class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, | ||||||
|  |                                          SupportsLoRA, SupportsPP): | ||||||
|  |     packed_modules_mapping = { | ||||||
|  |         "qkv_proj": [ | ||||||
|  |             "q_proj", | ||||||
|  |             "k_proj", | ||||||
|  |             "v_proj", | ||||||
|  |         ], | ||||||
|  |         "gate_up_proj": [ | ||||||
|  |             "gate_proj", | ||||||
|  |             "up_proj", | ||||||
|  |         ], | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     # LoRA specific attributes | ||||||
|  |     supported_lora_modules = [ | ||||||
|  |         "qkv_proj", | ||||||
|  |         "o_proj", | ||||||
|  |         "gate_up_proj", | ||||||
|  |         "down_proj", | ||||||
|  |         # vision tower | ||||||
|  |         "qkv", | ||||||
|  |         "attn.proj",  # Distinguish patch_embed.proj | ||||||
|  |         "fc1", | ||||||
|  |         "fc2", | ||||||
|  |         # projector | ||||||
|  |         "mlp.0", | ||||||
|  |         "mlp.2" | ||||||
|  |     ] | ||||||
|  |     embedding_modules = {} | ||||||
|  |     embedding_padding_modules = [] | ||||||
|  |  | ||||||
|  |     # To ensure correct weight loading and mapping. | ||||||
|  |     hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ | ||||||
|  |         "lm_head.": "language_model.lm_head.", | ||||||
|  |         "model.": "language_model.model.", | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||||||
|  |         super().__init__() | ||||||
|  |         config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config | ||||||
|  |         quant_config = vllm_config.quant_config | ||||||
|  |         multimodal_config = vllm_config.model_config.multimodal_config | ||||||
|  |  | ||||||
|  |         self.config = config | ||||||
|  |         self.multimodal_config = multimodal_config | ||||||
|  |  | ||||||
|  |         self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( | ||||||
|  |             config.vision_config) | ||||||
|  |  | ||||||
|  |         self.language_model = init_vllm_registered_model( | ||||||
|  |             vllm_config=vllm_config, | ||||||
|  |             prefix=maybe_prefix(prefix, "language_model"), | ||||||
|  |             architectures=["Qwen2ForCausalLM"], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.make_empty_intermediate_tensors = ( | ||||||
|  |             self.language_model.make_empty_intermediate_tensors) | ||||||
|  |  | ||||||
|  |     @cached_property | ||||||
|  |     def sampler(self): | ||||||
|  |         if hasattr(self.language_model, "sampler"): | ||||||
|  |             return self.language_model.sampler | ||||||
|  |  | ||||||
|  |         return get_sampler() | ||||||
|  |  | ||||||
|  |     def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): | ||||||
|  |         # GPTQ configs do not have a list of ignored modules, however AutoGPTQ | ||||||
|  |         # seems to avoid vision encoder sections for some models. | ||||||
|  |         # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 | ||||||
|  |         if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): | ||||||
|  |             return None | ||||||
|  |         return quant_config | ||||||
|  |  | ||||||
|  |     def _validate_and_reshape_mm_tensor(self, mm_input: object, | ||||||
|  |                                         name: str) -> torch.Tensor: | ||||||
|  |         if not isinstance(mm_input, (torch.Tensor, list)): | ||||||
|  |             raise ValueError(f"Incorrect type of {name}. " | ||||||
|  |                              f"Got type: {type(mm_input)}") | ||||||
|  |         if isinstance(mm_input, torch.Tensor): | ||||||
|  |             if mm_input.ndim == 2: | ||||||
|  |                 return mm_input | ||||||
|  |             if mm_input.ndim != 3: | ||||||
|  |                 raise ValueError(f"{name} should be 2D or batched 3D tensor. " | ||||||
|  |                                  f"Got ndim: {mm_input.ndim} " | ||||||
|  |                                  f"(shape={mm_input.shape})") | ||||||
|  |             return torch.concat(list(mm_input)) | ||||||
|  |         else: | ||||||
|  |             return torch.concat(mm_input) | ||||||
|  |  | ||||||
|  |     def _parse_and_validate_image_input( | ||||||
|  |             self, **kwargs: object) -> Optional[Qwen2_5_VLImagePixelInputs]: | ||||||
|  |         pixel_values = kwargs.pop("pixel_values", None) | ||||||
|  |         image_grid_thw = kwargs.pop("image_grid_thw", None) | ||||||
|  |  | ||||||
|  |         if pixel_values is None: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         if pixel_values is not None: | ||||||
|  |             pixel_values = self._validate_and_reshape_mm_tensor( | ||||||
|  |                 pixel_values, "image pixel values") | ||||||
|  |             image_grid_thw = self._validate_and_reshape_mm_tensor( | ||||||
|  |                 image_grid_thw, "image grid_thw") | ||||||
|  |  | ||||||
|  |             if not isinstance(pixel_values, (torch.Tensor, list)): | ||||||
|  |                 raise ValueError("Incorrect type of image pixel values. " | ||||||
|  |                                  f"Got type: {type(pixel_values)}") | ||||||
|  |  | ||||||
|  |             return Qwen2_5_VLImagePixelInputs(type="pixel_values", | ||||||
|  |                                               pixel_values=pixel_values, | ||||||
|  |                                               image_grid_thw=image_grid_thw) | ||||||
|  |         raise | ||||||
|  |  | ||||||
|  |     def _parse_and_validate_video_input( | ||||||
|  |             self, **kwargs: object) -> Optional[Qwen2_5_VLVideoPixelInputs]: | ||||||
|  |         pixel_values_videos = kwargs.pop("pixel_values_videos", None) | ||||||
|  |         video_grid_thw = kwargs.pop("video_grid_thw", None) | ||||||
|  |  | ||||||
|  |         if pixel_values_videos is None: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         if pixel_values_videos is not None: | ||||||
|  |             pixel_values_videos = self._validate_and_reshape_mm_tensor( | ||||||
|  |                 pixel_values_videos, "video pixel values") | ||||||
|  |             video_grid_thw = self._validate_and_reshape_mm_tensor( | ||||||
|  |                 video_grid_thw, "video grid_thw") | ||||||
|  |  | ||||||
|  |             return Qwen2_5_VLVideoPixelInputs( | ||||||
|  |                 type="pixel_values_videos", | ||||||
|  |                 pixel_values_videos=pixel_values_videos, | ||||||
|  |                 video_grid_thw=video_grid_thw, | ||||||
|  |             ) | ||||||
|  |         raise | ||||||
|  |  | ||||||
|  |     def _process_image_input( | ||||||
|  |             self, image_input: Qwen2_5_VLImagePixelInputs | ||||||
|  |     ) -> tuple[torch.Tensor, ...]: | ||||||
|  |  | ||||||
|  |         grid_thw = image_input["image_grid_thw"] | ||||||
|  |         assert grid_thw.ndim == 2 | ||||||
|  |  | ||||||
|  |         pixel_values = image_input["pixel_values"].type(self.visual.dtype) | ||||||
|  |         image_embeds = self.visual(pixel_values, grid_thw=grid_thw) | ||||||
|  |  | ||||||
|  |         # Split concatenated embeddings for each image item. | ||||||
|  |         merge_size = self.visual.spatial_merge_size | ||||||
|  |         sizes = grid_thw.prod(-1) // merge_size // merge_size | ||||||
|  |  | ||||||
|  |         return image_embeds.split(sizes.tolist()) | ||||||
|  |  | ||||||
|  |     def _process_video_input( | ||||||
|  |             self, video_input: Qwen2_5_VLVideoPixelInputs | ||||||
|  |     ) -> tuple[torch.Tensor, ...]: | ||||||
|  |  | ||||||
|  |         grid_thw = video_input["video_grid_thw"] | ||||||
|  |         assert grid_thw.ndim == 2 | ||||||
|  |  | ||||||
|  |         pixel_values_videos = video_input["pixel_values_videos"].type( | ||||||
|  |             self.visual.dtype) | ||||||
|  |         video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) | ||||||
|  |  | ||||||
|  |         # Split concatenated embeddings for each video item. | ||||||
|  |         merge_size = self.visual.spatial_merge_size | ||||||
|  |         sizes = grid_thw.prod(-1) // merge_size // merge_size | ||||||
|  |  | ||||||
|  |         return video_embeds.split(sizes.tolist()) | ||||||
|  |  | ||||||
|  |     def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: | ||||||
|  |         modalities = {} | ||||||
|  |  | ||||||
|  |         # Preserve the order of modalities if there are multiple of them | ||||||
|  |         # from the order of kwargs. | ||||||
|  |         for input_key in kwargs: | ||||||
|  |             if input_key in ("pixel_values", | ||||||
|  |                              "image_embeds") and "images" not in modalities: | ||||||
|  |                 modalities["images"] = self._parse_and_validate_image_input( | ||||||
|  |                     **kwargs) | ||||||
|  |             if input_key in ("pixel_values_videos", | ||||||
|  |                              "video_embeds") and "videos" not in modalities: | ||||||
|  |                 modalities["videos"] = self._parse_and_validate_video_input( | ||||||
|  |                     **kwargs) | ||||||
|  |  | ||||||
|  |         return modalities | ||||||
|  |  | ||||||
|  |     def get_multimodal_embeddings( | ||||||
|  |             self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: | ||||||
|  |  | ||||||
|  |         modalities = self._parse_and_validate_multimodal_inputs(**kwargs) | ||||||
|  |         if not modalities: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # The result multimodal_embeddings is tuple of tensors, with each | ||||||
|  |         # tensor correspoending to a multimodal data item (image or video). | ||||||
|  |         multimodal_embeddings: tuple[torch.Tensor, ...] = () | ||||||
|  |  | ||||||
|  |         # NOTE: It is important to iterate over the keys in this dictionary | ||||||
|  |         # to preserve the order of the modalities. | ||||||
|  |         for modality in modalities: | ||||||
|  |             if modality == "images": | ||||||
|  |                 image_input = modalities["images"] | ||||||
|  |                 vision_embeddings = self._process_image_input(image_input) | ||||||
|  |                 multimodal_embeddings += vision_embeddings | ||||||
|  |             if modality == "videos": | ||||||
|  |                 video_input = modalities["videos"] | ||||||
|  |                 video_embeddings = self._process_video_input(video_input) | ||||||
|  |                 multimodal_embeddings += video_embeddings | ||||||
|  |  | ||||||
|  |         return multimodal_embeddings | ||||||
|  |  | ||||||
|  |     def get_input_embeddings( | ||||||
|  |         self, | ||||||
|  |         input_ids: torch.Tensor, | ||||||
|  |         multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |         inputs_embeds = self.language_model.get_input_embeddings(input_ids) | ||||||
|  |         if multimodal_embeddings is not None: | ||||||
|  |             inputs_embeds = merge_multimodal_embeddings( | ||||||
|  |                 input_ids, inputs_embeds, multimodal_embeddings, | ||||||
|  |                 [self.config.image_token_id, self.config.video_token_id]) | ||||||
|  |         return inputs_embeds | ||||||
|  |  | ||||||
|  |     def get_input_embeddings_v0( | ||||||
|  |         self, | ||||||
|  |         input_ids: torch.Tensor, | ||||||
|  |         image_input: Optional[tuple[torch.Tensor, ...]] = None, | ||||||
|  |         video_input: Optional[tuple[torch.Tensor, ...]] = None, | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |  | ||||||
|  |         inputs_embeds = self.get_input_embeddings(input_ids) | ||||||
|  |         if image_input is not None: | ||||||
|  |             image_embeds = self._process_image_input(image_input) | ||||||
|  |             inputs_embeds = merge_multimodal_embeddings( | ||||||
|  |                 input_ids, | ||||||
|  |                 inputs_embeds, | ||||||
|  |                 image_embeds, | ||||||
|  |                 placeholder_token_id=self.config.image_token_id, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if video_input is not None: | ||||||
|  |             video_embeds = self._process_video_input(video_input) | ||||||
|  |             inputs_embeds = merge_multimodal_embeddings( | ||||||
|  |                 input_ids, | ||||||
|  |                 inputs_embeds, | ||||||
|  |                 video_embeds, | ||||||
|  |                 placeholder_token_id=self.config.video_token_id, | ||||||
|  |             ) | ||||||
|  |         return inputs_embeds | ||||||
|  |  | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         input_ids: torch.Tensor, | ||||||
|  |         positions: torch.Tensor, | ||||||
|  |         kv_caches: List[torch.Tensor], | ||||||
|  |         attn_metadata: AttentionMetadata, | ||||||
|  |         intermediate_tensors: Optional[IntermediateTensors] = None, | ||||||
|  |         inputs_embeds: Optional[torch.Tensor] = None, | ||||||
|  |         **kwargs: object, | ||||||
|  |     ) -> Union[torch.Tensor, IntermediateTensors]: | ||||||
|  |         """Run forward pass for Qwen2.5-VL. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             input_ids: Flattened (concatenated) input_ids corresponding to a | ||||||
|  |                 batch. | ||||||
|  |             positions: Flattened (concatenated) position ids corresponding to a | ||||||
|  |                 batch. | ||||||
|  |                 **NOTE**: If mrope is enabled (default setting for Qwen2-VL | ||||||
|  |                 opensource models), the shape will be `(3, seq_len)`, | ||||||
|  |                 otherwise it will be `(seq_len,). | ||||||
|  |             pixel_values: Pixel values to be fed to a model. | ||||||
|  |                 `None` if no images are passed. | ||||||
|  |             image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. | ||||||
|  |                 `None` if no images are passed. | ||||||
|  |             pixel_values_videos: Pixel values of videos to be fed to a model. | ||||||
|  |                 `None` if no videos are passed. | ||||||
|  |             video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. | ||||||
|  |                 `None` if no videos are passed. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if intermediate_tensors is not None: | ||||||
|  |             inputs_embeds = None | ||||||
|  |  | ||||||
|  |         # NOTE: In v1, inputs_embeds is always generated at model runner from | ||||||
|  |         # `get_multimodal_embeddings` and `get_input_embeddings`, this | ||||||
|  |         # condition is only for v0 compatibility. | ||||||
|  |         elif inputs_embeds is None: | ||||||
|  |             image_input = self._parse_and_validate_image_input(**kwargs) | ||||||
|  |             video_input = self._parse_and_validate_video_input(**kwargs) | ||||||
|  |  | ||||||
|  |             if image_input is None and video_input is None: | ||||||
|  |                 inputs_embeds = None | ||||||
|  |             else: | ||||||
|  |                 if uses_mrope(self.config): | ||||||
|  |                     assert positions.ndim == 2 and positions.size(0) == 3, ( | ||||||
|  |                         "multimodal section rotary embedding requires " | ||||||
|  |                         f"(3, seq_len) positions, but got {positions.size()}") | ||||||
|  |                 inputs_embeds = self.get_input_embeddings_v0( | ||||||
|  |                     input_ids, | ||||||
|  |                     image_input=image_input, | ||||||
|  |                     video_input=video_input) | ||||||
|  |                 input_ids = None | ||||||
|  |  | ||||||
|  |         hidden_states = self.language_model.model( | ||||||
|  |             input_ids=input_ids, | ||||||
|  |             positions=positions, | ||||||
|  |             kv_caches=kv_caches, | ||||||
|  |             attn_metadata=attn_metadata, | ||||||
|  |             intermediate_tensors=intermediate_tensors, | ||||||
|  |             inputs_embeds=inputs_embeds, | ||||||
|  |         ) | ||||||
|  |         return hidden_states | ||||||
|  |  | ||||||
|  |     def compute_logits( | ||||||
|  |         self, | ||||||
|  |         hidden_states: torch.Tensor, | ||||||
|  |         sampling_metadata: SamplingMetadata, | ||||||
|  |     ) -> Optional[torch.Tensor]: | ||||||
|  |         return self.language_model.compute_logits(hidden_states, | ||||||
|  |                                                   sampling_metadata) | ||||||
|  |  | ||||||
|  |     def sample( | ||||||
|  |         self, | ||||||
|  |         logits: torch.Tensor, | ||||||
|  |         sampling_metadata: SamplingMetadata, | ||||||
|  |     ) -> Optional[SamplerOutput]: | ||||||
|  |         return self.language_model.sample(logits, sampling_metadata) | ||||||
|  |  | ||||||
|  |     def load_weights(self, weights: Iterable[Tuple[str, | ||||||
|  |                                                    torch.Tensor]]) -> Set[str]: | ||||||
|  |  | ||||||
|  |         loader = AutoWeightsLoader(self) | ||||||
|  |         return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) | ||||||
|  |  | ||||||
|  |     def get_mm_mapping(self) -> MultiModelKeys: | ||||||
|  |         """ | ||||||
|  |         Get the module prefix in multimodal models | ||||||
|  |         """ | ||||||
|  |         return MultiModelKeys.from_string_field( | ||||||
|  |             language_model="language_model", | ||||||
|  |             connector="visual.", | ||||||
|  |             tower_model="visual.merger.") | ||||||
| @ -171,6 +171,7 @@ _MULTIMODAL_MODELS = { | |||||||
|     "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501 |     "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501 | ||||||
|     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), |     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), | ||||||
|     "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501 |     "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501 | ||||||
|  |     "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501 | ||||||
|     "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501 |     "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501 | ||||||
|     "UltravoxModel": ("ultravox", "UltravoxModel"), |     "UltravoxModel": ("ultravox", "UltravoxModel"), | ||||||
|     # [Encoder-decoder] |     # [Encoder-decoder] | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	