diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 656a6d3bd7..c57ccd62fe 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -222,8 +222,7 @@ VLM_TEST_SETTINGS = { vllm_runner_kwargs={ "model_impl": "transformers", }, - # FIXME: Investigate mrope issue - marks=[large_gpu_mark(min_gb=32), pytest.mark.skip(reason="Mrope issue")], + marks=[large_gpu_mark(min_gb=32)], ), #### Extended model tests "aria": VLMTestInfo( diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 25cd40a939..5b92aa97ea 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -79,7 +79,6 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, WeightsMapper, - flatten_bn, make_empty_intermediate_tensors_factory, maybe_prefix, ) @@ -347,12 +346,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): def _get_mm_fields_config( self, - hf_inputs, - hf_processor_mm_kwargs, - num_image_patches: torch.Tensor = None, - ): + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: # HF Processors always return a mask but vLLM doesn't need it hf_inputs.pop("attention_mask", None) + num_image_patches = hf_inputs.get("num_image_patches") mm_fields = { key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) for key in hf_inputs @@ -360,41 +359,24 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( "image", num_image_patches ) + + # Keep these as batched, as they always have batch size as first dim + mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") return mm_fields - def _apply_hf_processor_text_mm( + def _get_hf_mm_data( self, - prompt_text: str, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], BatchFeature, bool]: + ) -> tuple[Mapping[str, object], Mapping[str, object]]: """ - Apply the HF processor on the prompt text and multi-modal data - together. - - In addition, return whether prompt replacements have been applied. + In contrast to the base class, this method always adds + `return_mm_token_type_ids` to the processor data """ - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) processor_data["return_mm_token_type_ids"] = True - - processed_data = self._call_hf_processor( - prompt=prompt_text, - mm_data=processor_data, - mm_kwargs=hf_processor_mm_kwargs, - tok_kwargs=tokenization_kwargs, - ) - processed_data.update(passthrough_data) - - (prompt_ids,) = processed_data.pop("input_ids").tolist() - mm_token_type_ids = ( - processed_data.pop("mm_token_type_ids") - if "mm_token_type_ids" in processed_data - else processed_data.pop("token_type_ids") - ) # for gemma3 only - - return prompt_ids, processed_data, mm_token_type_ids + return processor_data, passthrough_data def apply( self, @@ -421,18 +403,28 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): # into string prompt = hf_processor.decode(prompt) - (prompt_ids, processed_data, mm_token_type_ids) = ( - self._apply_hf_processor_text_mm( - prompt_text=prompt, - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) + # Bypass cached processor and always apply to the full set of mm inputs + # NOTE: we can't just set caching=False because base class method + # transforms outputs to `MultiModalKwargs` which is not going to + # work for Transformers. We have a lot of logic tied to + # `mm_tokens_per_modality` below + prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) - # HF processor will return `mm_token_type_ids` from which - # we can infer mm_placeholders. Until then hardcode to make code run - # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + # For gemma3 we check `token_type_ids` as the key + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processed_data + else "token_type_ids" + ) + mm_token_type_ids = processed_data.pop(token_type_key) + + # We can infer vLLM style placeholder from token type ids, if we split + # it for each input `mm_data`. mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) multimodal_config = self.info.ctx.model_config.multimodal_config @@ -462,17 +454,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ] mm_placeholders = {"image": ranges} - num_image_patches = ( - torch.tensor(mm_tokens_per_modality["num_image_patches"]) - if "num_image_patches" in mm_tokens_per_modality - else None + processed_data["num_image_patches"] = torch.tensor( + mm_tokens_per_modality["num_image_patches"] ) - processed_data["num_image_patches"] = num_image_patches mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, - self._get_mm_fields_config( - processed_data, hf_processor_mm_kwargs, num_image_patches - ), + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), ) # Use overrides if provided; fallback to data-dependent hashing. @@ -531,8 +518,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors - # TODO: @raushan, use the public `model.set_attn_implementation()` - # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( @@ -844,17 +829,6 @@ class TransformersForCausalLM(TransformersBase): return logits -def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: - """Flatten until a list of tensors can be concatenated then do concat""" - - def _can_concat(x: list[torch.Tensor]): - return len(set(map(lambda _x: _x.shape[1:], x))) == 1 - - if _can_concat(x): - return torch.concat(x) - return flatten_and_concat(flatten_bn(x)) - - @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, @@ -935,9 +909,6 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) if isinstance(vision_embeddings, torch.Tensor): - if isinstance(num_image_patches, list): - num_image_patches = torch.cat(num_image_patches) - if vision_embeddings.ndim == 2: vision_embeddings = vision_embeddings.unsqueeze(0)