mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Add use_cross_encoder
flag to use correct activation in ClassifierPooler
(#20527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -1204,7 +1204,7 @@ class LLM:
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
|
||||
pooling_params = PoolingParams()
|
||||
pooling_params = PoolingParams(use_cross_encoder=True)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
|
@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
|
@ -25,9 +25,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -50,7 +48,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: AnyTokenizer,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
@ -141,7 +139,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
tokenizer: AnyTokenizer,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
@ -190,7 +188,7 @@ class ServingScores(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
pooling_params = request.to_pooling_params(use_cross_encoder=True)
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
@ -15,6 +15,7 @@ from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_classification_activation_function,
|
||||
get_cross_encoder_activation_function)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
@ -388,15 +389,14 @@ class ClassifierPooler(nn.Module):
|
||||
self.classifier = classifier
|
||||
self.pooler = pooler
|
||||
|
||||
if config.task == "score":
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config.hf_config)
|
||||
elif config.task == "classify":
|
||||
self.default_activation_function = nn.Sigmoid() \
|
||||
if config.hf_config.num_labels == 1 else nn.Softmax()
|
||||
else:
|
||||
raise NotImplementedError(f"task={config.task!r} is not supported"
|
||||
" with the classification pooler")
|
||||
self.classification_act_fn = get_classification_activation_function(
|
||||
config.hf_config)
|
||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
||||
config.hf_config)
|
||||
|
||||
def _get_act_fn(self, use_cross_encoder: bool):
|
||||
return (self.cross_encoder_act_fn
|
||||
if use_cross_encoder else self.classification_act_fn)
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
@ -446,8 +446,28 @@ class ClassifierPooler(nn.Module):
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
|
||||
# shape: (batch_size, num_labels)
|
||||
scores = self.default_activation_function(pooled_output)
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
use_cross_encoder_list = [
|
||||
pooling_param.use_cross_encoder
|
||||
for _, pooling_param in pooling_metadata.seq_groups
|
||||
]
|
||||
else:
|
||||
use_cross_encoder_list = [
|
||||
pooling_param.use_cross_encoder
|
||||
for pooling_param in pooling_metadata.pooling_params
|
||||
]
|
||||
|
||||
# shape of scores: (batch_size, num_labels)
|
||||
if all(use_cross_encoder == use_cross_encoder_list[0]
|
||||
for use_cross_encoder in use_cross_encoder_list):
|
||||
act_fn = self._get_act_fn(use_cross_encoder_list[0])
|
||||
scores = act_fn(pooled_output)
|
||||
else:
|
||||
scores = torch.stack([
|
||||
self._get_act_fn(use_cross_encoder)(vecs)
|
||||
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
|
||||
pooled_output)
|
||||
])
|
||||
|
||||
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
@ -25,8 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
@ -462,9 +460,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
|
@ -18,8 +18,6 @@ from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
@ -178,9 +176,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
|
@ -24,12 +24,14 @@ class PoolingParams(
|
||||
"""
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
use_cross_encoder: bool = False
|
||||
additional_data: Optional[Any] = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
@ -54,6 +56,7 @@ class PoolingParams(
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"additional_metadata={self.additional_data})")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -866,24 +866,26 @@ def try_get_generation_config(
|
||||
return None
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
|
||||
function_name: Optional[str] = None
|
||||
if hasattr(config, "sentence_transformers") and "activation_fn" in \
|
||||
config.sentence_transformers:
|
||||
if (hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers):
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
|
||||
elif (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), \
|
||||
"Loading of activation functions is restricted to " \
|
||||
"torch.nn.modules for security reasons"
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons")
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
||||
|
||||
def try_get_safetensors_metadata(
|
||||
|
Reference in New Issue
Block a user