[Model] Improve Pooling Model (#25149)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-09-18 19:04:21 +08:00
committed by GitHub
parent cc935fdd7e
commit 37970105fe
2 changed files with 7 additions and 6 deletions

View File

@ -12,8 +12,9 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation):
super().__init__()
if static_num_labels:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.num_labels = getattr(vllm_config.model_config.hf_config,
"num_labels", 0)
@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead):
super().__init__(activation=PoolerNormalize())
# Load ST projector if available
from vllm.config import get_current_vllm_config
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector: Optional[nn.Module] = _load_st_projector(
@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
@ -638,7 +635,6 @@ class ClassifierPooler(Pooler):
) -> None:
super().__init__()
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.pooling = pooling
@ -730,3 +726,7 @@ class DispatchPooler(Pooler):
offset += num_items
return PoolerOutput(outputs)
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s

View File

@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)