[Bugfix] Fix mrope in Transformers Backend (#26087)

Signed-off-by: raushan <raushan@huggingface.co>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-10-06 13:40:50 +02:00
committed by GitHub
parent 0340f45553
commit ab5e7d93f4
2 changed files with 38 additions and 68 deletions

View File

@ -222,8 +222,7 @@ VLM_TEST_SETTINGS = {
vllm_runner_kwargs={
"model_impl": "transformers",
},
# FIXME: Investigate mrope issue
marks=[large_gpu_mark(min_gb=32), pytest.mark.skip(reason="Mrope issue")],
marks=[large_gpu_mark(min_gb=32)],
),
#### Extended model tests
"aria": VLMTestInfo(

View File

@ -79,7 +79,6 @@ from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
flatten_bn,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
@ -347,12 +346,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs,
hf_processor_mm_kwargs,
num_image_patches: torch.Tensor = None,
):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
@ -360,41 +359,24 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)
# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _apply_hf_processor_text_mm(
def _get_hf_mm_data(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], BatchFeature, bool]:
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt replacements have been applied.
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
tok_kwargs=tokenization_kwargs,
)
processed_data.update(passthrough_data)
(prompt_ids,) = processed_data.pop("input_ids").tolist()
mm_token_type_ids = (
processed_data.pop("mm_token_type_ids")
if "mm_token_type_ids" in processed_data
else processed_data.pop("token_type_ids")
) # for gemma3 only
return prompt_ids, processed_data, mm_token_type_ids
return processor_data, passthrough_data
def apply(
self,
@ -421,18 +403,28 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
# into string
prompt = hf_processor.decode(prompt)
(prompt_ids, processed_data, mm_token_type_ids) = (
self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
@ -462,17 +454,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
]
mm_placeholders = {"image": ranges}
num_image_patches = (
torch.tensor(mm_tokens_per_modality["num_image_patches"])
if "num_image_patches" in mm_tokens_per_modality
else None
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
processed_data["num_image_patches"] = num_image_patches
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(
processed_data, hf_processor_mm_kwargs, num_image_patches
),
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
# Use overrides if provided; fallback to data-dependent hashing.
@ -531,8 +518,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
@ -844,17 +829,6 @@ class TransformersForCausalLM(TransformersBase):
return logits
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
"""Flatten until a list of tensors can be concatenated then do concat"""
def _can_concat(x: list[torch.Tensor]):
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
if _can_concat(x):
return torch.concat(x)
return flatten_and_concat(flatten_bn(x))
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
@ -935,9 +909,6 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
if isinstance(vision_embeddings, torch.Tensor):
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)