diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index abdce8935e..71e76abcb7 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -65,3 +65,9 @@ def test_pooling_params(llm: LLM): assert torch.allclose( softmax(wo_activation), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." + + +def test_encode_api(llm: LLM): + err_msg = "pooling_task must be one of.+" + with pytest.raises(ValueError, match=err_msg): + llm.encode(prompts, use_tqdm=False) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 886267c211..30078fe902 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -211,3 +211,18 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): assert torch.allclose( F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_pooling(server: RemoteOpenAIServer, model_name: str): + # pooling api uses ALL pooling, which does not support chunked prefill. + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float" + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 77aaddb4f5..d024c76ddd 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner, max_model_len=None, **vllm_extra_kwargs) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + if model_info.architecture: - assert (model_info.architecture - in vllm_model.llm.llm_engine.model_config.architectures) + assert model_info.architecture in model_config.architectures + assert (model_config._model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) @@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner, **vllm_extra_kwargs) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config + + if model_info.architecture: + assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 + assert (model_config._model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py new file mode 100644 index 0000000000..15e24c59d1 --- /dev/null +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + +from tests.models.language.pooling.embed_utils import ( + run_embedding_correctness_test) + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + example_prompts = example_prompts * 2 + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert cache_config.enable_prefix_caching + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +): + example_prompts = [str(s).strip() for s in example_prompts] * 2 + + with vllm_runner( + model, + runner="pooling", + max_model_len=None, + enable_prefix_caching=True, + ) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert cache_config.enable_prefix_caching + vllm_outputs = vllm_model.embed(example_prompts) + + with hf_runner( + model, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/e5-small", + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False + "papluca/xlm-roberta-base-language-detection", + ]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str) -> None: + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert not cache_config.enable_prefix_caching diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 64a8f25220..6fbe0e82d7 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -2,73 +2,78 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, LASTPoolingEmbedModelInfo, + RerankModelInfo) from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - EmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-en", + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("BAAI/bge-base-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-en", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-en", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - enable_test=True), + CLSPoolingEmbedModelInfo("BAAI/bge-m3", + architecture="XLMRobertaModel", + enable_test=True), ########## Qwen2Model - EmbedModelInfo("BAAI/bge-code-v1", - architecture="Qwen2Model", - dtype="float32", - enable_test=True), + LASTPoolingEmbedModelInfo("BAAI/bge-code-v1", + architecture="Qwen2Model", + dtype="float32", + enable_test=True), ] RERANK_MODELS = [ ########## XLMRobertaForSequenceClassification - RerankModelInfo("BAAI/bge-reranker-base", - architecture="XLMRobertaForSequenceClassification", - enable_test=True), - RerankModelInfo("BAAI/bge-reranker-large", - architecture="XLMRobertaForSequenceClassification", - enable_test=False), - RerankModelInfo("BAAI/bge-reranker-v2-m3", - architecture="XLMRobertaForSequenceClassification", - enable_test=False) + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + enable_test=True), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + enable_test=False) ] diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 7fa9485dbc..206524d7ca 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -8,12 +8,12 @@ import torch from tests.conftest import HfRunner -from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, - mteb_test_rerank_models) +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification"), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py index 9a33063d7b..8c1bc5779b 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling/test_cross_encoder.py @@ -2,13 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo, + RerankModelInfo) +from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - architecture="BertForSequenceClassification"), - RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", - architecture="Qwen3ForSequenceClassification") + CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", + architecture="BertForSequenceClassification"), + LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + architecture="Qwen3ForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 48a0cd64fe..5a5fdfbb21 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -4,57 +4,58 @@ from typing import Any import pytest -from ...utils import check_transformers_version -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo, check_transformers_version) +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - EmbedModelInfo("thenlper/gte-large", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + enable_test=False), ########### NewModel - EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), ########### Qwen2ForCausalLM - EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - architecture="Qwen2ForCausalLM", - enable_test=True), + LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True), ########## ModernBertModel - EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - architecture="ModernBertModel", - enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), ########## Qwen3ForCausalLM - EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=True), - EmbedModelInfo("Qwen/Qwen3-Embedding-4B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=False), + LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True), + LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False), ] diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py index d899aaada2..e48bdbe940 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling/test_intfloat.py @@ -2,34 +2,34 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - EmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/e5-small", + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("intfloat/e5-base", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/e5-large", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + enable_test=True), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False), ] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 59b634428c..37c5bdc97d 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -6,20 +6,22 @@ import pytest from vllm import PoolingParams -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, RerankModelInfo) from .embed_utils import (check_embeddings_close, correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - EmbedModelInfo("jinaai/jina-embeddings-v3", - architecture="XLMRobertaModel", - is_matryoshka=True) + CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3", + architecture="XLMRobertaModel", + is_matryoshka=True) ] RERANK_MODELS = [ - RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification") + CLSPoolingRerankModelInfo( + "jinaai/jina-reranker-v2-base-multilingual", + architecture="XLMRobertaForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index e74c58744d..480bd5e456 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -7,15 +7,16 @@ import torch from tests.conftest import HfRunner -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=True), - RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=False) + LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", + architecture="Qwen2ForSequenceClassification", + enable_test=True), + LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", + architecture="Qwen2ForSequenceClassification", + enable_test=False) ] diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index e16ec239a3..2d05958e9b 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -3,22 +3,23 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - enable_test=True) + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed", + architecture="NomicBertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + enable_test=True) ] diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 68e96f3270..37f5566a33 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -8,15 +8,16 @@ import torch from tests.conftest import HfRunner from tests.utils import multi_gpu_test -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("Qwen/Qwen3-Reranker-0.6B", - architecture="Qwen3ForSequenceClassification", - enable_test=True), - RerankModelInfo("Qwen/Qwen3-Reranker-4B", - architecture="Qwen3ForSequenceClassification", - enable_test=False) + LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + enable_test=True), + LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + enable_test=False) ] diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index d6b5dbd083..585fa0e683 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -3,42 +3,43 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + enable_test=True), ] diff --git a/tests/models/utils.py b/tests/models/utils.py index 11ddf45c8e..84aeb927c5 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -345,16 +345,34 @@ class EmbedModelInfo(NamedTuple): matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + default_pooling_type: str = "" enable_test: bool = True +class CLSPoolingEmbedModelInfo(EmbedModelInfo): + default_pooling_type: str = "CLS" + + +class LASTPoolingEmbedModelInfo(EmbedModelInfo): + default_pooling_type: str = "LAST" + + class RerankModelInfo(NamedTuple): name: str architecture: str = "" dtype: str = "auto" + default_pooling_type: str = "" enable_test: bool = True +class CLSPoolingRerankModelInfo(RerankModelInfo): + default_pooling_type: str = "CLS" + + +class LASTPoolingRerankModelInfo(RerankModelInfo): + default_pooling_type: str = "LAST" + + def dummy_hf_overrides( hf_config: PretrainedConfig, *, diff --git a/tests/test_config.py b/tests/test_config.py index 19b1b74e42..957771a422 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -227,6 +227,20 @@ def test_get_pooling_config_from_args(): assert asdict(pooling_config) == asdict(override_pooler_config) +@pytest.mark.parametrize( + ("model_id", "default_pooling_type", "pooling_type"), + [ + ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM + ("intfloat/e5-small", "CLS", "MEAN"), # BertModel + ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward + ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward + ]) +def test_default_pooling_type(model_id, default_pooling_type, pooling_type): + model_config = ModelConfig(model_id) + assert model_config._model_info.default_pooling_type == default_pooling_type + assert model_config.pooler_config.pooling_type == pooling_type + + @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_bert_tokenization_sentence_transformer_config(): diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 700d29f956..03ab034c62 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -871,6 +871,10 @@ class ModelConfig: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) + default_pooling_type = self._model_info.default_pooling_type + if pooler_config.pooling_type is None: + pooler_config.pooling_type = default_pooling_type + return pooler_config return None @@ -3844,6 +3848,10 @@ class VllmConfig: disable_chunked_prefill_reasons.append( "Only \"last\" pooling supports chunked " "prefill and prefix caching; disabling both.") + elif not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4767201617..41a6da709b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1600,11 +1600,10 @@ class EngineArgs: else: pooling_type = model_config.pooler_config.pooling_type - - # TODO: when encoder models are supported we'll have to - # check for causal attention here. - incremental_prefill_supported = (pooling_type is not None and - pooling_type.lower() == "last") + is_causal = getattr(model_config.hf_config, "is_causal", True) + incremental_prefill_supported = (pooling_type is not None + and pooling_type.lower() == "last" + and is_causal) action = "Enabling" if \ incremental_prefill_supported else "Disabling" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4014a961c6..915f14a29b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1100,6 +1100,10 @@ class LLM: "Try passing `--runner pooling` to use the model as a " "pooling model.") + if pooling_task not in self.supported_tasks: + raise ValueError( + f"pooling_task must be one of {self.supported_tasks}.") + if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, list[str]]], prompts), diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 0f2e58eb9b..e2162e5cbf 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -44,15 +44,14 @@ class ResolvedPoolingConfig: task: PoolingTask @classmethod - def from_config_with_defaults( + def from_config( cls, task: PoolingTask, pooler_config: PoolerConfig, - pooling_type: PoolingType, ) -> "ResolvedPoolingConfig": + assert pooler_config.pooling_type is not None return cls(task=task, - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type) + pooling_type=PoolingType[pooler_config.pooling_type]) @dataclass(frozen=True) @@ -68,32 +67,20 @@ class Pooler(nn.Module, ABC): """The interface required for all poolers used in pooling models in vLLM.""" @staticmethod - def for_encode( - pooler_config: PoolerConfig, - *, - default_pooling_type: PoolingType = PoolingType.ALL, - ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( - task="encode", - pooler_config=pooler_config, - pooling_type=default_pooling_type, - ) - - if resolved_config.pooling_type == PoolingType.STEP: + def for_encode(pooler_config: PoolerConfig): + if pooler_config.pooling_type == "STEP": return StepPooler() + resolved_config = ResolvedPoolingConfig(task="encode", + pooling_type=PoolingType.ALL) + return SimplePooler.from_config(resolved_config) @staticmethod - def for_embed( - pooler_config: PoolerConfig, - *, - default_pooling_type: PoolingType = PoolingType.LAST, - ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + def for_embed(pooler_config: PoolerConfig): + resolved_config = ResolvedPoolingConfig.from_config( task="embed", pooler_config=pooler_config, - pooling_type=default_pooling_type, ) return SimplePooler.from_config(resolved_config) @@ -102,13 +89,10 @@ class Pooler(nn.Module, ABC): def for_classify( pooler_config: PoolerConfig, classifier: Optional[ClassifierFn], - *, - default_pooling_type: PoolingType = PoolingType.LAST, ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + resolved_config = ResolvedPoolingConfig.from_config( task="classify", pooler_config=pooler_config, - pooling_type=default_pooling_type, ) pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 867de2c68b..1dbe70f84a 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T: assert pooler_config is not None pooling_type_str = pooler_config.pooling_type - pooling_type = (PoolingType.LAST if pooling_type_str is None else - PoolingType[pooling_type_str]) + assert pooling_type_str is not None + pooling_type = PoolingType[pooling_type_str] self.pooler = DispatchPooler({ "encode": diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 3d5d5d505b..6638f06f98 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces import (SupportsCrossEncoding, SupportsQuant, + default_pooling_type) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -327,6 +328,7 @@ class BertOutput(nn.Module): @support_torch_compile +@default_pooling_type("CLS") class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant): return loaded_params +@default_pooling_type("ALL") class BertPoolingModel(BertModel): is_pooling_model = True @@ -431,6 +434,7 @@ class BertPoolingModel(BertModel): return loaded_params +@default_pooling_type("CLS") class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: return DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "embed": - Pooler.for_embed( - pooler_config, - default_pooling_type=PoolingType.CLS, - ), + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), }) @@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: return token_type_ids +@default_pooling_type("CLS") class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 050f18f16e..e18b7b7ffa 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.interfaces import (SupportsQuant, + default_pooling_type) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module): @support_torch_compile +@default_pooling_type("CLS") class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b6d9877cd0..46caf3fce4 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -641,6 +641,20 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) +def default_pooling_type(pooling_type: str) -> object: + """Set default_pooling_type decorator. """ + + def func(model: object): + model.default_pooling_type = pooling_type + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") + + class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d29779a35e..d0c4bf5450 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): return loaded_params +@default_pooling_type("ALL") class InternLM2ForRewardModel(InternLM2ForCausalLM): is_pooling_model = True diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index c1033aff07..fbd310121a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateShapeCalculator) -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolingType) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM): Pooler.for_classify( pooler_config, classifier=self.score, - default_pooling_type=PoolingType.LAST, ), }) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 761fce815e..2c3bdd1c93 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import (SupportsCrossEncoding, SupportsV0Only, + default_pooling_type) from .utils import WeightsMapper, maybe_prefix @@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module): @support_torch_compile +@default_pooling_type("CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."}) @@ -264,7 +266,6 @@ class ModernBertPooler(Pooler): self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.pooling_type = config.classifier_pooling self.act = nn.GELU() self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, @@ -294,6 +295,7 @@ class ModernBertPooler(Pooler): return pooled_output +@default_pooling_type("CLS") class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, SupportsCrossEncoding): diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 9b6b70c75c..e0a30e04c6 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -15,11 +15,10 @@ from torch import nn from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolingType) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix @@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): return loader.load_weights(weights) +@default_pooling_type("ALL") class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): {"encode": Pooler.for_encode(pooler_config)}, ) +@default_pooling_type("STEP") class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode( - pooler_config, - default_pooling_type=PoolingType.STEP, - ) - }) + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index aca3d84f00..1b0c902c5e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -25,8 +25,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) -from .interfaces import (has_inner_state, has_noops, is_attention_free, - is_hybrid, supports_cross_encoding, +from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, + is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_raw_input, supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_pooling_model, is_text_generation_model @@ -305,6 +305,7 @@ class _ModelInfo: architecture: str is_text_generation_model: bool is_pooling_model: bool + default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input: bool @@ -323,6 +324,7 @@ class _ModelInfo: architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), is_pooling_model=is_pooling_model(model), + default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 005b917982..32a4a2c9a2 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, default_pooling_type class RobertaEmbedding(nn.Module): @@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module): return x +@default_pooling_type("CLS") class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): return loader.load_weights(weights_list, mapper=mapper) +@default_pooling_type("CLS") class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3cde7c6e96..045a06d927 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not is_pooling_model(model): return [] - return list(model.pooler.get_supported_tasks()) + supported_tasks = list(model.pooler.get_supported_tasks()) + + if (self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks): + supported_tasks.remove("encode") + + logger.info_once("Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it.") + + return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]()