mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model][1/N] Support multiple poolers at model level (#21227)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -11,26 +11,51 @@ before returning them.
|
||||
As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to
|
||||
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
|
||||
|
||||
For pooling models, we support the following `--task` options.
|
||||
The selected option sets the default pooler used to extract the final hidden states:
|
||||
If the model doesn't implement this interface, you can set `--task` which tells vLLM
|
||||
to convert the model into a pooling model.
|
||||
|
||||
| Task | Pooling Type | Normalization | Softmax |
|
||||
|---------------------------------|----------------|-----------------|-----------|
|
||||
| Embedding (`embed`) | `LAST` | ✅︎ | ❌ |
|
||||
| Classification (`classify`) | `LAST` | ❌ | ✅︎ |
|
||||
| Sentence Pair Scoring (`score`) | \* | \* | \* |
|
||||
| `--task` | Model type | Supported pooling tasks |
|
||||
|------------|----------------------|-------------------------------|
|
||||
| `embed` | Embedding model | `encode`, `embed` |
|
||||
| `classify` | Classification model | `encode`, `classify`, `score` |
|
||||
| `reward` | Reward model | `encode` |
|
||||
|
||||
\*The default pooler is always defined by the model.
|
||||
## Pooling Tasks
|
||||
|
||||
!!! note
|
||||
If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table.
|
||||
In vLLM, we define the following pooling tasks and corresponding APIs:
|
||||
|
||||
| Task | APIs |
|
||||
|------------|--------------------|
|
||||
| `encode` | `encode` |
|
||||
| `embed` | `embed`, `score`\* |
|
||||
| `classify` | `classify` |
|
||||
| `score` | `score` |
|
||||
|
||||
\*The `score` API falls back to `embed` task if the model does not support `score` task.
|
||||
|
||||
Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks].
|
||||
|
||||
By default, the pooler assigned to each task has the following attributes:
|
||||
|
||||
| Task | Pooling Type | Normalization | Softmax |
|
||||
|------------|----------------|---------------|---------|
|
||||
| `encode` | `ALL` | ❌ | ❌ |
|
||||
| `embed` | `LAST` | ✅︎ | ❌ |
|
||||
| `classify` | `LAST` | ❌ | ✅︎ |
|
||||
|
||||
These defaults may be overridden by the model's implementation in vLLM.
|
||||
|
||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||
we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`).
|
||||
we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`),
|
||||
which takes priority over the model's defaults.
|
||||
|
||||
!!! tip
|
||||
You can customize the model's pooling method via the `--override-pooler-config` option,
|
||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
You can further customize this via the `--override-pooler-config` option,
|
||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
|
||||
!!! note
|
||||
|
||||
The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler
|
||||
that is not based on [PoolerConfig][vllm.config.PoolerConfig].
|
||||
|
||||
## Offline Inference
|
||||
|
||||
|
@ -144,7 +144,7 @@ def test_quantization(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_classify(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -26,12 +26,13 @@ class MyGemma2Embedding(nn.Module):
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False,
|
||||
)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
})
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
@ -94,7 +94,7 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription", "draft"]
|
||||
|
||||
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
|
||||
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
|
||||
"classify", "reward", "draft"]
|
||||
|
||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||
@ -103,7 +103,7 @@ RunnerType = Literal["generate", "pooling", "draft"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||
"generate": ["generate", "transcription"],
|
||||
"pooling": ["pooling", "embed", "classify", "reward"],
|
||||
"pooling": ["encode", "embed", "classify", "reward"],
|
||||
"draft": [],
|
||||
}
|
||||
|
||||
@ -579,7 +579,7 @@ class ModelConfig:
|
||||
# user-selected task
|
||||
if runner_type == "pooling" and self.task == "auto":
|
||||
selected_task = all_supported_tasks[runner_type][-1]
|
||||
assert selected_task != "pooling"
|
||||
assert selected_task != "encode"
|
||||
self.task = selected_task
|
||||
self.supported_runner_types = supported_runner_types
|
||||
self.runner_type = runner_type
|
||||
@ -884,7 +884,7 @@ class ModelConfig:
|
||||
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if registry.is_pooling_model(architectures):
|
||||
supported_tasks.append("pooling")
|
||||
supported_tasks.append("encode")
|
||||
|
||||
# For now, users must specify the task (other than "pooling")
|
||||
# to use for pooling models
|
||||
|
@ -1668,7 +1668,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if "pooling" in model_config.supported_tasks else None
|
||||
) if "encode" in model_config.supported_tasks else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
@ -1,15 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
@ -21,6 +22,10 @@ from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]]]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@ -79,37 +84,81 @@ class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def from_config_with_defaults(
|
||||
def for_encode(
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> "Pooler":
|
||||
*,
|
||||
default_pooling_type: PoolingType = PoolingType.ALL,
|
||||
default_normalize: bool = False,
|
||||
default_softmax: bool = False,
|
||||
default_step_tag_id: Optional[int] = None,
|
||||
default_returned_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
step_tag_id=default_step_tag_id,
|
||||
returned_token_ids=default_returned_token_ids,
|
||||
)
|
||||
|
||||
if pooling_type == PoolingType.STEP:
|
||||
if resolved_config.pooling_type == PoolingType.STEP:
|
||||
return StepPooler.from_config(resolved_config)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
@staticmethod
|
||||
def for_embed(
|
||||
pooler_config: PoolerConfig,
|
||||
*,
|
||||
default_pooling_type: PoolingType = PoolingType.LAST,
|
||||
default_normalize: bool = True,
|
||||
default_softmax: bool = False,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: Optional[ClassifierFn],
|
||||
*,
|
||||
default_pooling_type: PoolingType = PoolingType.LAST,
|
||||
default_normalize: bool = False,
|
||||
default_softmax: bool = True,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
base_pooler = SimplePooler.from_config(resolved_config)
|
||||
if classifier is None:
|
||||
return base_pooler
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=base_pooler.pooling,
|
||||
classifier=classifier,
|
||||
act_fn=base_pooler.head.activation,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the pooling parameters to use for a task,
|
||||
or `None` if the task is not supported.
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return None
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
@ -127,9 +176,8 @@ def get_prompt_lens(
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
pooling_metadata, hidden_states[0].device).prompt_lens
|
||||
|
||||
|
||||
def get_prompt_token_ids(
|
||||
@ -149,6 +197,21 @@ def get_prompt_token_ids(
|
||||
]
|
||||
|
||||
|
||||
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
pooling_params = [p for _, p in pooling_metadata.seq_groups]
|
||||
else:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
tasks: list[PoolingTask] = [
|
||||
task for pooling_param in pooling_params
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
assert len(pooling_params) == len(tasks)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
return PoolerClassify()
|
||||
|
||||
@ -172,7 +235,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return PoolerScore()
|
||||
|
||||
|
||||
def build_output(all_data: torch.Tensor) -> PoolerOutput:
|
||||
def build_output(
|
||||
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
|
||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||
return PoolerOutput(outputs=all_outputs)
|
||||
|
||||
@ -193,12 +257,12 @@ class PoolingMethod(nn.Module, ABC):
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||
|
||||
@abstractmethod
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward_one(
|
||||
self,
|
||||
@ -237,16 +301,8 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
@ -270,16 +326,8 @@ class CLSPool(PoolingMethod):
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
@ -299,18 +347,8 @@ class LastPool(PoolingMethod):
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
if task == "encode":
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
@ -327,28 +365,13 @@ class AllPool(PoolingMethod):
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
|
||||
return pooled_data
|
||||
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
|
||||
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
@ -529,24 +552,6 @@ class SimplePooler(Pooler):
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults( # type: ignore[override]
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
) -> "SimplePooler":
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
)
|
||||
assert resolved_config.pooling_type != PoolingType.STEP
|
||||
|
||||
return cls.from_config(resolved_config)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
@ -563,10 +568,10 @@ class SimplePooler(Pooler):
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
@ -627,18 +632,11 @@ class StepPooler(Pooler):
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
if task == "encode":
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -650,68 +648,43 @@ class StepPooler(Pooler):
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]]]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
class ClassifierPooler(nn.Module):
|
||||
class ClassifierPooler(Pooler):
|
||||
"""A pooling layer for classification tasks.
|
||||
|
||||
This layer does the following:
|
||||
1. Applies a classification layer to the hidden states.
|
||||
2. Optionally applies a pooler layer.
|
||||
3. Applies an activation function to the output. In the case of
|
||||
classification models it is either sigmoid or softmax. In the
|
||||
case of scoring models, the same behavior is configuration
|
||||
dependent, as in the sentence-transformers library.
|
||||
3. Applies an activation function to the output.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_seq_cls(config: ModelConfig):
|
||||
return get_classification_activation_function(config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_cross_encoder(config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(config.hf_config)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
pooling: PoolingFn,
|
||||
classifier: ClassifierFn,
|
||||
act_fn: Optional[PoolerActivation] = None,
|
||||
act_fn: PoolerActivation,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.act_fn = act_fn
|
||||
|
||||
self.classification_act_fn = get_classification_activation_function(
|
||||
config.hf_config) if act_fn is None else act_fn
|
||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
||||
config.hf_config) if act_fn is None else act_fn
|
||||
|
||||
def _get_act_fn(self, task: PoolingTask):
|
||||
if task == "encode" or task == "classify":
|
||||
return self.classification_act_fn
|
||||
if task == "score":
|
||||
return self.cross_encoder_act_fn
|
||||
|
||||
raise ValueError(f"Unsupported task: {task!r}")
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "encode" or task == "classify" or task == "score":
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
if task == "embed":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools sentence pair scores from the hidden_states."""
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
# apply classifier once on the full batch if possible
|
||||
@ -722,28 +695,59 @@ class ClassifierPooler(nn.Module):
|
||||
else:
|
||||
pooled_output = [self.classifier(data) for data in pooled_data]
|
||||
|
||||
task_list: list[PoolingTask]
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
task_list = [
|
||||
task for _, pooling_param in pooling_metadata.seq_groups
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
else:
|
||||
task_list = [
|
||||
task for pooling_param in pooling_metadata.pooling_params
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
|
||||
assert len(task_list) == len(pooled_output)
|
||||
|
||||
# shape of scores: (batch_size, num_labels)
|
||||
if len(set(task_list)) <= 1:
|
||||
act_fn = self._get_act_fn(task_list[0])
|
||||
scores = act_fn(pooled_output)
|
||||
else:
|
||||
scores = torch.stack([
|
||||
self._get_act_fn(task)(vecs)
|
||||
for task, vecs in zip(task_list, pooled_output)
|
||||
])
|
||||
scores = self.act_fn(pooled_output)
|
||||
|
||||
return build_output(scores)
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||
|
||||
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
||||
super().__init__()
|
||||
|
||||
for task, pooler in poolers_by_task.items():
|
||||
if task not in pooler.get_supported_tasks():
|
||||
raise ValueError(
|
||||
f"{pooler=} does not support {task=}. "
|
||||
f"Supported tasks: {pooler.get_supported_tasks()}")
|
||||
|
||||
self.poolers_by_task = poolers_by_task
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return set(self.poolers_by_task)
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.poolers_by_task[task].get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
hidden_states_lst = hidden_states
|
||||
else:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
|
||||
|
||||
outputs = list[PoolingSequenceGroupOutput]()
|
||||
offset = 0
|
||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
raise ValueError(
|
||||
f"Unsupported task: {task} "
|
||||
f"Supported tasks: {self.get_supported_tasks()}")
|
||||
|
||||
num_items = len(list(group))
|
||||
group_output: PoolerOutput = pooler(
|
||||
hidden_states_lst[offset:offset + num_items],
|
||||
pooling_metadata[offset:offset + num_items],
|
||||
)
|
||||
|
||||
outputs.extend(group_output.outputs)
|
||||
offset += num_items
|
||||
|
||||
return PoolerOutput(outputs)
|
||||
|
@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def _create_pooling_model_cls(
|
||||
orig_cls: _T,
|
||||
*,
|
||||
default_pooling_type: "PoolingType",
|
||||
default_normalize: bool,
|
||||
default_softmax: bool,
|
||||
) -> _T:
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
@ -71,15 +62,7 @@ def _create_pooling_model_cls(
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: Support uninitialized params tracking
|
||||
@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}, )
|
||||
|
||||
ModelForEmbedding = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=True,
|
||||
default_softmax=False,
|
||||
)
|
||||
ModelForEmbedding.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
PoolingType, SimplePooler)
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod, PoolingType)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
ModelForPooling = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=False,
|
||||
default_softmax=True,
|
||||
)
|
||||
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
)
|
||||
pooling_type_str = pooler_config.pooling_type
|
||||
pooling_type = (PoolingType.LAST if pooling_type_str is None else
|
||||
PoolingType[pooling_type_str])
|
||||
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self._classifier,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
|
||||
ModelForReward = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.ALL,
|
||||
default_normalize=False,
|
||||
default_softmax=False,
|
||||
)
|
||||
class ModelForReward(_create_pooling_model_cls(cls)):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
ModelForReward.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Set
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
@ -92,20 +93,29 @@ class BertPooler(Pooler):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
add_pooling_layer: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type[nn.Module] = BertEmbedding,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "query", "q"),
|
||||
@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
if name in params_dict:
|
||||
other_weights.append((name, loaded_weight))
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["pooler."] if self.pooler is None else []),
|
||||
return other_weights, loaded_stacked_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
|
||||
loaded_params = loader.load_weights(other_weights)
|
||||
loaded_params.update(loaded_stacked_params)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertPoolingModel(BertModel):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type[nn.Module] = BertEmbedding,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=embedding_class,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(other_weights)
|
||||
loaded_params.update(loaded_stacked_params)
|
||||
return loaded_params
|
||||
@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return Pooler.from_config_with_defaults(pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
return DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"embed":
|
||||
Pooler.for_embed(
|
||||
pooler_config,
|
||||
default_pooling_type=PoolingType.CLS,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ..layers.pooler import Pooler, PoolingType
|
||||
from ..layers.pooler import DispatchPooler, Pooler
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -339,12 +339,16 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "gpt2"))
|
||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
Pooler.for_classify(pooler_config, classifier=None),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
@ -1,17 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Set
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
|
||||
PoolerNormalize,
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolerHead, PoolerNormalize,
|
||||
PoolingParamsUpdate,
|
||||
build_output, get_prompt_lens,
|
||||
get_prompt_token_ids)
|
||||
@ -135,18 +134,11 @@ class GritLMMeanPool(nn.Module):
|
||||
|
||||
return instruction_len
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "encode" or task == "embed":
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed"}
|
||||
|
||||
if task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
@ -207,10 +199,10 @@ class GritLMPooler(Pooler):
|
||||
self.pooling = GritLMMeanPool(model_config)
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
@ -262,4 +254,11 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.pooler = GritLMPooler(vllm_config.model_config)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
if pooler_config is not None:
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"embed":
|
||||
GritLMPooler(vllm_config.model_config),
|
||||
})
|
||||
|
@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -429,12 +429,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False,
|
||||
)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -19,8 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
|
||||
SimplePooler)
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -584,16 +584,15 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self.score,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
Pooler.for_classify(
|
||||
pooler_config,
|
||||
classifier=self.score,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=False,
|
||||
default_softmax=False,
|
||||
),
|
||||
})
|
||||
|
@ -12,7 +12,7 @@ from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -96,11 +96,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
|
||||
self.score = JinaVLScorer(config)
|
||||
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
Pooler.for_classify(pooler_config, classifier=None),
|
||||
"score":
|
||||
Pooler.for_classify(pooler_config, classifier=None),
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Set
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
@ -271,19 +272,27 @@ class ModernBertPooler(Pooler):
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
return self.norm(self.act(self.dense(pooled_output)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_output = self.norm(self.act(self.dense(pooled_output)))
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
@ -299,11 +308,28 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=ModernBertPooler(config),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=ModernBertPooler(config),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=ModernBertPooler(config),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
|
@ -15,7 +15,8 @@ from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
is_pooling_model = True
|
||||
pooler: SimplePooler
|
||||
pooler: Pooler
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
step_tag_id=151651,
|
||||
)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(
|
||||
pooler_config,
|
||||
default_pooling_type=PoolingType.STEP,
|
||||
default_normalize=False,
|
||||
default_softmax=True,
|
||||
default_step_tag_id=151651,
|
||||
)
|
||||
})
|
||||
|
@ -9,7 +9,8 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
@ -63,16 +64,10 @@ class RobertaEmbedding(nn.Module):
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
pos_list = []
|
||||
token_list = []
|
||||
offset = 0
|
||||
for seq_len in seq_lens:
|
||||
pos_list.append(position_ids[offset:offset + seq_len])
|
||||
token_list.append(input_ids[offset:offset + seq_len])
|
||||
offset += seq_len
|
||||
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
new_pos_list = []
|
||||
for positions, tokens in zip(pos_list, token_list):
|
||||
for positions, tokens in zip(position_ids.split(seq_lens_list),
|
||||
input_ids.split(seq_lens_list)):
|
||||
# Verify assumption that incoming position are
|
||||
# always a sequence from 0 to N.
|
||||
expected_pos = torch.arange(positions.size()[0],
|
||||
@ -184,15 +179,30 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=RobertaEmbedding,
|
||||
add_pooling_layer=False)
|
||||
embedding_class=RobertaEmbedding)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
@ -38,6 +38,13 @@ class PoolingMetadata:
|
||||
f"seq_data={self.seq_data}, "
|
||||
f"prompt_lens={self.prompt_lens})")
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
seq_groups=self.seq_groups[indices],
|
||||
seq_data=dict(list(self.seq_data.items())[indices]),
|
||||
prompt_lens=self.prompt_lens[indices],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingTensors:
|
||||
|
@ -15,3 +15,11 @@ class PoolingMetadata:
|
||||
prompt_lens: torch.Tensor
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
pooling_params: list[PoolingParams]
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
prompt_lens=self.prompt_lens[indices],
|
||||
prompt_token_ids=None if self.prompt_token_ids is None else
|
||||
self.prompt_token_ids[indices],
|
||||
pooling_params=self.pooling_params[indices],
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import copy
|
||||
import gc
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -415,15 +415,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
generator = None
|
||||
|
||||
if pooling_params:
|
||||
assert pooling_params.task is not None, (
|
||||
assert (task := pooling_params.task) is not None, (
|
||||
"You did not set `task` in the API")
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
to_update = (model.pooler.get_pooling_updates(
|
||||
pooling_params.task))
|
||||
assert to_update is not None, (
|
||||
f"{pooling_params.task=} is not supported by the model")
|
||||
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
@ -1122,10 +1118,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
@ -2247,7 +2240,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||
|
||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||
assert to_update is not None
|
||||
to_update.apply(dummy_pooling_params)
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
|
@ -3,7 +3,7 @@
|
||||
import bisect
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -491,10 +491,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
|
@ -4,7 +4,7 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
TypeVar, get_args)
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
|
@ -199,15 +199,11 @@ class PoolingModelRunner(
|
||||
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
assert pooling_params is not None
|
||||
assert pooling_params.task is not None, (
|
||||
assert (task := pooling_params.task) is not None, (
|
||||
"You did not set `task` in the API")
|
||||
|
||||
to_update = (cast(VllmModelForPooling,
|
||||
self.model).pooler.get_pooling_updates(
|
||||
pooling_params.task))
|
||||
assert to_update is not None, (
|
||||
f"{pooling_params.task=} is not supported by the model")
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
Reference in New Issue
Block a user