[Bugfix] Add use_cross_encoder flag to use correct activation in ClassifierPooler (#20527)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-07 05:01:48 +08:00
committed by GitHub
parent 9528e3a05e
commit c18b3b8e8b
8 changed files with 56 additions and 41 deletions

View File

@ -1204,7 +1204,7 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
pooling_params = PoolingParams(use_cross_encoder=True)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,

View File

@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
class RerankRequest(OpenAIBaseModel):
@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
class RerankDocument(BaseModel):

View File

@ -25,9 +25,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
@ -50,7 +48,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
@ -141,7 +139,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
@ -190,7 +188,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
pooling_params = request.to_pooling_params(use_cross_encoder=True)
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"

View File

@ -15,6 +15,7 @@ from vllm.model_executor.pooling_metadata import ( # noqa: E501
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_classification_activation_function,
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
@ -388,15 +389,14 @@ class ClassifierPooler(nn.Module):
self.classifier = classifier
self.pooler = pooler
if config.task == "score":
self.default_activation_function = \
get_cross_encoder_activation_function(config.hf_config)
elif config.task == "classify":
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax()
else:
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
self.classification_act_fn = get_classification_activation_function(
config.hf_config)
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config)
def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)
def get_prompt_lens(
self,
@ -446,8 +446,28 @@ class ClassifierPooler(nn.Module):
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
# shape: (batch_size, num_labels)
scores = self.default_activation_function(pooled_output)
if isinstance(pooling_metadata, V0PoolingMetadata):
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for _, pooling_param in pooling_metadata.seq_groups
]
else:
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for pooling_param in pooling_metadata.pooling_params
]
# shape of scores: (batch_size, num_labels)
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
act_fn = self._get_act_fn(use_cross_encoder_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(use_cross_encoder)(vecs)
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
pooled_output)
])
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)

View File

@ -25,8 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@ -462,9 +460,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
super().__init__()
config = vllm_config.model_config.hf_config
self.default_activation_function = \
get_cross_encoder_activation_function(config)
self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),

View File

@ -18,8 +18,6 @@ from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
@ -178,9 +176,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
super().__init__()
config = vllm_config.model_config.hf_config
self.default_activation_function = \
get_cross_encoder_activation_function(config)
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),

View File

@ -24,12 +24,14 @@ class PoolingParams(
"""
dimensions: Optional[int] = None
use_cross_encoder: bool = False
additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None:
@ -54,6 +56,7 @@ class PoolingParams(
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})")
def __post_init__(self) -> None:

View File

@ -866,24 +866,26 @@ def try_get_generation_config(
return None
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \
config.sentence_transformers:
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata(