mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
|
||||
python examples/offline_inference/pooling/embed_matryoshka_fy.py
|
||||
```
|
||||
|
||||
## Multi vector retrieval usage
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/pooling/multi_vector_retrieval.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
|
56
examples/offline_inference/pooling/multi_vector_retrieval.py
Normal file
56
examples/offline_inference/pooling/multi_vector_retrieval.py
Normal file
@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="BAAI/bge-m3",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
print(len(embeds))
|
||||
|
||||
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed")
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
multi_vector = output.outputs.data
|
||||
print(multi_vector.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
@ -40,7 +40,7 @@ def main():
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
pooling_params = PoolingParams(task="token_classify", activation=False)
|
||||
pooler_output = llm.encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
|
@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py
|
||||
python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||
```
|
||||
|
||||
## Multi vector retrieval usage
|
||||
|
||||
```bash
|
||||
python examples/online_serving/pooling/multi_vector_retrieval_client.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
|
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Example online usage of Pooling API for multi vector retrieval.
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve BAAI/bge-m3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompt = {"model": model_name, "input": prompts}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
for output in pooling_response.json()["data"]:
|
||||
multi_vector = torch.tensor(output["data"])
|
||||
print(multi_vector.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
@ -1011,8 +1011,12 @@ class VllmRunner:
|
||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def encode(self, prompts: list[str]) -> list[list[float]]:
|
||||
req_outputs = self.llm.encode(prompts)
|
||||
def token_embed(self, prompts: list[str]) -> list[list[float]]:
|
||||
req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
|
||||
return [req_output.outputs.data for req_output in req_outputs]
|
||||
|
||||
def token_classify(self, prompts: list[str]) -> list[list[float]]:
|
||||
req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
|
||||
return [req_output.outputs.data for req_output in req_outputs]
|
||||
|
||||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||||
|
@ -63,7 +63,7 @@ def test_encode_api(llm: LLM):
|
||||
# chunked prefill does not support all pooling
|
||||
err_msg = "pooling_task must be one of.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompts, use_tqdm=False)
|
||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||
|
||||
|
||||
def test_score_api(llm: LLM):
|
||||
|
@ -35,6 +35,13 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
||||
multi_vector = outputs[0].outputs.data
|
||||
assert multi_vector.shape == (11, 384)
|
||||
|
||||
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(normalize):
|
||||
outputs = llm.embed(
|
||||
|
@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
|
||||
]
|
||||
|
||||
# Multiple PoolingParams should be matched with each prompt
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# Exception raised, if the size of params does not match the size of prompts
|
||||
with pytest.raises(ValueError):
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
|
||||
outputs = llm.encode(
|
||||
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
|
||||
)
|
||||
|
||||
# Single PoolingParams should be applied to every prompt
|
||||
single_pooling_params = PoolingParams()
|
||||
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
|
||||
outputs = llm.encode(
|
||||
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
|
||||
)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# pooling_params is None, default params should be applied
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None)
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
|
||||
|
@ -36,22 +36,23 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(softmax):
|
||||
def get_outputs(activation):
|
||||
outputs = llm.reward(
|
||||
prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False
|
||||
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
|
||||
)
|
||||
return torch.cat([x.outputs.data for x in outputs])
|
||||
|
||||
default = get_outputs(softmax=None)
|
||||
w_softmax = get_outputs(softmax=True)
|
||||
wo_softmax = get_outputs(softmax=False)
|
||||
default = get_outputs(activation=None)
|
||||
w_activation = get_outputs(activation=True)
|
||||
wo_activation = get_outputs(activation=False)
|
||||
|
||||
assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax."
|
||||
assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), (
|
||||
"wo_softmax should not use softmax."
|
||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||
"Default should use activation."
|
||||
)
|
||||
assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), (
|
||||
"w_softmax should be close to softmax(wo_softmax)."
|
||||
assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
|
||||
"wo_activation should not use activation."
|
||||
)
|
||||
assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
|
||||
"w_activation should be close to activation(wo_activation)."
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
EmbeddingResponse,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
|
||||
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
|
||||
"w_normal should be close to normal(wo_normal)."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||
)
|
||||
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 11
|
||||
assert len(poolings.data[0].data[0]) == 384
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import RerankResponse
|
||||
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
DTYPE = "bfloat16"
|
||||
@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
||||
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
|
||||
"w_activation should be close to activation(wo_activation)."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||
)
|
||||
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 11
|
||||
assert len(poolings.data[0].data[0]) == 1
|
||||
|
45
tests/models/language/pooling/test_multi_vector_retrieval.py
Normal file
45
tests/models/language/pooling/test_multi_vector_retrieval.py
Normal file
@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["BAAI/bge-m3"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str):
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
auto_cls=AutoModel,
|
||||
) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
for prompt in example_prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
output = hf_model.model(**inputs)
|
||||
embedding = output.last_hidden_state[0].float()
|
||||
# normal
|
||||
hf_outputs.append(embedding.cpu())
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_output,
|
||||
embeddings_1_lst=vllm_output,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
@ -93,7 +93,7 @@ def test_embed_models_using_normalize(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_reward_models_using_softmax(
|
||||
def test_reward_models_using_activation(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -104,22 +104,64 @@ def test_reward_models_using_softmax(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(softmax=False),
|
||||
pooler_config=PoolerConfig(activation=False),
|
||||
) as vllm_model:
|
||||
wo_softmax = vllm_model.encode(example_prompts)
|
||||
wo_activation = vllm_model.reward(example_prompts)
|
||||
|
||||
with vllm_runner(
|
||||
model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True)
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(activation=True),
|
||||
) as vllm_model:
|
||||
w_softmax = vllm_model.encode(example_prompts)
|
||||
w_activation = vllm_model.reward(example_prompts)
|
||||
|
||||
for wo, w in zip(wo_softmax, w_softmax):
|
||||
for wo, w in zip(wo_activation, w_activation):
|
||||
wo = torch.tensor(wo)
|
||||
w = torch.tensor(w)
|
||||
|
||||
assert not torch.allclose(wo, w, atol=1e-2), (
|
||||
"pooler_config softmax is not working"
|
||||
"pooler_config activation is not working"
|
||||
)
|
||||
assert torch.allclose(softmax(wo), w, atol=1e-2), (
|
||||
"w_softmax should be close to softmax(wo_softmax)."
|
||||
"w_activation should be close to activation(wo_activation)."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"intfloat/multilingual-e5-small",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_multi_vector_retrieval_models_using_normalize(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(normalize=False),
|
||||
) as vllm_model:
|
||||
wo_normalize = vllm_model.token_embed(example_prompts)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(normalize=True),
|
||||
) as vllm_model:
|
||||
w_normalize = vllm_model.token_embed(example_prompts)
|
||||
|
||||
for wo, w in zip(wo_normalize, w_normalize):
|
||||
assert not torch.allclose(wo, w, atol=1e-2), (
|
||||
"pooler_config normalize is not working"
|
||||
)
|
||||
assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), (
|
||||
"w_normal should be close to normal(wo_normal)."
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ def test_bert_models(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||
@ -50,7 +50,7 @@ def test_modernbert_models(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||
|
@ -39,7 +39,7 @@ def _run_test(
|
||||
max_num_seqs=32,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
vllm_model.encode(prompt)
|
||||
vllm_model.llm.encode(prompt, pooling_task="token_classify")
|
||||
|
||||
|
||||
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
|
@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
|
||||
@ -15,6 +17,15 @@ EMBEDDING_MODELS = [
|
||||
),
|
||||
]
|
||||
|
||||
classify_parameters = ["activation"]
|
||||
embed_parameters = ["dimensions", "normalize"]
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
|
||||
|
||||
@dataclass()
|
||||
class MockModelConfig:
|
||||
pooler_config: PoolerConfig
|
||||
|
||||
|
||||
def test_task():
|
||||
pooling_params = PoolingParams()
|
||||
@ -24,25 +35,27 @@ def test_task():
|
||||
pooling_params.verify(task="score")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params.verify(task="encode")
|
||||
pooling_params.verify(task="classify")
|
||||
|
||||
|
||||
def test_embed():
|
||||
task = "embed"
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=True)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=False)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = ["activation", "softmax"]
|
||||
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||
for p in invalid_parameters:
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
||||
|
||||
@pytest.mark.parametrize("task", ["score", "classify"])
|
||||
def test_classify(task):
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(activation=None)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=True)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = ["dimensions", "normalize", "softmax"]
|
||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||
for p in invalid_parameters:
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
|
||||
def test_encode():
|
||||
task = "encode"
|
||||
pooling_params = PoolingParams(softmax=None)
|
||||
pooling_params.verify(task=task)
|
||||
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||
def test_token_embed(pooling_type: str):
|
||||
task = "token_embed"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(softmax=True)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(softmax=False)
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params = PoolingParams(normalize=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = classify_parameters
|
||||
if pooling_type != "STEP":
|
||||
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||
|
||||
invalid_parameters = ["dimensions", "normalize", "activation"]
|
||||
for p in invalid_parameters:
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||
def test_token_classify(pooling_type: str):
|
||||
task = "token_classify"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = embed_parameters
|
||||
if pooling_type != "STEP":
|
||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||
|
||||
for p in invalid_parameters:
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
@ -951,7 +951,7 @@ class LLM:
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
pooling_task: PoolingTask | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
@ -986,25 +986,24 @@ class LLM:
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
|
||||
if self.supported_tasks == ["encode"] and pooling_task is None:
|
||||
pooling_task = "encode"
|
||||
error_str = (
|
||||
"pooling_task required for `LLM.encode`\n"
|
||||
"Please use one of the more specific methods or set the "
|
||||
"pooling_task when using `LLM.encode`:\n"
|
||||
" - For embeddings, use `LLM.embed(...)` "
|
||||
'or `pooling_task="embed"`.\n'
|
||||
" - For classification logits, use `LLM.classify(...)` "
|
||||
'or `pooling_task="classify"`.\n'
|
||||
" - For similarity scores, use `LLM.score(...)`.\n"
|
||||
" - For rewards, use `LLM.reward(...)` "
|
||||
'or `pooling_task="token_classify"`\n'
|
||||
" - For token classification, "
|
||||
'use `pooling_task="token_classify"`\n'
|
||||
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
|
||||
)
|
||||
|
||||
if pooling_task is None:
|
||||
pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
|
||||
|
||||
logger.warning_once(
|
||||
"`LLM.encode` is currently using `pooling_task = %s`.\n"
|
||||
"Please use one of the more specific methods or set the "
|
||||
"task directly when using `LLM.encode`:\n"
|
||||
" - For embeddings, use `LLM.embed(...)` "
|
||||
'or `pooling_task="embed"`.\n'
|
||||
" - For classification logits, use `LLM.classify(...)` "
|
||||
'or `pooling_task="classify"`.\n'
|
||||
" - For rewards, use `LLM.reward(...)` "
|
||||
'or `pooling_task="reward"`\n'
|
||||
" - For similarity scores, use `LLM.score(...)`.",
|
||||
pooling_task,
|
||||
)
|
||||
raise ValueError(error_str)
|
||||
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
@ -1206,7 +1205,7 @@ class LLM:
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
pooling_task="encode",
|
||||
pooling_task="token_classify",
|
||||
)
|
||||
|
||||
def _embedding_score(
|
||||
|
@ -1748,16 +1748,19 @@ async def init_app_state(
|
||||
else None
|
||||
)
|
||||
state.openai_serving_pooling = (
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
)
|
||||
if "encode" in supported_tasks
|
||||
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_embedding = (
|
||||
|
@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
When using plugins IOProcessor plugins, the actual input is processed
|
||||
by the plugin itself. Hence, we use a generic type for the request data
|
||||
"""
|
||||
softmax: bool = True
|
||||
activation: bool = False
|
||||
|
||||
embed_dtype: str = Field(
|
||||
default="float32",
|
||||
@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(task="encode", softmax=self.softmax)
|
||||
return PoolingParams(task="token_classify", activation=self.activation)
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
|
@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
if "token_embed" in self.supported_tasks:
|
||||
pooling_task = "token_embed"
|
||||
elif "token_classify" in self.supported_tasks:
|
||||
pooling_task = "token_classify"
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"pooling_task must be one of {self.supported_tasks}."
|
||||
)
|
||||
|
||||
try:
|
||||
pooling_params.verify("encode", self.model_config)
|
||||
pooling_params.verify(pooling_task, self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
@ -64,66 +64,6 @@ class PoolingParamsUpdate:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def for_encode(pooler_config: PoolerConfig):
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler()
|
||||
|
||||
resolved_config = ResolvedPoolingConfig(
|
||||
task="encode", pooling_type=PoolingType.ALL
|
||||
)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
@staticmethod
|
||||
def for_embed(pooler_config: PoolerConfig):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="embed",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_prompt_lens(
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
|
||||
return self.fn(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def for_token_embed(pooler_config: PoolerConfig):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_token_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_embed(pooler_config: PoolerConfig):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="embed",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SimplePooler(pooling=pooling, head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
|
||||
# Load ST projector if available
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: nn.Module | None = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class RewardPoolerHead(PoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
|
||||
else:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
# for softmax
|
||||
flags = [p.softmax for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class SimplePooler(Pooler):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
@ -513,20 +495,6 @@ class SimplePooler(Pooler):
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
pooler_config: ResolvedPoolingConfig,
|
||||
) -> "SimplePooler":
|
||||
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
|
||||
if pooler_config.task == "embed":
|
||||
head = EmbeddingPoolerHead()
|
||||
elif pooler_config.task == "encode":
|
||||
head = RewardPoolerHead()
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
|
||||
return cls(pooling, head)
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -549,58 +517,6 @@ class SimplePooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = RewardPoolerHead()
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
"""A pooling layer for classification tasks.
|
||||
|
||||
@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_seq_cls(config: ModelConfig):
|
||||
return get_classification_activation_function(config.hf_config)
|
||||
def act_fn_for_seq_cls(model_config: ModelConfig):
|
||||
return get_classification_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_cross_encoder(config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(config.hf_config)
|
||||
def act_fn_for_cross_encoder(model_config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def resolve_act_fn(
|
||||
model_config: ModelConfig,
|
||||
static_num_labels: bool = True,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "classify":
|
||||
return ClassifierPooler.act_fn_for_seq_cls(model_config)
|
||||
elif act_fn == "score":
|
||||
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
|
||||
else:
|
||||
raise ValueError(f"act_fn [{act_fn=}] not supported.")
|
||||
elif act_fn is None:
|
||||
return PoolerClassify(static_num_labels=static_num_labels)
|
||||
else:
|
||||
assert callable(act_fn)
|
||||
return act_fn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: PoolingFn,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.act_fn = act_fn or PoolerClassify()
|
||||
self.act_fn = self.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
|
||||
return scores
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
||||
) -> torch.Tensor:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if pooling_param.normalize:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.classifier = classifier
|
||||
self.act_fn = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_param: PoolingParams,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(hidden_states)
|
||||
else:
|
||||
scores = hidden_states
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.activation:
|
||||
scores = self.act_fn(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||
|
||||
|
@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
model_config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
params_dtype=vllm_config.model_config.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooling_type_str = pooler_config.pooling_type
|
||||
assert pooling_type_str is not None
|
||||
pooling_type = PoolingType[pooling_type_str]
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": SPLADESparsePooler(
|
||||
mlm_head=self.mlm_head,
|
||||
cls_token_id=cls_id,
|
||||
@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
|
||||
if pooler_config is not None:
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config),
|
||||
}
|
||||
)
|
||||
|
@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config,
|
||||
classifier=self.score,
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"score": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
|
||||
@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of BertModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
"""A model that uses Roberta to provide embedding functionalities."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
|
@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
@ -30,7 +30,6 @@ class PoolingParams(
|
||||
if model support matryoshka representation.
|
||||
activation: Whether to apply activation function to
|
||||
the classification outputs.
|
||||
softmax: Whether to apply softmax to the reward outputs.
|
||||
"""
|
||||
|
||||
# --8<-- [start:common-pooling-params]
|
||||
@ -48,32 +47,19 @@ class PoolingParams(
|
||||
activation: bool | None = None
|
||||
# --8<-- [end:classification-pooling-params]
|
||||
|
||||
## for reward models
|
||||
softmax: bool | None = None
|
||||
## for step pooling models
|
||||
step_tag_id: int | None = None
|
||||
returned_token_ids: list[int] | None = None
|
||||
|
||||
## Internal use only
|
||||
task: PoolingTask | None = None
|
||||
"""Internal use only."""
|
||||
|
||||
requires_token_ids: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
extra_kwargs: dict[str, Any] | None = None
|
||||
"""Internal use only."""
|
||||
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@property
|
||||
def all_parameters(self) -> list[str]:
|
||||
return [
|
||||
"dimensions",
|
||||
"normalize",
|
||||
"activation",
|
||||
"softmax",
|
||||
"step_tag_id",
|
||||
"returned_token_ids",
|
||||
]
|
||||
return ["dimensions", "normalize", "activation"]
|
||||
|
||||
@property
|
||||
def valid_parameters(self):
|
||||
@ -81,7 +67,8 @@ class PoolingParams(
|
||||
"embed": ["dimensions", "normalize"],
|
||||
"classify": ["activation"],
|
||||
"score": ["activation"],
|
||||
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
|
||||
"token_embed": ["dimensions", "normalize"],
|
||||
"token_classify": ["activation"],
|
||||
}
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
@ -100,7 +87,6 @@ class PoolingParams(
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
# which is not available in model config. So, it's not included
|
||||
# in this method
|
||||
|
||||
self._merge_default_parameters(model_config)
|
||||
self._set_default_parameters(model_config)
|
||||
self._verify_valid_parameters()
|
||||
@ -125,8 +111,34 @@ class PoolingParams(
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||
|
||||
def _verify_step_pooling(
|
||||
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
|
||||
):
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
if pooler_config.pooling_type != "STEP":
|
||||
invalid_parameters = []
|
||||
for k in step_pooling_parameters:
|
||||
if getattr(self, k, None) is not None:
|
||||
invalid_parameters.append(k)
|
||||
|
||||
if invalid_parameters:
|
||||
raise ValueError(
|
||||
f"Task {self.task} only supports {valid_parameters} "
|
||||
f"parameters, does not support "
|
||||
f"{invalid_parameters} parameters"
|
||||
)
|
||||
else:
|
||||
for k in step_pooling_parameters:
|
||||
if getattr(pooler_config, k, None) is None:
|
||||
continue
|
||||
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
||||
if self.task == "embed":
|
||||
if self.task in ["embed", "token_embed"]:
|
||||
if self.normalize is None:
|
||||
self.normalize = True
|
||||
|
||||
@ -150,13 +162,9 @@ class PoolingParams(
|
||||
elif self.dimensions < 1:
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
elif self.task in ["classify", "score"]:
|
||||
elif self.task in ["classify", "score", "token_classify"]:
|
||||
if self.activation is None:
|
||||
self.activation = True
|
||||
|
||||
elif self.task == "encode":
|
||||
if self.softmax is None:
|
||||
self.softmax = True
|
||||
else:
|
||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||
|
||||
@ -185,7 +193,6 @@ class PoolingParams(
|
||||
f"normalize={self.normalize}, "
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"activation={self.activation}, "
|
||||
f"softmax={self.softmax}, "
|
||||
f"step_tag_id={self.step_tag_id}, "
|
||||
f"returned_token_ids={self.returned_token_ids}, "
|
||||
f"requires_token_ids={self.requires_token_ids}, "
|
||||
|
@ -5,7 +5,7 @@ from typing import Literal, get_args
|
||||
GenerationTask = Literal["generate", "transcription"]
|
||||
GENERATION_TASKS = get_args(GenerationTask)
|
||||
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
|
||||
POOLING_TASKS = get_args(PoolingTask)
|
||||
|
||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||
|
@ -1926,15 +1926,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
supported_tasks = list(model.pooler.get_supported_tasks())
|
||||
|
||||
if (
|
||||
self.scheduler_config.chunked_prefill_enabled
|
||||
and "encode" in supported_tasks
|
||||
):
|
||||
supported_tasks.remove("encode")
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
if "token_embed" in supported_tasks:
|
||||
supported_tasks.remove("token_embed")
|
||||
if "token_classify" in supported_tasks:
|
||||
supported_tasks.remove("token_classify")
|
||||
|
||||
logger.debug_once(
|
||||
"Chunked prefill is not supported with "
|
||||
"encode task which using ALL pooling. "
|
||||
"token_embed and token_classify tasks "
|
||||
"which using ALL pooling. "
|
||||
"Please turn off chunked prefill by "
|
||||
"`--no-enable-chunked-prefill` before using it."
|
||||
)
|
||||
|
Reference in New Issue
Block a user