[CI/Build][Bugfix] Ensure compatibility with transformers 4.52 (#18678)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-06-04 19:49:20 +08:00
committed by GitHub
parent 35cf32df30
commit 01dc9a76db
13 changed files with 82 additions and 47 deletions

View File

@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
mteb>=1.38.11, <2 # required for mteb test
transformers==4.51.3
transformers==4.52.4
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
schemathesis>=3.39.15 # Required for openai schema test.

View File

@ -794,7 +794,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.51.3
transformers==4.52.4
# via
# -r requirements/test.in
# genai-perf

View File

@ -226,6 +226,8 @@ VLM_TEST_SETTINGS = {
img_idx_to_prompt=lambda idx: "",
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.blip2_vllm_to_hf_output,
# FIXME: https://github.com/huggingface/transformers/pull/38510
marks=[pytest.mark.skip("Model is broken")],
),
"chameleon": VLMTestInfo(
models=["facebook/chameleon-7b"],
@ -281,10 +283,10 @@ VLM_TEST_SETTINGS = {
multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
@ -337,7 +339,8 @@ VLM_TEST_SETTINGS = {
models=[
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL2-2B",
"OpenGVLab/Mono-InternVL-2B",
# FIXME: Config cannot be loaded in transformers 4.52
# "OpenGVLab/Mono-InternVL-2B",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
@ -568,6 +571,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
# FIXME: https://github.com/huggingface/transformers/issues/38358
marks=[pytest.mark.skip("Model initialization fails")],
),
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],

View File

@ -100,6 +100,8 @@ def run_test(
)
# FIXME: https://github.com/huggingface/transformers/issues/38358
@pytest.mark.skip("Model initialization fails")
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(

View File

@ -29,7 +29,7 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
# Audio lora co-exists directly in the model directory, but
# currently still needs to be passed directly to vLLM.
audio_lora_path = MODEL_NAME

View File

@ -122,6 +122,10 @@ def run_test(
for prompts, images, audios in inputs
]
# This error occurs inside `get_peft_model`
# FIXME: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/75
pytest.skip("HF impl is not compatible with current transformers")
hf_model_kwargs = {"_attn_implementation": "sdpa"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:

View File

@ -10,11 +10,12 @@ from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import pytest
import regex as re
import torch
from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig)
GenerationConfig, GenerationMixin)
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
@ -324,6 +325,16 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = processor
orig_generate = hf_model.model.generate
def _generate(self, *args, **kwargs):
# FIXME: https://github.com/huggingface/transformers/issues/38333
kwargs["disable_compile"] = True
return orig_generate(*args, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
@ -610,6 +621,11 @@ def _internvl_generate(
if getattr(self, "use_visual_token_mask", False):
visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype)
forward_kwargs["visual_token_mask"] = visual_token_mask
# e.g. InternVL2-2B
if not isinstance(self.language_model, GenerationMixin):
pytest.skip("HF impl is not compatible with current transformers")
outputs = self.language_model.generate(
**forward_kwargs,
**generate_kwargs,

View File

@ -245,7 +245,7 @@ def _test_processing_correctness_one(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"ibm-granite/granite-speech-3.3-8b",
"ibm-granite/granite-speech-3.3-2b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL3-1B",

View File

@ -160,17 +160,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
is_available_online=False,
min_transformers_version="4.52.2"),
min_transformers_version="4.53"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo(
"THUDM/GLM-4-32B-0414",
is_available_online=False,
min_transformers_version="4.52.dev0"
),
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
{"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder",
@ -181,8 +176,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{"1b": "EleutherAI/pythia-1.4b"}),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501
min_transformers_version="4.52.0"), # noqa: E501
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
@ -203,8 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
is_available_online=False),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
@ -243,10 +236,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
v0_only=True),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
@ -256,7 +248,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",
trust_remote_code=True),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
tokenizer="meta-llama/Llama-2-7b",
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
@ -275,8 +267,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
hf_overrides={"architectures": ["GteNewModel"]}), # noqa: E501
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
@ -298,10 +289,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
# The model on Huggingface is currently being updated,
# hence I temporarily mark it as not available online
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False),
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False), # noqa: E501
}
_CROSS_ENCODER_EXAMPLE_MODELS = {
@ -327,8 +316,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
min_transformers_version="4.52.0"), # noqa: E501
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
@ -347,7 +335,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True,
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51",
max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
@ -360,8 +347,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
@ -399,10 +384,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
min_transformers_version="4.52"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
min_transformers_version="4.52"),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
@ -413,8 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="Isotr0py/Florence-2-tokenizer",
trust_remote_code=True,), # noqa: E501
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}

View File

@ -21,6 +21,10 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
# FIXME: Possible memory leak in the previous tests?
if model_arch == "GraniteSpeechForConditionalGeneration":
pytest.skip("Avoid OOM")
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
@ -41,6 +45,13 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
return hf_config
# Avoid calling model.forward()

View File

@ -3139,6 +3139,8 @@ def _find_dtype(
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "encoder_config"):
config_dtype = getattr(config.encoder_config, "torch_dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:

View File

@ -111,7 +111,13 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_config(AyaVisionConfig)
def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
processor = self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
# Temporary workaround since this processor has multiple image tokens
# See https://github.com/huggingface/transformers/issues/38350
processor._check_special_mm_tokens = lambda *args, **kwargs: None
return processor
def get_image_processor(self) -> GotOcr2ImageProcessor:
return self.get_hf_processor().image_processor
@ -188,9 +194,7 @@ class AyaVisionMultiModalProcessor(
image_processor = hf_processor.image_processor
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images :=
mm_data.get("images")) is not None and '<image>' in prompt:
assert isinstance(images, list)
if (images := mm_data.get("images")) is not None:
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images

View File

@ -22,8 +22,8 @@ from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor)
from transformers import (AddedToken, BatchFeature, Idefics3Config,
Idefics3ImageProcessor, Idefics3Processor)
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
@ -199,13 +199,21 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return grid_w * grid_h + 1
# TODO: Remove after requiring transformers>=4.52
def _get_content(self, token: Union[AddedToken, str]) -> str:
if isinstance(token, str):
return token
return token.content
def _get_image_token(
self,
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
image_token = self._get_content(processor.image_token)
fake_image_token = self._get_content(processor.fake_image_token)
global_image_token = processor.global_image_tag
return image_token, fake_image_token, global_image_token