Update transformers to v4.55 (#21931)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Harry Mellor
2025-08-06 06:56:14 +01:00
committed by GitHub
parent 6e20924350
commit 796bae07c5
13 changed files with 235 additions and 39 deletions

View File

@ -7,7 +7,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.53.2
transformers >= 4.55.0
huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads.
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer.

View File

@ -35,7 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
mteb[bm25s]>=1.38.11, <2 # required for mteb test
transformers==4.53.2
transformers==4.55.0
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
schemathesis>=3.39.15 # Required for openai schema test.

View File

@ -214,7 +214,7 @@ fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1
fonttools==4.55.0
# via matplotlib
fqdn==1.5.1
# via jsonschema
@ -286,7 +286,7 @@ httpx==0.27.2
# via
# -r requirements/test.in
# schemathesis
huggingface-hub==0.33.1
huggingface-hub==0.34.3
# via
# -r requirements/test.in
# accelerate
@ -1148,7 +1148,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.53.2
transformers==4.55.0
# via
# -r requirements/test.in
# genai-perf

View File

@ -337,6 +337,10 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
# FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we
# should enable this again after the fix is released:
# https://github.com/huggingface/transformers/pull/39915
marks=[pytest.mark.skip("HF model is broken")],
),
"gemma3": VLMTestInfo(
models=["google/gemma-3-4b-it"],

View File

@ -179,8 +179,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
min_transformers_version="4.54"),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
min_transformers_version="4.53"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
@ -223,7 +222,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
}),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
@ -239,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf",
min_transformers_version="4.53"),
"MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True,
revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501
@ -272,6 +273,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.53",
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
@ -299,8 +302,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst",
min_transformers_version="4.53"),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
@ -326,8 +328,12 @@ _EMBEDDING_EXAMPLE_MODELS = {
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True, v0_only=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501

View File

@ -9,6 +9,8 @@ import pytest
from tests.quantization.utils import is_quant_method_supported
from ..models.registry import HF_EXAMPLE_MODELS
MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"]
@ -25,6 +27,8 @@ def test_model_experts_int8_startup(
dtype: str,
max_tokens: int,
) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
with vllm_runner(model, dtype=dtype,
quantization="experts_int8") as vllm_model:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import torch
@ -14,6 +14,10 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.sampling_metadata import SamplingMetadata
else:
VllmConfig = Any
Pooler = Any
SamplingMetadata = Any
logger = init_logger(__name__)
@ -34,7 +38,7 @@ class VllmModel(Protocol[T_co]):
def __init__(
self,
vllm_config: "VllmConfig",
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
...
@ -96,7 +100,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
def compute_logits(
self,
hidden_states: T,
sampling_metadata: "SamplingMetadata",
sampling_metadata: SamplingMetadata,
) -> Optional[T]:
"""Return `None` if TP rank > 0."""
...
@ -140,7 +144,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
pooler: "Pooler"
pooler: Pooler
"""The pooler is only called on TP rank 0."""

View File

@ -1395,11 +1395,12 @@ class Tarsier2Processor(Qwen2VLProcessor):
**kwargs,
):
self.image_processor = Tarsier2ImageProcessor(**vision_config)
super().__init__(image_processor=self.image_processor,
tokenizer=tokenizer,
video_processor=Qwen2VLVideoProcessor(),
chat_template=None,
**kwargs)
super().__init__(
image_processor=self.image_processor,
tokenizer=tokenizer,
video_processor=Qwen2VLVideoProcessor(**vision_config),
chat_template=None,
**kwargs)
class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):

View File

@ -90,7 +90,7 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig
) -> Union[ColumnParallelLinear, RowParallelLinear]:
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
@ -445,7 +445,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method after v4.54.0 is released
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
self.model: PreTrainedModel = AutoModel.from_config(
@ -520,7 +520,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer(return_tuple=True)
layers[i] = PPMissingLayer()
# Layers after module list
for name in pp_plan[module_list_idx + 1:]:
@ -533,14 +533,16 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if not self.model.supports_tp_plan:
if self.tp_size <= 1:
return
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
if not tp_plan and self.tp_size > 1:
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")
tp_plan = self.model._tp_plan
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
tp_plan[".*"] = "replicated"
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
@ -552,6 +554,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
break
else:
_tensor_parallel(child_module, prefix=qual_name)

View File

@ -534,16 +534,10 @@ class PPMissingLayer(torch.nn.Identity):
def __init__(self, *args, **kwargs):
super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input, ) if self.return_tuple else input
"""Return the first arg from args or the first value from kwargs."""
return args[0] if args else next(iter(kwargs.values()))
_CPU_OFFLOAD_BYTES = 0

View File

@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
RWConfig, SpeculatorsConfig,
OvisConfig, RWConfig,
SpeculatorsConfig,
Step3TextConfig, Step3VLConfig,
UltravoxConfig)
# yapf: enable
@ -85,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"speculators": SpeculatorsConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ovis": OvisConfig,
"ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,

View File

@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
@ -45,6 +46,7 @@ __all__ = [
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"OvisConfig",
"SpeculatorsConfig",
"UltravoxConfig",
"Step3VLConfig",

View File

@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# yapf: disable
# ruff: noqa: E501
# adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py
# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py
# Ovis Config with AimV2 config registration removed for Transformers compatibility
from typing import Any, Optional, Union
from transformers import AutoConfig, PretrainedConfig
class AIMv2Config(PretrainedConfig):
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
Instantiating a configuration with the defaults will yield a similar configuration
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
Args:
hidden_size: Dimension of the hidden representations.
intermediate_size: Dimension of the SwiGLU representations.
num_hidden_layers: Number of hidden layers in the Transformer.
num_attention_heads: Number of attention heads for each attention layer
in the Transformer.
num_channels: Number of input channels.
image_size: Image size.
patch_size: Patch size.
rms_norm_eps: Epsilon value used for the RMS normalization layer.
attention_dropout: Dropout ratio for attention probabilities.
projection_dropout: Dropout ratio for the projection layer after the attention.
qkv_bias: Whether to add a bias to the queries, keys and values.
use_bias: Whether to add a bias in the feed-forward and projection layers.
kwargs: Keyword arguments for the [`PretrainedConfig`].
"""
model_type: str = "aimv2"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.qkv_bias = qkv_bias
self.use_bias = use_bias
# ----------------------------------------------------------------------
# Visual Tokenizer Configuration
# ----------------------------------------------------------------------
class BaseVisualTokenizerConfig(PretrainedConfig):
def __init__(self,
vocab_size=16384,
tokenize_function="softmax",
tau=1.0,
depths=None,
drop_cls_token=False,
backbone_config: Optional[Union[PretrainedConfig,
dict]] = None,
hidden_stride: int = 1,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.tokenize_function = tokenize_function
self.tau = tau
if isinstance(depths, str):
depths = [int(x) for x in depths.split('|')]
self.depths = depths
self.backbone_kwargs = dict[str, Any]()
self.drop_cls_token = drop_cls_token
if backbone_config is not None:
assert isinstance(backbone_config, (PretrainedConfig, dict)), \
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
if not isinstance(backbone_config, PretrainedConfig):
model_type = backbone_config['model_type']
if model_type != "aimv2":
backbone_config.pop('model_type')
backbone_config = AutoConfig.for_model(model_type, **backbone_config)
else:
backbone_config = AIMv2Config(**backbone_config)
self.backbone_config = backbone_config
self.hidden_stride = hidden_stride
class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = "aimv2_visual_tokenizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.drop_cls_token:
self.drop_cls_token = False
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = "siglip_visual_tokenizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.drop_cls_token:
self.drop_cls_token = False
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig)
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
# ----------------------------------------------------------------------
# Ovis Configuration
# ----------------------------------------------------------------------
class OvisConfig(PretrainedConfig):
model_type = "ovis"
def __init__(self,
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
visual_tokenizer_config: Optional[Union[PretrainedConfig,
dict]] = None,
multimodal_max_length=8192,
hidden_size=None,
conversation_formatter_class=None,
llm_attn_implementation=None,
disable_tie_weight=False,
**kwargs):
super().__init__(**kwargs)
if llm_config is not None:
assert isinstance(llm_config, (PretrainedConfig, dict)), \
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
if not isinstance(llm_config, PretrainedConfig):
model_type = llm_config['model_type']
llm_config.pop('model_type')
llm_config = AutoConfig.for_model(model_type, **llm_config)
# map llm_config to text_config
self.text_config = llm_config
if visual_tokenizer_config is not None:
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
if not isinstance(visual_tokenizer_config, PretrainedConfig):
model_type = visual_tokenizer_config['model_type']
visual_tokenizer_config.pop('model_type')
visual_tokenizer_config = AutoConfig.for_model(
model_type, **visual_tokenizer_config)
self.visual_tokenizer_config = visual_tokenizer_config
self.multimodal_max_length = multimodal_max_length
self.hidden_size = hidden_size
self.conversation_formatter_class = conversation_formatter_class
self.llm_attn_implementation = llm_attn_implementation
self.disable_tie_weight = disable_tie_weight