mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model][2/N] Automatic conversion of CrossEncoding model (#19978)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@ -471,7 +471,7 @@ Specified using `--task classify`.
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
|
||||
#### Sentence Pair Scoring
|
||||
|
||||
|
@ -426,7 +426,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
|
||||
|
||||
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
|
||||
|
||||
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
|
||||
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
|
||||
|
||||
|
@ -6,19 +6,16 @@ import pytest
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS,
|
||||
MTEB_RERANK_TASKS,
|
||||
MTEB_RERANK_TOL,
|
||||
RerankClientMtebEncoder,
|
||||
ScoreClientMtebEncoder,
|
||||
run_mteb_rerank)
|
||||
from tests.models.language.pooling.mteb_utils import (
|
||||
MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL,
|
||||
RerankClientMtebEncoder, ScoreClientMtebEncoder,
|
||||
mteb_test_rerank_models_hf, run_mteb_rerank)
|
||||
# yapf: enable
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
|
||||
|
||||
MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
MAIN_SCORE = 0.33437
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -31,12 +28,19 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
def test_mteb_score(server):
|
||||
@pytest.fixture(scope="module")
|
||||
def st_main_score(hf_runner):
|
||||
# The main score related to the version of the dependency.
|
||||
# So we need to recalculate every time.
|
||||
main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME)
|
||||
return main_score
|
||||
|
||||
|
||||
def test_mteb_score(server, st_main_score):
|
||||
url = server.url_for("score")
|
||||
encoder = ScoreClientMtebEncoder(MODEL_NAME, url)
|
||||
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
|
||||
MTEB_RERANK_LANGS)
|
||||
st_main_score = MAIN_SCORE
|
||||
|
||||
print("VLLM main score: ", vllm_main_score)
|
||||
print("SentenceTransformer main score: ", st_main_score)
|
||||
@ -45,12 +49,11 @@ def test_mteb_score(server):
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
|
||||
|
||||
|
||||
def test_mteb_rerank(server):
|
||||
def test_mteb_rerank(server, st_main_score):
|
||||
url = server.url_for("rerank")
|
||||
encoder = RerankClientMtebEncoder(MODEL_NAME, url)
|
||||
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
|
||||
MTEB_RERANK_LANGS)
|
||||
st_main_score = MAIN_SCORE
|
||||
|
||||
print("VLLM main score: ", vllm_main_score)
|
||||
print("SentenceTransformer main score: ", st_main_score)
|
||||
|
@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
|
||||
return main_score
|
||||
|
||||
|
||||
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
|
||||
with hf_runner(model_name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
return st_main_score, st_dtype
|
||||
|
||||
|
||||
def mteb_test_rerank_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = model_config.dtype
|
||||
|
||||
with hf_runner(model_info.name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||
hf_runner, model_info.name, hf_model_callback)
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||
|
@ -9,9 +9,9 @@ import torch.cuda
|
||||
from vllm.model_executor.models import (is_pooling_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
as_reward_model)
|
||||
from vllm.model_executor.models.adapters import (as_embedding_model,
|
||||
as_reward_model,
|
||||
as_seq_cls_model)
|
||||
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
|
||||
assert is_text_generation_model(model_cls)
|
||||
|
||||
# All vLLM models should be convertible to a pooling model
|
||||
assert is_pooling_model(as_classification_model(model_cls))
|
||||
assert is_pooling_model(as_seq_cls_model(model_cls))
|
||||
assert is_pooling_model(as_embedding_model(model_cls))
|
||||
assert is_pooling_model(as_reward_model(model_cls))
|
||||
|
||||
|
@ -52,7 +52,7 @@ def test_get_field():
|
||||
("distilbert/distilgpt2", "generate", "generate"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
||||
("openai/whisper-small", "transcription", "transcription"),
|
||||
],
|
||||
@ -72,6 +72,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "pooling", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
|
||||
("openai/whisper-small", "pooling", "embed"),
|
||||
],
|
||||
)
|
||||
def test_score_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="score",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
|
||||
])
|
||||
|
@ -93,14 +93,14 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription"]
|
||||
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
||||
"draft", "transcription"]
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
|
||||
"transcription"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||
"generate": ["generate"],
|
||||
"pooling": ["embed", "classify", "score", "reward"],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
"draft": ["draft"],
|
||||
"transcription": ["transcription"],
|
||||
}
|
||||
@ -777,7 +777,7 @@ class ModelConfig:
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
return "score"
|
||||
return "classify"
|
||||
if self.registry.is_transcription_model(architectures):
|
||||
return "transcription"
|
||||
|
||||
@ -841,14 +841,24 @@ class ModelConfig:
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
if task_option == "score":
|
||||
if not runner_support["pooling"]:
|
||||
msg = (f"This model does not support the '{task_option}' "
|
||||
f"task. Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
task_option = "classify"
|
||||
else:
|
||||
task_option = "embed"
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
task_option = "embed"
|
||||
task_option = "embed"
|
||||
|
||||
if task_option not in supported_tasks:
|
||||
msg = (
|
||||
|
@ -1289,9 +1289,13 @@ class LLM:
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if self.llm_engine.model_config.task not in ("embed", "score"):
|
||||
raise ValueError(
|
||||
"Score API is only enabled for `--task embed or --task score`")
|
||||
if self.llm_engine.model_config.task not in ("embed", "classify"):
|
||||
raise ValueError("Score API is only enabled for "
|
||||
"`--task embed or --task classify`.")
|
||||
|
||||
if (self.llm_engine.model_config.task == "classify"
|
||||
and self.llm_engine.model_config.hf_config.num_labels != 1):
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
|
@ -1311,24 +1311,27 @@ async def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embed" else None
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if model_config.task in (
|
||||
"score", "embed", "pooling") else None
|
||||
state.openai_serving_classification = ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "classify" else None
|
||||
|
||||
enable_serving_reranking = (model_config.task == "classify" and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
state.jinaai_serving_reranking = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
request_logger=request_logger) if enable_serving_reranking else None
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if (
|
||||
model_config.task == "embed" or enable_serving_reranking) else None
|
||||
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
@ -357,12 +357,16 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embed" else None
|
||||
|
||||
enable_serving_reranking = (model_config.task == "classify" and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = (ServingScores(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "score" else None)
|
||||
) if (model_config.task == "embed" or enable_serving_reranking) else None)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
@ -285,6 +285,7 @@ class PoolerHead(nn.Module):
|
||||
else:
|
||||
pooled_data = pooled_data.to(torch.float32)
|
||||
|
||||
# for matryoshka representation
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions
|
||||
@ -299,10 +300,16 @@ class PoolerHead(nn.Module):
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
if len(set(dimensions_list)) == 1 and not isinstance(
|
||||
pooled_data, list):
|
||||
# if all dimensions are the same
|
||||
d = dimensions_list[0]
|
||||
pooled_data = pooled_data[..., :d]
|
||||
else:
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
if self.normalize:
|
||||
if isinstance(pooled_data, list):
|
||||
@ -325,6 +332,10 @@ class PoolerHead(nn.Module):
|
||||
else:
|
||||
pooled_data = F.sigmoid(pooled_data)
|
||||
|
||||
# shape:
|
||||
# classify (& score) -> (batch_size, num_classes)
|
||||
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||
return pooled_data
|
||||
|
||||
|
||||
@ -419,7 +430,6 @@ class ClassifierPooler(nn.Module):
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
for pooled_data_i in pooled_data:
|
||||
|
||||
@ -436,7 +446,8 @@ class ClassifierPooler(nn.Module):
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
|
||||
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
||||
# shape: (batch_size, num_labels)
|
||||
scores = self.default_activation_function(pooled_output)
|
||||
|
||||
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
@ -21,8 +21,7 @@ from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
from vllm.model_executor.models.adapters import (as_embedding_model,
|
||||
as_reward_model)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.utils import is_pin_memory_available
|
||||
@ -246,7 +245,9 @@ def get_model_architecture(
|
||||
if model_config.task == "embed":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif model_config.task == "classify":
|
||||
model_cls = as_classification_model(model_cls)
|
||||
# Cannot automatically run as_seq_cls_model,
|
||||
# otherwise it will cause a circular reference on is_cross_encoder_model
|
||||
pass
|
||||
elif model_config.task == "reward":
|
||||
model_cls = as_reward_model(model_cls)
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return ModelForEmbedding # type: ignore
|
||||
|
||||
|
||||
def as_classification_model(cls: _T) -> _T:
|
||||
def as_seq_cls_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support classification.
|
||||
Subclass an existing vLLM model to support classify and score tasks.
|
||||
|
||||
By default, the class probabilities are extracted from the softmaxed
|
||||
hidden state corresponding to the last token.
|
||||
@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
|
||||
default_softmax=True,
|
||||
)
|
||||
|
||||
class ModelForClassification(ModelForPooling):
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -190,6 +193,10 @@ def as_classification_model(cls: _T) -> _T:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.task = vllm_config.model_config.task
|
||||
self.pooling_type = (
|
||||
vllm_config.model_config.pooler_config.pooling_type)
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
@ -205,17 +212,41 @@ def as_classification_model(cls: _T) -> _T:
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = super().forward(input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
return super().forward(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
|
||||
def get_logits(hidden_states):
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
if self.pooling_type == PoolingType.ALL:
|
||||
logits = get_logits(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
else:
|
||||
hidden_states = self._pooler.extract_states(
|
||||
hidden_states, pooling_metadata)
|
||||
logits = get_logits(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
ModelForClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForClassification")
|
||||
ModelForSequenceClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
|
||||
|
||||
return ModelForClassification # type: ignore
|
||||
return ModelForSequenceClassification # type: ignore
|
||||
|
||||
|
||||
def as_reward_model(cls: _T) -> _T:
|
||||
|
@ -50,6 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@ -495,3 +496,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)
|
||||
|
@ -164,8 +164,6 @@ _EMBEDDING_MODELS = {
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
@ -180,7 +178,9 @@ _CROSS_ENCODER_MODELS = {
|
||||
"RobertaForSequenceClassification"),
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
@ -453,6 +453,7 @@ class ClassificationOutput:
|
||||
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
# pooling_output shape: (num_classes)
|
||||
pooled_data = pooling_output.data
|
||||
if pooled_data.ndim != 1:
|
||||
raise ValueError("pooled_data should be a 1-D probability vector")
|
||||
@ -490,7 +491,10 @@ class ScoringOutput:
|
||||
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
pooled_data = pooling_output.data
|
||||
# pooling_output shape:
|
||||
# classify task: (num_classes) num_classes == 1
|
||||
# embed task: a scalar value
|
||||
pooled_data = pooling_output.data.squeeze()
|
||||
if pooled_data.ndim != 0:
|
||||
raise ValueError("pooled_data should be a scalar score")
|
||||
|
||||
|
Reference in New Issue
Block a user