[New Model]: Support GteNewModelForSequenceClassification (#23524)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
|
||||
@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
|
||||
```
|
||||
|
||||
!!! note
|
||||
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
|
||||
|
||||
!!! note
|
||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||
|
||||
|
@ -456,11 +456,10 @@ class HfRunner:
|
||||
# output is final logits
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
outputs = []
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model(**self.wrap_device(inputs))
|
||||
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
if problem_type == "regression":
|
||||
logits = output.logits[0].tolist()
|
||||
elif problem_type == "multi_label_classification":
|
||||
|
@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
|
@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
|
@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification"),
|
||||
architecture="GemmaForSequenceClassification",
|
||||
hf_overrides={
|
||||
"architectures":
|
||||
["GemmaForSequenceClassification"],
|
||||
"classifier_from_token": ["Yes"],
|
||||
"method":
|
||||
"no_post_processing",
|
||||
}),
|
||||
]
|
||||
|
||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||
@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
|
||||
monkeypatch) -> None:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
assert model_info.architecture == "GemmaForSequenceClassification"
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["GemmaForSequenceClassification"],
|
||||
"classifier_from_token": ["Yes"],
|
||||
"method": "no_post_processing",
|
||||
}
|
||||
}
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
|
||||
mteb_test_rerank_models(GemmaRerankerHfRunner,
|
||||
vllm_runner,
|
||||
model_info,
|
||||
vllm_extra_kwargs,
|
||||
vllm_mteb_encoder=GemmaMtebEncoder)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -33,12 +32,15 @@ MODELS = [
|
||||
########### NewModel
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
||||
architecture="GteNewModel",
|
||||
hf_overrides={"architectures": ["GteNewModel"]},
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
hf_overrides={"architectures": ["GteNewModel"]},
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
hf_overrides={"architectures": ["GteNewModel"]},
|
||||
enable_test=True),
|
||||
########### Qwen2ForCausalLM
|
||||
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
@ -60,11 +62,16 @@ MODELS = [
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
# classifier_pooling: mean
|
||||
CLSPoolingRerankModelInfo(
|
||||
# classifier_pooling: mean
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base",
|
||||
architecture="ModernBertForSequenceClassification",
|
||||
enable_test=True),
|
||||
CLSPoolingRerankModelInfo(
|
||||
"Alibaba-NLP/gte-multilingual-reranker-base",
|
||||
architecture="GteNewForSequenceClassification",
|
||||
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
|
||||
@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
check_transformers_version(model_info.name,
|
||||
max_transformers_version="4.53.2")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
||||
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||
vllm_extra_kwargs)
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
check_transformers_version(model_info.name,
|
||||
max_transformers_version="4.53.2")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
||||
|
||||
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||
example_prompts, vllm_extra_kwargs)
|
||||
example_prompts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
|
@ -10,12 +10,20 @@ from tests.conftest import HfRunner
|
||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||
from .mteb_utils import mteb_test_rerank_models
|
||||
|
||||
mxbai_rerank_hf_overrides = {
|
||||
"architectures": ["Qwen2ForSequenceClassification"],
|
||||
"classifier_from_token": ["0", "1"],
|
||||
"method": "from_2_way_softmax",
|
||||
}
|
||||
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
enable_test=True),
|
||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner):
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "Qwen2ForSequenceClassification":
|
||||
vllm_extra_kwargs["hf_overrides"] = {
|
||||
"architectures": ["Qwen2ForSequenceClassification"],
|
||||
"classifier_from_token": ["0", "1"],
|
||||
"method": "from_2_way_softmax",
|
||||
}
|
||||
|
||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
|
||||
vllm_extra_kwargs)
|
||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
|
||||
|
@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test
|
||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||
from .mteb_utils import mteb_test_rerank_models
|
||||
|
||||
qwen3_reranker_hf_overrides = {
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
}
|
||||
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
hf_overrides=qwen3_reranker_hf_overrides,
|
||||
enable_test=True),
|
||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
hf_overrides=qwen3_reranker_hf_overrides,
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner):
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
|
||||
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
}
|
||||
}
|
||||
|
||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
|
||||
vllm_extra_kwargs)
|
||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
|
||||
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
},
|
||||
"tensor_parallel_size": 2,
|
||||
}
|
||||
|
||||
|
@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
|
||||
# [Cross-encoder]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
hf_overrides={
|
||||
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
|
@ -3,7 +3,8 @@
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -339,36 +340,43 @@ def softmax(data):
|
||||
return F.softmax(data, dim=-1)
|
||||
|
||||
|
||||
class EmbedModelInfo(NamedTuple):
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
name: str
|
||||
is_matryoshka: bool = False
|
||||
matryoshka_dimensions: Optional[list[int]] = None
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
hf_overrides: Optional[dict[str, Any]] = None
|
||||
default_pooling_type: str = ""
|
||||
enable_test: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedModelInfo(ModelInfo):
|
||||
is_matryoshka: bool = False
|
||||
matryoshka_dimensions: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
default_pooling_type: str = "CLS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
default_pooling_type: str = "LAST"
|
||||
|
||||
|
||||
class RerankModelInfo(NamedTuple):
|
||||
name: str
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
default_pooling_type: str = ""
|
||||
enable_test: bool = True
|
||||
@dataclass
|
||||
class RerankModelInfo(ModelInfo):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
||||
default_pooling_type: str = "CLS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
||||
default_pooling_type: str = "LAST"
|
||||
|
||||
|
@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
|
||||
from .bert import BertPooler
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module):
|
||||
class BertWithRope(nn.Module, SupportsQuant):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.add_pooling_layer = add_pooling_layer
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.embeddings = BertWithRopeEmbedding(self.config)
|
||||
self.encoder = BertWithRopeEncoder(
|
||||
@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
bias=getattr(self.config, "bias", True),
|
||||
rotary_kwargs=self.config.rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "pooler" in name:
|
||||
if not self.add_pooling_layer and "pooler" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -508,8 +517,8 @@ class GteNewModel(BertWithRope):
|
||||
"attention.o_proj": "attn.out_proj",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
# GteNewModel only gate_up_proj does not have bias.
|
||||
# Hack method learned from vllm/model_executor/models/glm.py
|
||||
@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
weights = self.jina_merge_lora_weights(weights)
|
||||
return super().load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.new = GteNewModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "classifier"),
|
||||
return_bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.new(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
@ -191,12 +191,14 @@ _EMBEDDING_MODELS = {
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"GteNewForSequenceClassification": ("bert_with_rope",
|
||||
"GteNewForSequenceClassification"),
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
"RobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
"XLMRobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
Reference in New Issue
Block a user