Feature/vllm/input embedding completion api (#17590)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Nan2018 <nan@protopia.ai>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
Co-authored-by: Andrew Sansom <qthequartermasterman@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nan Qin
2025-05-18 22:18:05 -05:00
committed by GitHub
parent 9da1095daf
commit 221cfc2fea
10 changed files with 637 additions and 40 deletions

View File

@ -119,6 +119,7 @@ serving/offline_inference
serving/openai_compatible_server
serving/serve_args
serving/multimodal_inputs
serving/prompt_embeds
serving/distributed_serving
serving/metrics
serving/engine_args

View File

@ -0,0 +1,142 @@
# Prompt Embedding Inputs
This page teaches you how to pass prompt embedding inputs to vLLM.
## What are prompt embeddings?
The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary.
:::{note}
Prompt embeddings are currently only supported in the v0 engine.
:::
## Offline Inference
To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`:
- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model.
### Hugging Face Transformers Inputs
You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples:
```python
from vllm import LLM
import transformers
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
llm = LLM(model=model_name, enable_prompt_embeds=True)
# Refer to the HuggingFace repo for the correct format to use
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
prompt_embeds = embedding_layer(token_ids).squeeze(0)
# Single prompt inference
outputs = llm.generate({
"prompt_embeds": prompt_embeds,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Batch inference
chats = [
[{"role": "user", "content": "Please tell me about the capital of France."}],
[{"role": "user", "content": "When is the day longest during the year?"}],
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}]
]
token_ids_list = [
tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats
]
prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list]
outputs = llm.generate(
[
{
"prompt_embeds": prompt_embeds,
} for prompt_embeds in prompt_embeds_list
]
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
## Online Serving
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package.
When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first.
Prompt embeddings are passed in as base64 encoded torch tensors.
### Transformers Inputs via OpenAI Client
First, launch the OpenAI-compatible server:
```bash
vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \
--max-model-len 4096 --enable-prompt-embeds
```
Then, you can use the OpenAI client as follows:
```python
from openai import OpenAI
import transformers
import torch
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
# Refer to the HuggingFace repo for the correct format to use
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
prompt_embeds = embedding_layer(token_ids).squeeze(0)
# Prompt embeddings
buffer = io.BytesIO()
torch.save(prompt_embeds, buffer)
buffer.seek(0)
binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
completion = client_with_prompt_embeds.completions.create(
model=model_name,
# NOTE: The OpenAI client does not allow `None` as an input to
# `prompt`. Use an empty string if you have no text prompts.
prompt="",
max_tokens=5,
temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the
# `extra_body` argument.
extra_body={"prompt_embeds": encoded_embeds}
)
print(completion.choices[0].text)
```

View File

@ -0,0 +1,257 @@
# SPDX-License-Identifier: Apache-2.0
import base64
import io
import shutil
from tempfile import TemporaryDirectory
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoConfig, AutoTokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
LORA_NAME = "typeof/zephyr-7b-beta-lora"
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module")
def default_server_args(
zephyr_lora_files,
zephyr_lora_added_tokens_files,
) -> list[str]:
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# Prompt Embeds server args
"--enable-prompt-embeds",
"--no-enable-chunked-prefill",
]
@pytest.fixture(scope="module",
params=["", "--disable-frontend-multiprocessing"])
def server_with_prompt_embeds(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_with_prompt_embeds(server_with_prompt_embeds):
async with server_with_prompt_embeds.get_async_client() as async_client:
yield async_client
def create_dummy_embeds(num_tokens: int = 5) -> str:
"""Create dummy embeddings and return them as base64 encoded string."""
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
buffer = io.BytesIO()
torch.save(dummy_embeds, buffer)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
# Test case: Single prompt embeds input
encoded_embeds = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
# Test case: batch completion with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
assert len(completion.choices) == 2
assert len(completion.choices[0].text) >= 1
assert len(completion.choices[1].text) >= 1
# Test case: streaming with prompt_embeds
encoded_embeds = create_dummy_embeds()
single_completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
single_output = single_completion.choices[0].text
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": encoded_embeds})
chunks = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
# Test case: batch streaming with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
chunks_stream_embeds: list[list[str]] = [[], []]
finish_reason_count = 0
async for chunk in stream:
chunks_stream_embeds[chunk.choices[0].index].append(
chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 2
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert len(chunks_stream_embeds[0]) > 0
assert len(chunks_stream_embeds[1]) > 0
# Test case: mixed text and prompt_embeds
encoded_embeds = create_dummy_embeds()
completion_mixed = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
assert len(completion.choices) == 2
completion_text_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
)
completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
# Embeddings responses should be handled first
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
0].text
assert completion_mixed.choices[1].text == completion_text_only.choices[
0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_errors_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
# Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create(
prompt="",
model=model_name,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": "invalid_base64"})
@pytest.mark.asyncio
@pytest.mark.parametrize("logprobs_arg", [1, 0])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_with_logprobs_and_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
model_name: str):
# Test case: Logprobs using prompt_embeds
encoded_embeds = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": encoded_embeds})
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
# Test case: Log probs with batch completion and prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
assert len(completion.choices) == 2
for choice in completion.choices:
logprobs = choice.logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5

View File

@ -2,6 +2,8 @@
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
@ -23,6 +25,7 @@ class RequestLogger:
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
@ -39,6 +42,8 @@ class RequestLogger:
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)
prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
lora_request, prompt_adapter_request)

View File

@ -286,6 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
if args.enable_prompt_embeds and args.enable_prompt_adapter:
raise ValueError(
"Cannot use prompt embeds and prompt adapter at the same time.")
def log_non_default_args(args: argparse.Namespace):

View File

@ -745,7 +745,8 @@ class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Union[list[int], list[list[int]], str, list[str]]
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
@ -1025,6 +1026,14 @@ class CompletionRequest(OpenAIBaseModel):
return data
@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
if data.get("prompt") is None and data.get("prompt_embeds") is None:
raise ValueError(
"At least one of `prompt` or `prompt_embeds` must be set.")
return data
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation

View File

@ -8,6 +8,7 @@ from typing import Optional, Union, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@ -25,8 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
clamp_prompt_logprobs,
is_text_tokens_prompt)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -90,6 +94,10 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(
"suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response(
"Echo is unsupported with prompt embeds.")
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())
@ -130,8 +138,24 @@ class OpenAIServingCompletion(OpenAIServing):
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
else:
assert_never(engine_prompt)
default_max_tokens = self.max_model_len - input_length
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
@ -152,6 +176,11 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
@ -211,7 +240,11 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = request_prompts[i]["prompt"]
request_prompt = request_prompts[i]
if is_text_tokens_prompt(request_prompt):
final_res.prompt = request_prompt["prompt"]
else:
final_res.prompt = None
final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
@ -276,8 +309,8 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[dict[

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import base64
import io
import json
import sys
import time
@ -8,11 +9,18 @@ from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union)
TypeVar, Union, cast, overload)
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers
from typing_extensions import TypeIs
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
if sys.version_info >= (3, 12):
from typing import TypedDict
@ -53,7 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -100,7 +109,22 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: list[int]
RequestPrompt = Union[list[int], str, TextTokensPrompt]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt)
RequestT = TypeVar("RequestT", bound=AnyRequest)
@ -112,8 +136,9 @@ class RequestProcessingMixin(BaseModel):
"""
request_prompts: Optional[Sequence[RequestPrompt]] = \
Field(default_factory=list)
engine_prompts: Optional[list[TokensPrompt]] = \
Field(default_factory=list)
engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = Field(
default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -311,6 +336,12 @@ class OpenAIServing:
lora_request=ctx.lora_request,
prompt_adapter_request=ctx.prompt_adapter_request)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
@ -596,10 +627,11 @@ class OpenAIServing:
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> list[TextTokensPrompt]:
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
"""
Tokenize/detokenize depending on the input format.
@ -607,11 +639,25 @@ class OpenAIServing:
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
inputs_embeds = list[EmbedsPrompt]()
inputs_text = list[TextTokensPrompt]()
if (isinstance(request, CompletionRequest)
and request.prompt_embeds is not None):
inputs_embeds.extend(
self._load_prompt_embeds(request.prompt_embeds,
truncate_prompt_tokens))
# Empty prompts are okay as long as there are prompt embeddings
if input_or_inputs is None or (inputs_embeds
and input_or_inputs == ""):
return [], inputs_embeds
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
return [
inputs_text.extend([
self._normalize_prompt_text_to_input(
request,
tokenizer,
@ -625,29 +671,88 @@ class OpenAIServing:
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
]
])
return inputs_text, inputs_embeds
@overload
async def _preprocess_completion(
self,
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
RerankRequest, ClassificationRequest, ScoreRequest,
TokenizeCompletionRequest],
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
add_special_tokens: bool = ...,
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
...
@overload
async def _preprocess_completion(
self,
request: CompletionRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
add_special_tokens: bool = ...,
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
EngineTokensPrompt, EngineEmbedsPrompt]]]:
...
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
) -> tuple[Union[list[TextTokensPrompt], list[Union[
TextTokensPrompt, EmbedsPrompt]]], Union[
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
EngineEmbedsPrompt]]]]:
if not isinstance(request,
CompletionRequest) and input_or_inputs is None:
raise ValueError(
"Prompt embeds with non-completion requests is not"
" currently supported.")
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
for request_prompt in request_prompts
(request_prompts_text, request_prompts_embeds
) = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
engine_prompts_text = [
EngineTokensPrompt(
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
# overloads to the private helper functions to enable this check.
# This overload is needed because only TextPrompts are allowed for
# non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request.
if not isinstance(request,
CompletionRequest) and input_or_inputs is not None:
return request_prompts_text, engine_prompts_text
engine_prompts_embeds = [
EngineEmbedsPrompt(
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text
return request_prompts, engine_prompts
async def _preprocess_chat(
@ -666,7 +771,7 @@ class OpenAIServing:
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[TokensPrompt]]:
list[EngineTokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
@ -739,7 +844,7 @@ class OpenAIServing:
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = TokensPrompt(
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
@ -751,6 +856,35 @@ class OpenAIServing:
return conversation, [request_prompt], [engine_prompt]
def _load_prompt_embeds(
self,
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
weights_only=True)
assert isinstance(
tensor,
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
return {"prompt_embeds": tensor}
if prompt_embeds:
if isinstance(prompt_embeds, list):
return [
_load_and_validate_embed(embed) for embed in prompt_embeds
]
else:
return [_load_and_validate_embed(prompt_embeds)]
else:
return []
def _log_inputs(
self,
request_id: str,
@ -762,13 +896,13 @@ class OpenAIServing:
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
elif 'prompt_embeds' in inputs:
prompt_embeds = inputs.get("prompt_embeds")
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
@ -777,6 +911,7 @@ class OpenAIServing:
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,

View File

@ -106,8 +106,9 @@ class OpenAIServingTokenization(OpenAIServing):
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
input_ids.extend(engine_prompt["prompt_token_ids"])
if isinstance(engine_prompt,
dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),

View File

@ -3,7 +3,7 @@ from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
@ -98,6 +98,17 @@ where the decoder-prompt is not specified explicitly, or
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
"""
def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt)
_T1_co = TypeVar("_T1_co",
bound=SingletonPrompt,
default=SingletonPrompt,