[V0 Deprecation] Remove pooling model support in V0 (#23434)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
committed by
GitHub
parent
934bebf192
commit
2554b27baa
@ -118,6 +118,8 @@ class PPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
vllm_major_versions = ["1"] if runner == "pooling" else ["0"]
|
||||
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
@ -126,7 +128,7 @@ class PPTestSettings:
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
vllm_major_versions=["0"],
|
||||
vllm_major_versions=vllm_major_versions,
|
||||
runner=runner,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = {
|
||||
EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
||||
# [Text-only]
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
|
||||
# TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
|
||||
# is fixed
|
||||
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
|
||||
load_format="dummy", runner="pooling"
|
||||
),
|
||||
|
@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
prompts = ["The chef prepared a delicious meal."]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
|
@ -27,14 +27,6 @@ TOKEN_IDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
|
@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||
prompts = ["The chef prepared a delicious meal."]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
|
@ -14,14 +14,6 @@ from ...models.utils import softmax
|
||||
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
|
@ -32,15 +32,16 @@ MODEL_CONFIGS = [
|
||||
"tensor_parallel_size": 1,
|
||||
"tokenizer_mode": "mistral",
|
||||
},
|
||||
{
|
||||
"model": "sentence-transformers/all-MiniLM-L12-v2",
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.20,
|
||||
"max_model_len": 64,
|
||||
"max_num_batched_tokens": 64,
|
||||
"max_num_seqs": 64,
|
||||
"tensor_parallel_size": 1,
|
||||
},
|
||||
# TODO: re-enable once these tests are run with V1
|
||||
# {
|
||||
# "model": "sentence-transformers/all-MiniLM-L12-v2",
|
||||
# "enforce_eager": True,
|
||||
# "gpu_memory_utilization": 0.20,
|
||||
# "max_model_len": 64,
|
||||
# "max_num_batched_tokens": 64,
|
||||
# "max_num_seqs": 64,
|
||||
# "tensor_parallel_size": 1,
|
||||
# },
|
||||
]
|
||||
|
||||
|
||||
|
@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
|
@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
||||
|
@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
MODELS = [
|
||||
{
|
||||
"name": "BAAI/bge-reranker-v2-m3",
|
||||
|
@ -10,14 +10,6 @@ from vllm.platforms import current_platform
|
||||
from ...utils import check_embeddings_close, check_transformers_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
@ -32,21 +24,15 @@ def v1(run_with_both_engines):
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
# CPU v1 doesn't support sliding window
|
||||
marks=[pytest.mark.core_model]),
|
||||
# the qwen models interfere with each other (see PR
|
||||
# https://github.com/vllm-project/vllm/pull/18720).
|
||||
# To avoid this problem, for now we skip v0 since it will be
|
||||
# deprecated anyway.
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
|
||||
marks=[pytest.mark.cpu_model]),
|
||||
# [Encoder-only]
|
||||
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
# [Cross-Encoder]
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
],
|
||||
)
|
||||
def test_models(
|
||||
|
@ -13,14 +13,6 @@ from ....conftest import HfRunner
|
||||
from ...utils import check_transformers_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def math_step_prompts():
|
||||
# ruff: noqa: E501
|
||||
|
@ -23,15 +23,6 @@ TEXTS_2 = [
|
||||
"The capital of Germany is Berlin.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
|
@ -323,8 +323,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
trust_remote_code=True),
|
||||
@ -337,9 +337,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
trust_remote_code=True, v0_only=True),
|
||||
trust_remote_code=True),
|
||||
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||
trust_remote_code=True, v0_only=True), # noqa: E501
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
|
||||
max_transformers_version="4.53",
|
||||
@ -347,9 +347,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
@ -364,20 +364,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
|
||||
|
||||
# [Cross-encoder]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
hf_overrides={
|
||||
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
|
||||
}
|
||||
|
||||
_AUTOMATIC_CONVERTED_MODELS = {
|
||||
# Use as_seq_cls_model for automatic conversion
|
||||
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
|
||||
v0_only=True,
|
||||
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
|
||||
"classifier_from_token": ["Yes"], # noqa: E501
|
||||
"method": "no_post_processing"}), # noqa: E501
|
||||
|
@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.worker.pooling_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
|
||||
class MockAttentionBackend(AttentionBackend):
|
||||
@ -114,54 +111,3 @@ def test_model_runner_input():
|
||||
assert (received_model_input.sampling_metadata.selected_token_indices ==
|
||||
sampling_metadata.selected_token_indices)
|
||||
assert received_model_input.sampling_metadata.seq_groups is None
|
||||
|
||||
|
||||
def test_embedding_model_runner_input():
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=[[0]],
|
||||
seq_data={},
|
||||
prompt_lens=[1],
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
model_input = ModelInputForGPUWithPoolingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
pooling_metadata=pooling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (
|
||||
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input,
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
assert received_model_input.input_tokens is not None
|
||||
assert (
|
||||
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||
assert received_model_input.input_positions is not None
|
||||
assert (received_model_input.input_positions == model_input.input_positions
|
||||
).all()
|
||||
assert received_model_input.multi_modal_kwargs is None
|
||||
assert (received_model_input.multi_modal_kwargs ==
|
||||
model_input.multi_modal_kwargs)
|
||||
assert received_model_input.lora_requests is None
|
||||
assert received_model_input.lora_requests == model_input.lora_requests
|
||||
assert received_model_input.lora_mapping is None
|
||||
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(received_model_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# Pooling metadata is not broadcast.
|
||||
assert received_model_input.pooling_metadata is None
|
||||
|
@ -1591,7 +1591,6 @@ class Scheduler:
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
state=seq_group.state,
|
||||
token_type_ids=seq_group.token_type_ids,
|
||||
# `multi_modal_data` will only be present for the 1st comm
|
||||
# between engine and worker.
|
||||
# the subsequent comms can still use delta, but
|
||||
|
@ -1566,8 +1566,7 @@ class EngineArgs:
|
||||
use_spec_decode = self.speculative_config is not None
|
||||
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and model_config.runner_type != "pooling"):
|
||||
and not self.enable_lora):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models "
|
||||
@ -1585,10 +1584,6 @@ class EngineArgs:
|
||||
"OOM during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache size. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
elif (self.enable_chunked_prefill
|
||||
and model_config.runner_type == "pooling"):
|
||||
msg = "Chunked prefill is not supported for pooling models"
|
||||
raise ValueError(msg)
|
||||
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching:
|
||||
|
@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
"""A stream of RequestOutputs for a request that can be iterated over
|
||||
asynchronously via an async generator."""
|
||||
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
self.request_id = request_id
|
||||
@ -81,8 +81,7 @@ class AsyncStream:
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]) -> None:
|
||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||
if not self._finished:
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
@ -99,9 +98,7 @@ class AsyncStream:
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
async def generator(self) -> AsyncGenerator[RequestOutput, None]:
|
||||
try:
|
||||
while True:
|
||||
result = await self._queue.get()
|
||||
@ -151,8 +148,7 @@ class RequestTracker:
|
||||
self.abort_request(rid, exception=exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient):
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
async def encode(
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient):
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to the LLM. See
|
||||
[`PromptType`][vllm.inputs.PromptType] for more details about
|
||||
the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
Yields:
|
||||
The output `PoolingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
which iteratively invokes
|
||||
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
|
||||
to process the waiting requests.
|
||||
- Add the request to the engine's `RequestTracker`.
|
||||
On the next background loop, this request will be sent to
|
||||
the underlying engine.
|
||||
Also, a corresponding `AsyncStream` will be created.
|
||||
- Wait for the request outputs from `AsyncStream` and yield them.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Please refer to entrypoints/api_server.py for
|
||||
# the complete example.
|
||||
|
||||
# initialize the engine and the example input
|
||||
# note that engine_args here is AsyncEngineArgs instance
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
example_input = {
|
||||
"input": "What is LLM?",
|
||||
"request_id": 0,
|
||||
}
|
||||
|
||||
# start the generation
|
||||
results_generator = engine.encode(
|
||||
example_input["input"],
|
||||
PoolingParams(),
|
||||
example_input["request_id"])
|
||||
|
||||
# get the results
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
# Return or raise an error
|
||||
...
|
||||
final_output = request_output
|
||||
|
||||
# Process and return the final output
|
||||
...
|
||||
```
|
||||
"""
|
||||
try:
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
raise NotImplementedError(
|
||||
"Pooling models are not supported in vLLM V0")
|
||||
|
||||
async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Abort a request.
|
||||
@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient):
|
||||
async def is_sleeping(self) -> bool:
|
||||
return self.engine.is_sleeping()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.engine.add_lora(lora_request)
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.engine.add_lora(lora_request)
|
||||
|
||||
async def collective_rpc(self,
|
||||
method: str,
|
||||
|
@ -40,12 +40,11 @@ from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupBase, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
Sequence, SequenceGroup, SequenceGroupBase,
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
@ -93,8 +92,7 @@ class SchedulerContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.output_queue: Deque[OutputData] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
PoolingRequestOutput]] = []
|
||||
self.request_outputs: List[RequestOutput] = []
|
||||
self.seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
@ -261,8 +259,7 @@ class LLMEngine:
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config)
|
||||
|
||||
if self.model_config.runner_type != "pooling":
|
||||
self._initialize_kv_caches()
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
@ -541,7 +538,7 @@ class LLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: ProcessorInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
params: SamplingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
@ -577,7 +574,7 @@ class LLMEngine:
|
||||
encoder_seq = (None if encoder_inputs is None else Sequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
# Create a SequenceGroup based on SamplingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
@ -588,18 +585,8 @@ class LLMEngine:
|
||||
trace_headers=trace_headers,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
raise ValueError("SamplingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler with least unfinished seqs.
|
||||
costs = [
|
||||
@ -618,7 +605,7 @@ class LLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
@ -636,9 +623,8 @@ class LLMEngine:
|
||||
prompt: The prompt to the LLM. See
|
||||
[PromptType][vllm.inputs.PromptType]
|
||||
for more details about the format of each input.
|
||||
params: Parameters for sampling or pooling.
|
||||
params: Parameters for sampling.
|
||||
[SamplingParams][vllm.SamplingParams] for text generation.
|
||||
[PoolingParams][vllm.PoolingParams] for pooling.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
lora_request: The LoRA request to add.
|
||||
@ -760,29 +746,6 @@ class LLMEngine:
|
||||
|
||||
return seq_group
|
||||
|
||||
def _create_sequence_group_with_pooling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
priority: int = 0,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
pooling_params = pooling_params.clone()
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a request(s) with the given ID.
|
||||
|
||||
@ -856,18 +819,6 @@ class LLMEngine:
|
||||
success = success and scheduler.reset_prefix_cache(device)
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def _process_sequence_group_outputs(
|
||||
seq_group: SequenceGroup,
|
||||
outputs: List[PoolingSequenceGroupOutput],
|
||||
) -> None:
|
||||
seq_group.pooled_data = outputs[0].data
|
||||
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(self,
|
||||
ctx: SchedulerContext,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
@ -962,13 +913,10 @@ class LLMEngine:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
|
||||
if self.model_config.runner_type == "pooling":
|
||||
self._process_sequence_group_outputs(seq_group, output)
|
||||
else:
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(
|
||||
seq_group, output, is_async)
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(seq_group, output,
|
||||
is_async)
|
||||
|
||||
if seq_group.is_finished():
|
||||
finished_now.append(i)
|
||||
@ -1090,7 +1038,7 @@ class LLMEngine:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
<figure markdown="span">
|
||||
|
@ -120,6 +120,7 @@ class RPCLoadAdapterRequest:
|
||||
@dataclass
|
||||
class RPCAdapterLoadedResponse:
|
||||
request_id: str
|
||||
lora_loaded: bool
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||
|
@ -6,7 +6,7 @@ import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
|
||||
Mapping, Optional, Union, cast)
|
||||
Mapping, Optional, Union)
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
@ -477,10 +477,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
Any priority other than 0 will lead to an error if the
|
||||
scheduling policy is not "priority".
|
||||
"""
|
||||
return cast(
|
||||
AsyncGenerator[RequestOutput, None],
|
||||
self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers, priority))
|
||||
return self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers, priority)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@ -490,45 +488,20 @@ class MQLLMEngineClient(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to the LLM. See
|
||||
[`PromptType`][vllm.inputs.PromptType] for more details about
|
||||
the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
|
||||
Yields:
|
||||
The output `PoolingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
"""
|
||||
return cast(
|
||||
AsyncGenerator[PoolingRequestOutput, None],
|
||||
self._process_request(prompt,
|
||||
pooling_params,
|
||||
request_id,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority=priority))
|
||||
raise NotImplementedError(
|
||||
"Pooling models are not supported in vLLM V0")
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
PoolingRequestOutput, None]]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
@ -547,7 +520,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
try:
|
||||
# 2) Detach logits processors so that they can be pickled
|
||||
# separately (may require cloudpickle which is slower)
|
||||
if isinstance(params, SamplingParams) and params.logits_processors:
|
||||
if params.logits_processors:
|
||||
# Defensive shallow copy
|
||||
params = copy.copy(params)
|
||||
logits_processors = params.logits_processors
|
||||
@ -646,13 +619,14 @@ class MQLLMEngineClient(EngineClient):
|
||||
raise request_output
|
||||
return request_output.is_sleeping
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
# Uses the same I/O as generate requests
|
||||
request = RPCLoadAdapterRequest(lora_request)
|
||||
|
||||
# Create output queue for this request.
|
||||
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
|
||||
queue: asyncio.Queue[Union[
|
||||
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
|
||||
self.output_queues[request.request_id] = queue
|
||||
|
||||
# Send the request
|
||||
@ -666,3 +640,4 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Raise on error, otherwise happily return None
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
return request_output.lora_loaded
|
||||
|
@ -347,7 +347,7 @@ class MQLLMEngine:
|
||||
|
||||
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
|
||||
try:
|
||||
self.engine.add_lora(request.lora_request)
|
||||
lora_loaded = self.engine.add_lora(request.lora_request)
|
||||
except BaseException as e:
|
||||
# Send back an error if the adater fails to load
|
||||
rpc_err = RPCError(request_id=request.request_id,
|
||||
@ -357,7 +357,8 @@ class MQLLMEngine:
|
||||
return
|
||||
# Otherwise, send back the successful load message
|
||||
self._send_outputs(
|
||||
RPCAdapterLoadedResponse(request_id=request.request_id))
|
||||
RPCAdapterLoadedResponse(request_id=request.request_id,
|
||||
lora_loaded=lora_loaded))
|
||||
|
||||
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
|
||||
is_sleeping = self.is_sleeping()
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Iterable, Mapping, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
@ -224,6 +224,7 @@ class EngineClient(ABC):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model."""
|
||||
...
|
||||
@ -320,7 +321,7 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
...
|
||||
|
||||
|
@ -1156,8 +1156,7 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
|
||||
"token_type_ids", None)):
|
||||
if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
|
||||
params = pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
|
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -229,8 +228,7 @@ class ServingScores(OpenAIServing):
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
|
||||
"token_type_ids", None)):
|
||||
if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {
|
||||
|
@ -174,9 +174,6 @@ class TokenInputs(TypedDict):
|
||||
prompt_token_ids: list[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
@ -190,7 +187,6 @@ class TokenInputs(TypedDict):
|
||||
|
||||
def token_inputs(
|
||||
prompt_token_ids: list[int],
|
||||
token_type_ids: Optional[list[int]] = None,
|
||||
prompt: Optional[str] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> TokenInputs:
|
||||
@ -200,8 +196,6 @@ def token_inputs(
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if token_type_ids is not None:
|
||||
inputs["token_type_ids"] = token_type_ids
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
|
@ -355,7 +355,6 @@ class InputPreprocessor:
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
@ -368,10 +367,7 @@ class InputPreprocessor:
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
@ -387,7 +383,6 @@ class InputPreprocessor:
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
@ -400,10 +395,7 @@ class InputPreprocessor:
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
@ -13,17 +13,12 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import current_stream, resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingCursor
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]]]
|
||||
@ -127,36 +122,23 @@ def get_prompt_lens(
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states[0].device).prompt_lens
|
||||
return pooling_metadata.prompt_lens
|
||||
|
||||
|
||||
def get_prompt_token_ids(
|
||||
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
assert pooling_metadata.prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`")
|
||||
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
assert pooling_metadata.prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`")
|
||||
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
|
||||
|
||||
def get_pooling_params(
|
||||
pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
pooling_params = [p for _, p in pooling_metadata.seq_groups]
|
||||
else:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
return pooling_params
|
||||
|
||||
|
||||
|
@ -24,9 +24,9 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
@ -15,10 +15,10 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
build_output, get_prompt_lens,
|
||||
get_prompt_token_ids)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import PoolerOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
@ -22,9 +22,9 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
@ -1,90 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
|
||||
|
||||
|
||||
class PoolingMetadata:
|
||||
"""Metadata for pooling operations in the Pooler layer.
|
||||
|
||||
This class holds the necessary information for pooling operations,
|
||||
providing context for how to perform pooling and other related operations.
|
||||
|
||||
Attributes:
|
||||
seq_groups: List of (seq_ids, pooling_params).
|
||||
seq_data: A mapping of sequence ID to additional sequence data.
|
||||
prompt_lens: List of the lengths of each prompt.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: list[tuple[list[int], PoolingParams]],
|
||||
seq_data: dict[int, Any], # Specific data related to sequences
|
||||
prompt_lens: list[int],
|
||||
pooling_cursor: Optional[PoolingCursor] = None) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("PoolingMetadata("
|
||||
f"seq_groups={self.seq_groups}, "
|
||||
f"seq_data={self.seq_data}, "
|
||||
f"prompt_lens={self.prompt_lens})")
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
seq_groups=self.seq_groups[indices],
|
||||
seq_data=dict(list(self.seq_data.items())[indices]),
|
||||
prompt_lens=self.prompt_lens[indices],
|
||||
pooling_cursor=None
|
||||
if self.pooling_cursor is None else self.pooling_cursor[indices],
|
||||
)
|
||||
|
||||
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
|
||||
device: torch.device):
|
||||
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
|
||||
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
|
||||
prompt_lens,
|
||||
device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingTensors:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def from_pooling_metadata(
|
||||
cls,
|
||||
pooling_metadata: "PoolingMetadata",
|
||||
device: torch.device,
|
||||
) -> "PoolingTensors":
|
||||
"""
|
||||
Create PoolingTensors from PoolingMetadata.
|
||||
|
||||
Args:
|
||||
pooling_metadata: PoolingMetadata instance to convert.
|
||||
device: Device to store the tensors.
|
||||
"""
|
||||
# Convert prompt lengths to tensor
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
prompt_lens_t = torch.tensor(
|
||||
pooling_metadata.prompt_lens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
return cls(prompt_lens=prompt_lens_t.to(device=device,
|
||||
non_blocking=True), )
|
@ -913,9 +913,6 @@ class MultiModalInputs(TypedDict):
|
||||
prompt_token_ids: list[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
mm_kwargs: MultiModalKwargsOptionalItems
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
@ -946,6 +943,3 @@ class MultiModalEncDecInputs(MultiModalInputs):
|
||||
|
||||
encoder_prompt_token_ids: list[int]
|
||||
"""The processed token IDs of the encoder prompt."""
|
||||
|
||||
encoder_token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the encoder prompt."""
|
||||
|
@ -508,12 +508,6 @@ class Sequence:
|
||||
return [0] * len(self.inputs["prompt_embeds"])
|
||||
return self.inputs["prompt_token_ids"]
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> list[int]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return []
|
||||
return self.inputs.get("token_type_ids", [])
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> MultiModalKwargs:
|
||||
if self.inputs["type"] == "multimodal":
|
||||
@ -765,10 +759,6 @@ class SequenceGroup:
|
||||
return (self.encoder_seq.prompt_token_ids
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> Optional[list[int]]:
|
||||
return self.first_seq.token_type_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> MultiModalKwargs:
|
||||
if self.first_seq.multi_modal_data:
|
||||
@ -972,7 +962,6 @@ class SequenceGroupMetadata(
|
||||
computed_block_nums: Optional[list[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
token_type_ids: Optional[list[int]] = None
|
||||
multi_modal_data: Optional[MultiModalKwargs] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
|
@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in "
|
||||
"EncoderDecoderModelRunner")
|
||||
|
@ -86,7 +86,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
@ -192,7 +191,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.inputs_embeds = None # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
@ -219,7 +217,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
@ -284,12 +281,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_positions[seq_id].clear()
|
||||
|
||||
if token_types:
|
||||
self.token_types = token_types
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.token_types[seq_id].clear()
|
||||
|
||||
self.mrope_input_positions = None
|
||||
|
||||
if seq_lens:
|
||||
@ -348,7 +339,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.input_tokens = input_tokens or []
|
||||
self.inputs_embeds = inputs_embeds
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
@ -376,7 +366,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||
self.token_types = [[] for _ in range(self.n_seqs)]
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
@ -400,7 +389,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
f"inputs_embeds.shape="
|
||||
f"{getattr(self.inputs_embeds, 'shape', None)}, "
|
||||
f"input_positions={self.input_positions}, "
|
||||
f"token_types={self.token_types}, "
|
||||
f"mrope_input_positions={self.mrope_input_positions}, "
|
||||
f"seq_lens={self.seq_lens}, "
|
||||
f"orig_seq_lens={self.orig_seq_lens}, "
|
||||
@ -522,8 +510,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
prompt_embeds = seq_data.get_token_embeddings(
|
||||
)[context_len:seq_len]
|
||||
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
|
||||
@ -531,8 +517,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.token_types[seq_idx].extend(
|
||||
token_types if token_types else [])
|
||||
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
@ -590,8 +574,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
seq_idx][uncomputed_start:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][uncomputed_start:]
|
||||
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||
uncomputed_start:]
|
||||
context_len = prefix_cache_len
|
||||
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
@ -606,8 +588,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
seq_idx][-1:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][-1:]
|
||||
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||
-1:]
|
||||
inter_data.query_lens[seq_idx] = 1
|
||||
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
|
||||
|
||||
@ -802,12 +782,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = list[int]()
|
||||
inputs_embeds_list = list[torch.Tensor]()
|
||||
token_types = list[int]()
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
if inter_data.inputs_embeds is not None:
|
||||
inputs_embeds_list.append(
|
||||
inter_data.inputs_embeds.to(
|
||||
@ -890,11 +867,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
|
||||
token_types_tensor = async_tensor_h2d(token_types, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory) \
|
||||
if token_types else None
|
||||
|
||||
if mrope_input_positions is not None:
|
||||
for idx in range(3):
|
||||
mrope_input_positions[idx].extend(
|
||||
@ -951,7 +923,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
input_tokens=input_tokens_tensor,
|
||||
inputs_embeds=inputs_embeds,
|
||||
input_positions=input_positions_tensor,
|
||||
token_types=token_types_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
|
@ -13,10 +13,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from vllm.model_executor.models.interfaces_base import is_text_generation_model
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.tasks import GenerationTask, SupportedTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks = list[SupportedTask]()
|
||||
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.get_supported_generation_tasks())
|
||||
if self.model_config.runner_type == "pooling":
|
||||
tasks.extend(self.get_supported_pooling_tasks())
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
|
@ -1,222 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
|
||||
ModelInputForGPUBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the PoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class PoolingModelRunner(
|
||||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"PoolingModelRunner does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
virtual_engine = model_input.virtual_engine
|
||||
# Pooling models are (ab-)used also to integrate non text models that
|
||||
# are not autoregressive (PrithviGeosaptialMAE).
|
||||
# These model might not use attention and do not really have a prefill
|
||||
# and decode phase. The model input is processed in one shot and both
|
||||
# decode_metadata and prefill_metadata would be None for such models.
|
||||
# See the PlaceholderAttentionMetadata class.
|
||||
# TODO: Figure out if cuda_graph is of any use for these models and
|
||||
# explore how to leverage it.
|
||||
if (prefill_meta is None and decode_meta is not None
|
||||
and decode_meta.use_cuda_graph):
|
||||
if model_input.inputs_embeds is None:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_types is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
**seqlen_agnostic_kwargs,
|
||||
)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Only perform pooling in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
pooling_metadata = model_input.pooling_metadata
|
||||
assert pooling_metadata is not None
|
||||
|
||||
pooling_metadata.build_pooling_cursor(
|
||||
num_scheduled_tokens=pooling_metadata.prompt_lens,
|
||||
device=hidden_or_intermediate_states.device)
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForGPUWithPoolingMetadata:
|
||||
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
assert pooling_params is not None
|
||||
|
||||
task = pooling_params.task
|
||||
assert task is not None, "You did not set `task` in the API"
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
@ -30,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
@ -83,9 +82,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
if self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=self.vllm_config,
|
||||
@ -99,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
# Initialize gpu_cache as pooling models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
|
Reference in New Issue
Block a user