[Model][1/N] Support multiple poolers at model level (#21227)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-21 17:22:21 +08:00
committed by GitHub
parent 378d33c392
commit 042af0c8d3
22 changed files with 549 additions and 413 deletions

View File

@ -11,26 +11,51 @@ before returning them.
As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
For pooling models, we support the following `--task` options.
The selected option sets the default pooler used to extract the final hidden states:
If the model doesn't implement this interface, you can set `--task` which tells vLLM
to convert the model into a pooling model.
| Task | Pooling Type | Normalization | Softmax |
|---------------------------------|----------------|-----------------|-----------|
| Embedding (`embed`) | `LAST` | ✅︎ | ❌ |
| Classification (`classify`) | `LAST` | ❌ | ✅︎ |
| Sentence Pair Scoring (`score`) | \* | \* | \* |
| `--task` | Model type | Supported pooling tasks |
|------------|----------------------|-------------------------------|
| `embed` | Embedding model | `encode`, `embed` |
| `classify` | Classification model | `encode`, `classify`, `score` |
| `reward` | Reward model | `encode` |
\*The default pooler is always defined by the model.
## Pooling Tasks
!!! note
If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table.
In vLLM, we define the following pooling tasks and corresponding APIs:
| Task | APIs |
|------------|--------------------|
| `encode` | `encode` |
| `embed` | `embed`, `score`\* |
| `classify` | `classify` |
| `score` | `score` |
\*The `score` API falls back to `embed` task if the model does not support `score` task.
Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks].
By default, the pooler assigned to each task has the following attributes:
| Task | Pooling Type | Normalization | Softmax |
|------------|----------------|---------------|---------|
| `encode` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |
These defaults may be overridden by the model's implementation in vLLM.
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`).
we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`),
which takes priority over the model's defaults.
!!! tip
You can customize the model's pooling method via the `--override-pooler-config` option,
which takes priority over both the model's and Sentence Transformers's defaults.
You can further customize this via the `--override-pooler-config` option,
which takes priority over both the model's and Sentence Transformers's defaults.
!!! note
The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler
that is not based on [PoolerConfig][vllm.config.PoolerConfig].
## Offline Inference

View File

@ -144,7 +144,7 @@ def test_quantization(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["float"])
def test_classify(
hf_runner,
vllm_runner,

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.sequence import IntermediateTensors
@ -26,12 +26,13 @@ class MyGemma2Embedding(nn.Module):
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

View File

@ -94,7 +94,7 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
"classify", "reward", "draft"]
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
@ -103,7 +103,7 @@ RunnerType = Literal["generate", "pooling", "draft"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate", "transcription"],
"pooling": ["pooling", "embed", "classify", "reward"],
"pooling": ["encode", "embed", "classify", "reward"],
"draft": [],
}
@ -579,7 +579,7 @@ class ModelConfig:
# user-selected task
if runner_type == "pooling" and self.task == "auto":
selected_task = all_supported_tasks[runner_type][-1]
assert selected_task != "pooling"
assert selected_task != "encode"
self.task = selected_task
self.supported_runner_types = supported_runner_types
self.runner_type = runner_type
@ -884,7 +884,7 @@ class ModelConfig:
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures):
supported_tasks.append("pooling")
supported_tasks.append("encode")
# For now, users must specify the task (other than "pooling")
# to use for pooling models

View File

@ -1668,7 +1668,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if "pooling" in model_config.supported_tasks else None
) if "encode" in model_config.supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,

View File

@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
@ -21,6 +22,10 @@ from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
class PoolingType(IntEnum):
@ -79,37 +84,81 @@ class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def from_config_with_defaults(
def for_encode(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "Pooler":
*,
default_pooling_type: PoolingType = PoolingType.ALL,
default_normalize: bool = False,
default_softmax: bool = False,
default_step_tag_id: Optional[int] = None,
default_returned_token_ids: Optional[list[int]] = None,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
step_tag_id=default_step_tag_id,
returned_token_ids=default_returned_token_ids,
)
if pooling_type == PoolingType.STEP:
if resolved_config.pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
@staticmethod
def for_embed(
pooler_config: PoolerConfig,
*,
default_pooling_type: PoolingType = PoolingType.LAST,
default_normalize: bool = True,
default_softmax: bool = False,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: Optional[ClassifierFn],
*,
default_pooling_type: PoolingType = PoolingType.LAST,
default_normalize: bool = False,
default_softmax: bool = True,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
base_pooler = SimplePooler.from_config(resolved_config)
if classifier is None:
return base_pooler
return ClassifierPooler(
pooling=base_pooler.pooling,
classifier=classifier,
act_fn=base_pooler.head.activation,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
Construct the updated pooling parameters to use for a supported task.
"""
return None
return PoolingParamsUpdate()
@abstractmethod
def forward(
@ -127,9 +176,8 @@ def get_prompt_lens(
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
pooling_metadata, hidden_states[0].device).prompt_lens
def get_prompt_token_ids(
@ -149,6 +197,21 @@ def get_prompt_token_ids(
]
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
if isinstance(pooling_metadata, V0PoolingMetadata):
pooling_params = [p for _, p in pooling_metadata.seq_groups]
else:
pooling_params = pooling_metadata.pooling_params
tasks: list[PoolingTask] = [
task for pooling_param in pooling_params
if (task := pooling_param.task) is not None
]
assert len(pooling_params) == len(tasks)
return tasks
def get_classification_activation_function(config: PretrainedConfig):
return PoolerClassify()
@ -172,7 +235,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return PoolerScore()
def build_output(all_data: torch.Tensor) -> PoolerOutput:
def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)
@ -193,12 +257,12 @@ class PoolingMethod(nn.Module, ABC):
raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward_one(
self,
@ -237,16 +301,8 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod):
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
@ -270,16 +326,8 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod):
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
@ -299,18 +347,8 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod):
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParamsUpdate()
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def forward_one(
self,
@ -327,28 +365,13 @@ class AllPool(PoolingMethod):
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
return pooled_data
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
class MeanPool(PoolingMethod):
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
@ -529,24 +552,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
@classmethod
def from_config_with_defaults( # type: ignore[override]
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "SimplePooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
)
assert resolved_config.pooling_type != PoolingType.STEP
return cls.from_config(resolved_config)
@classmethod
def from_config(
cls,
@ -563,10 +568,10 @@ class SimplePooler(Pooler):
self.pooling = pooling
self.head = head
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def forward(
@ -627,18 +632,11 @@ class StepPooler(Pooler):
return pooled_data
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParamsUpdate(requires_token_ids=True)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
@ -650,68 +648,43 @@ class StepPooler(Pooler):
return build_output(pooled_data)
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
class ClassifierPooler(nn.Module):
class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks.
This layer does the following:
1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer.
3. Applies an activation function to the output. In the case of
classification models it is either sigmoid or softmax. In the
case of scoring models, the same behavior is configuration
dependent, as in the sentence-transformers library.
3. Applies an activation function to the output.
"""
@staticmethod
def act_fn_for_seq_cls(config: ModelConfig):
return get_classification_activation_function(config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(config: ModelConfig):
return get_cross_encoder_activation_function(config.hf_config)
def __init__(
self,
config: ModelConfig,
pooling: PoolingFn,
classifier: ClassifierFn,
act_fn: Optional[PoolerActivation] = None,
act_fn: PoolerActivation,
) -> None:
super().__init__()
self.pooling = pooling
self.classifier = classifier
self.act_fn = act_fn
self.classification_act_fn = get_classification_activation_function(
config.hf_config) if act_fn is None else act_fn
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config) if act_fn is None else act_fn
def _get_act_fn(self, task: PoolingTask):
if task == "encode" or task == "classify":
return self.classification_act_fn
if task == "score":
return self.cross_encoder_act_fn
raise ValueError(f"Unsupported task: {task!r}")
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "classify" or task == "score":
return PoolingParamsUpdate()
if task == "embed":
return None
assert_never(task)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
pooled_data = self.pooling(hidden_states, pooling_metadata)
# apply classifier once on the full batch if possible
@ -722,28 +695,59 @@ class ClassifierPooler(nn.Module):
else:
pooled_output = [self.classifier(data) for data in pooled_data]
task_list: list[PoolingTask]
if isinstance(pooling_metadata, V0PoolingMetadata):
task_list = [
task for _, pooling_param in pooling_metadata.seq_groups
if (task := pooling_param.task) is not None
]
else:
task_list = [
task for pooling_param in pooling_metadata.pooling_params
if (task := pooling_param.task) is not None
]
assert len(task_list) == len(pooled_output)
# shape of scores: (batch_size, num_labels)
if len(set(task_list)) <= 1:
act_fn = self._get_act_fn(task_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(task)(vecs)
for task, vecs in zip(task_list, pooled_output)
])
scores = self.act_fn(pooled_output)
return build_output(scores)
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. "
f"Supported tasks: {pooler.get_supported_tasks()}")
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
if isinstance(hidden_states, list):
hidden_states_lst = hidden_states
else:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
outputs = list[PoolingSequenceGroupOutput]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task} "
f"Supported tasks: {self.get_supported_tasks()}")
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states_lst[offset:offset + num_items],
pooling_metadata[offset:offset + num_items],
)
outputs.extend(group_output.outputs)
offset += num_items
return PoolerOutput(outputs)

View File

@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling):
@ -71,15 +62,7 @@ def _create_pooling_model_cls(
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}, )
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolingType, SimplePooler)
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForSequenceClassification(ModelForPooling,
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
SupportsCrossEncoding):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
pooling_type_str = pooler_config.pooling_type
pooling_type = (PoolingType.LAST if pooling_type_str is None else
PoolingType[pooling_type_str])
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Set
from typing import Optional, Union
import torch
@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType)
@ -92,20 +93,29 @@ class BertPooler(Pooler):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
add_pooling_layer: bool = False):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type[nn.Module] = BertEmbedding,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None
def forward(
self,
@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant):
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant):
if name in params_dict:
other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(
self,
skip_prefixes=(["pooler."] if self.pooler is None else []),
return other_weights, loaded_stacked_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BertPoolingModel(BertModel):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type[nn.Module] = BertEmbedding,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
embedding_class=embedding_class,
)
config = vllm_config.model_config.hf_config
self.pooler = BertPooler(config)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config)
@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
return DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"embed":
Pooler.for_embed(
pooler_config,
default_pooling_type=PoolingType.CLS,
),
})
class BertForSequenceClassification(nn.Module, SupportsV0Only,
@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
config = vllm_config.model_config.hf_config
self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.bert = BertPoolingModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from ..layers.pooler import Pooler, PoolingType
from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -339,12 +339,16 @@ class GPT2ForSequenceClassification(nn.Module):
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set
from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
from typing_extensions import assert_never
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
PoolerNormalize,
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolerHead, PoolerNormalize,
PoolingParamsUpdate,
build_output, get_prompt_lens,
get_prompt_token_ids)
@ -135,18 +134,11 @@ class GritLMMeanPool(nn.Module):
return instruction_len
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "embed":
return PoolingParamsUpdate(requires_token_ids=True)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed"}
if task == "classify" or task == "score":
return None
assert_never(task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward_one(
self,
@ -207,10 +199,10 @@ class GritLMPooler(Pooler):
self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize())
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def forward(
@ -262,4 +254,11 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = GritLMPooler(vllm_config.model_config)
pooler_config = vllm_config.model_config.pooler_config
if pooler_config is not None:
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"embed":
GritLMPooler(vllm_config.model_config),
})

View File

@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -429,12 +429,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False,
)
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
def forward(
self,

View File

@ -19,8 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
SimplePooler)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -584,16 +584,15 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False,
)
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(
pooler_config,
classifier=self.score,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=False,
),
})

View File

@ -12,7 +12,7 @@ from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@ -96,11 +96,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self.score = JinaVLScorer(config)
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
"score":
Pooler.for_classify(pooler_config, classifier=None),
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Set
from typing import Optional, Union
import torch
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType)
@ -271,19 +272,27 @@ class ModernBertPooler(Pooler):
eps=config.norm_eps,
bias=config.norm_bias)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
return self.norm(self.act(self.dense(pooled_output)))
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
@ -299,11 +308,28 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=ModernBertPooler(config),
classifier=self.classifier,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=ModernBertPooler(config),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=ModernBertPooler(config),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -15,7 +15,8 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
is_pooling_model = True
pooler: SimplePooler
pooler: Pooler
packed_modules_mapping = {
"qkv_proj": [
@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=True,
step_tag_id=151651,
)
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(
pooler_config,
default_pooling_type=PoolingType.STEP,
default_normalize=False,
default_softmax=True,
default_step_tag_id=151651,
)
})

View File

@ -9,7 +9,8 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
@ -63,16 +64,10 @@ class RobertaEmbedding(nn.Module):
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
@ -184,15 +179,30 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding,
add_pooling_layer=False)
embedding_class=RobertaEmbedding)
self.classifier = RobertaClassificationHead(config)
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@ -38,6 +38,13 @@ class PoolingMetadata:
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
def __getitem__(self, indices: slice):
return PoolingMetadata(
seq_groups=self.seq_groups[indices],
seq_data=dict(list(self.seq_data.items())[indices]),
prompt_lens=self.prompt_lens[indices],
)
@dataclass
class PoolingTensors:

View File

@ -15,3 +15,11 @@ class PoolingMetadata:
prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]
def __getitem__(self, indices: slice):
return PoolingMetadata(
prompt_lens=self.prompt_lens[indices],
prompt_token_ids=None if self.prompt_token_ids is None else
self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
)

View File

@ -5,7 +5,7 @@ import copy
import gc
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import numpy as np
import torch
@ -415,15 +415,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
generator = None
if pooling_params:
assert pooling_params.task is not None, (
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
to_update = (model.pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState(
@ -1122,10 +1118,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
return list(model.pooler.get_supported_tasks())
def apply_grammar_bitmask(
self,
@ -2247,7 +2240,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dummy_pooling_params = PoolingParams(task=dummy_task)
to_update = model.pooler.get_pooling_updates(dummy_task)
assert to_update is not None
to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata(

View File

@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch
import numpy as np
@ -491,10 +491,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
return list(model.pooler.get_supported_tasks())
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""

View File

@ -4,7 +4,7 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar, get_args)
TypeVar)
import torch
import torch.nn as nn
@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
return list(model.pooler.get_supported_tasks())
def execute_model(
self,

View File

@ -199,15 +199,11 @@ class PoolingModelRunner(
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
assert pooling_params.task is not None, (
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
to_update = (cast(VllmModelForPooling,
self.model).pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
model = cast(VllmModelForPooling, self.model)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))