[New Model]: Support GteNewModelForSequenceClassification (#23524)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-08-28 15:36:42 +08:00
committed by GitHub
parent 186aced5ff
commit 11a7fafaa8
13 changed files with 157 additions and 76 deletions

View File

@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
```
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note
Load the official original `mxbai-rerank-v2` by using the following command.

View File

@ -456,11 +456,10 @@ class HfRunner:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
problem_type = getattr(self.config, "problem_type", "")
if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":

View File

@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,

View File

@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,
@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,

View File

@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
architecture="GemmaForSequenceClassification",
hf_overrides={
"architectures":
["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method":
"no_post_processing",
}),
]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")
assert model_info.architecture == "GemmaForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
}
}
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(GemmaRerankerHfRunner,
vllm_runner,
model_info,
vllm_extra_kwargs,
vllm_mteb_encoder=GemmaMtebEncoder)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
@ -33,12 +32,15 @@ MODELS = [
########### NewModel
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
########### Qwen2ForCausalLM
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
@ -60,11 +62,16 @@ MODELS = [
]
RERANK_MODELS = [
# classifier_pooling: mean
CLSPoolingRerankModelInfo(
# classifier_pooling: mean
"Alibaba-NLP/gte-reranker-modernbert-base",
architecture="ModernBertForSequenceClassification",
enable_test=True),
CLSPoolingRerankModelInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
enable_test=True),
]
@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)
example_prompts)
@pytest.mark.parametrize("model_info", RERANK_MODELS)

View File

@ -10,12 +10,20 @@ from tests.conftest import HfRunner
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models
mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
RERANK_MODELS = [
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=True),
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=False)
]
@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)

View File

@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models
qwen3_reranker_hf_overrides = {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
RERANK_MODELS = [
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=True),
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=False)
]
@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
assert model_info.architecture == "Qwen3ForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
}
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
assert model_info.architecture == "Qwen3ForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tensor_parallel_size": 2,
}

View File

@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501

View File

@ -3,7 +3,8 @@
import warnings
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1)
class EmbedModelInfo(NamedTuple):
@dataclass
class ModelInfo:
name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class RerankModelInfo(ModelInfo):
pass
@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"

View File

@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module):
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
add_pooling_layer: bool = False):
super().__init__()
self.vllm_config = vllm_config
self.add_pooling_layer = add_pooling_layer
self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant):
bias=getattr(self.config, "bias", True),
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def forward(
self,
@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
if not self.add_pooling_layer and "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
@ -508,8 +517,8 @@ class GteNewModel(BertWithRope):
"attention.o_proj": "attn.out_proj",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope):
torch.Tensor]]) -> set[str]:
weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights)
@default_pooling_type("CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix,
add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "classifier"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.new(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

View File

@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

View File

@ -191,12 +191,14 @@ _EMBEDDING_MODELS = {
_CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GteNewForSequenceClassification": ("bert_with_rope",
"GteNewForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"RobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
# [Auto-converted (see adapters.py)]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
}