[Model][2/N] Automatic conversion of CrossEncoding model (#19978)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-07-03 21:59:23 +08:00
committed by GitHub
parent 1819fbda63
commit 6f1229f91d
16 changed files with 199 additions and 92 deletions

View File

@ -471,7 +471,7 @@ Specified using `--task classify`.
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
If your model is not in the above list, we will try to automatically convert the model using
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
#### Sentence Pair Scoring

View File

@ -426,7 +426,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
Code example: <gh-file:examples/online_serving/openai_classification_client.py>

View File

@ -6,19 +6,16 @@ import pytest
# yapf conflicts with isort for this block
# yapf: disable
from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS,
MTEB_RERANK_TASKS,
MTEB_RERANK_TOL,
RerankClientMtebEncoder,
ScoreClientMtebEncoder,
run_mteb_rerank)
from tests.models.language.pooling.mteb_utils import (
MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL,
RerankClientMtebEncoder, ScoreClientMtebEncoder,
mteb_test_rerank_models_hf, run_mteb_rerank)
# yapf: enable
from tests.utils import RemoteOpenAIServer
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
MAIN_SCORE = 0.33437
@pytest.fixture(scope="module")
@ -31,12 +28,19 @@ def server():
yield remote_server
def test_mteb_score(server):
@pytest.fixture(scope="module")
def st_main_score(hf_runner):
# The main score related to the version of the dependency.
# So we need to recalculate every time.
main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME)
return main_score
def test_mteb_score(server, st_main_score):
url = server.url_for("score")
encoder = ScoreClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE
print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
@ -45,12 +49,11 @@ def test_mteb_score(server):
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
def test_mteb_rerank(server):
def test_mteb_rerank(server, st_main_score):
url = server.url_for("rerank")
encoder = RerankClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE
print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)

View File

@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
return main_score
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
with hf_runner(model_name, is_cross_encoder=True,
dtype="float32") as hf_model:
original_predict = hf_model.predict
def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)
hf_model.predict = _predict
hf_model.original_predict = original_predict
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
return st_main_score, st_dtype
def mteb_test_rerank_models(hf_runner,
vllm_runner,
model_info: RerankModelInfo,
@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner,
languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype
with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:
original_predict = hf_model.predict
def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)
hf_model.predict = _predict
hf_model.original_predict = original_predict
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)

View File

@ -9,9 +9,9 @@ import torch.cuda
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
assert is_text_generation_model(model_cls)
# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(as_seq_cls_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))

View File

@ -52,7 +52,7 @@ def test_get_field():
("distilbert/distilgpt2", "generate", "generate"),
("intfloat/multilingual-e5-small", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
],
@ -72,6 +72,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("distilbert/distilgpt2", "pooling", "embed"),
("intfloat/multilingual-e5-small", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
("openai/whisper-small", "pooling", "embed"),
],
)
def test_score_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="score",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
])

View File

@ -93,14 +93,14 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
"transcription"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
@ -777,7 +777,7 @@ class ModelConfig:
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "score"
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"
@ -841,14 +841,24 @@ class ModelConfig:
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_option = "embed"
task_option = "embed"
if task_option not in supported_tasks:
msg = (

View File

@ -1289,9 +1289,13 @@ class LLM:
raise ValueError(" ".join(messages))
if self.llm_engine.model_config.task not in ("embed", "score"):
raise ValueError(
"Score API is only enabled for `--task embed or --task score`")
if self.llm_engine.model_config.task not in ("embed", "classify"):
raise ValueError("Score API is only enabled for "
"`--task embed or --task classify`.")
if (self.llm_engine.model_config.task == "classify"
and self.llm_engine.model_config.hf_config.num_labels != 1):
raise ValueError("Score API is only enabled for num_labels == 1.")
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing

View File

@ -1311,24 +1311,27 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.task == "classify" else None
enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
request_logger=request_logger) if enable_serving_reranking else None
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if (
model_config.task == "embed" or enable_serving_reranking) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,

View File

@ -357,12 +357,16 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embed" else None
enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)
openai_serving_scores = (ServingScores(
engine,
model_config,
openai_serving_models,
request_logger=request_logger,
) if model_config.task == "score" else None)
) if (model_config.task == "embed" or enable_serving_reranking) else None)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

View File

@ -285,6 +285,7 @@ class PoolerHead(nn.Module):
else:
pooled_data = pooled_data.to(torch.float32)
# for matryoshka representation
if isinstance(pooling_metadata, V0PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
@ -299,10 +300,16 @@ class PoolerHead(nn.Module):
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
if len(set(dimensions_list)) == 1 and not isinstance(
pooled_data, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
else:
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
if self.normalize:
if isinstance(pooled_data, list):
@ -325,6 +332,10 @@ class PoolerHead(nn.Module):
else:
pooled_data = F.sigmoid(pooled_data)
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
return pooled_data
@ -419,7 +430,6 @@ class ClassifierPooler(nn.Module):
offset += prompt_len
pooled_data.append(pooled_data_i)
offset = 0
pooled_data_lst = []
for pooled_data_i in pooled_data:
@ -436,7 +446,8 @@ class ClassifierPooler(nn.Module):
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
scores = self.default_activation_function(pooled_output).squeeze(-1)
# shape: (batch_size, num_labels)
scores = self.default_activation_function(pooled_output)
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)

View File

@ -21,8 +21,7 @@ from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available
@ -246,7 +245,9 @@ def get_model_architecture(
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
# Cannot automatically run as_seq_cls_model,
# otherwise it will cause a circular reference on is_cross_encoder_model
pass
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
import torch
import torch.nn as nn
@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
return ModelForEmbedding # type: ignore
def as_classification_model(cls: _T) -> _T:
def as_seq_cls_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classification.
Subclass an existing vLLM model to support classify and score tasks.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
default_softmax=True,
)
class ModelForClassification(ModelForPooling):
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
def __init__(
self,
@ -190,6 +193,10 @@ def as_classification_model(cls: _T) -> _T:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
@ -205,17 +212,41 @@ def as_classification_model(cls: _T) -> _T:
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)
def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
ModelForClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForClassification")
ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
return ModelForClassification # type: ignore
return ModelForSequenceClassification # type: ignore
def as_reward_model(cls: _T) -> _T:

View File

@ -50,6 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
@ -495,3 +496,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

View File

@ -164,8 +164,6 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
@ -180,7 +178,9 @@ _CROSS_ENCODER_MODELS = {
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
}
_MULTIMODAL_MODELS = {

View File

@ -453,6 +453,7 @@ class ClassificationOutput:
@staticmethod
def from_base(pooling_output: PoolingOutput):
# pooling_output shape: (num_classes)
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D probability vector")
@ -490,7 +491,10 @@ class ScoringOutput:
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
# pooling_output shape:
# classify task: (num_classes) num_classes == 1
# embed task: a scalar value
pooled_data = pooling_output.data.squeeze()
if pooled_data.ndim != 0:
raise ValueError("pooled_data should be a scalar score")