mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Improve Pooling Model (#25149)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -12,8 +12,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation):
|
||||
super().__init__()
|
||||
|
||||
if static_num_labels:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.num_labels = getattr(vllm_config.model_config.hf_config,
|
||||
"num_labels", 0)
|
||||
@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
|
||||
# Load ST projector if available
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: Optional[nn.Module] = _load_st_projector(
|
||||
@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
@ -638,7 +635,6 @@ class ClassifierPooler(Pooler):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.pooling = pooling
|
||||
@ -730,3 +726,7 @@ class DispatchPooler(Pooler):
|
||||
offset += num_items
|
||||
|
||||
return PoolerOutput(outputs)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"supported_task={self.get_supported_tasks()}"
|
||||
return s
|
||||
|
@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
model = cast(VllmModelForPooling, self.get_model())
|
||||
dummy_pooling_params = PoolingParams(task=task)
|
||||
dummy_pooling_params.verify(task=task, model_config=self.model_config)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(dummy_pooling_params)
|
||||
|
||||
|
Reference in New Issue
Block a user