mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Feature/vllm/input embedding completion api (#17590)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Nan2018 <nan@protopia.ai> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Andrew Sansom <qthequartermasterman@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -119,6 +119,7 @@ serving/offline_inference
|
||||
serving/openai_compatible_server
|
||||
serving/serve_args
|
||||
serving/multimodal_inputs
|
||||
serving/prompt_embeds
|
||||
serving/distributed_serving
|
||||
serving/metrics
|
||||
serving/engine_args
|
||||
|
142
docs/source/serving/prompt_embeds.md
Normal file
142
docs/source/serving/prompt_embeds.md
Normal file
@ -0,0 +1,142 @@
|
||||
# Prompt Embedding Inputs
|
||||
|
||||
This page teaches you how to pass prompt embedding inputs to vLLM.
|
||||
|
||||
## What are prompt embeddings?
|
||||
|
||||
The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary.
|
||||
|
||||
:::{note}
|
||||
Prompt embeddings are currently only supported in the v0 engine.
|
||||
:::
|
||||
|
||||
## Offline Inference
|
||||
|
||||
To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`:
|
||||
|
||||
- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model.
|
||||
|
||||
### Hugging Face Transformers Inputs
|
||||
|
||||
You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
import transformers
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
# Transformers
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
|
||||
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
|
||||
# Single prompt inference
|
||||
outputs = llm.generate({
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
# Batch inference
|
||||
|
||||
chats = [
|
||||
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}]
|
||||
]
|
||||
|
||||
token_ids_list = [
|
||||
tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats
|
||||
]
|
||||
prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list]
|
||||
|
||||
outputs = llm.generate(
|
||||
[
|
||||
{
|
||||
"prompt_embeds": prompt_embeds,
|
||||
} for prompt_embeds in prompt_embeds_list
|
||||
]
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
```
|
||||
|
||||
## Online Serving
|
||||
|
||||
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package.
|
||||
|
||||
When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first.
|
||||
|
||||
Prompt embeddings are passed in as base64 encoded torch tensors.
|
||||
|
||||
### Transformers Inputs via OpenAI Client
|
||||
|
||||
First, launch the OpenAI-compatible server:
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \
|
||||
--max-model-len 4096 --enable-prompt-embeds
|
||||
```
|
||||
|
||||
Then, you can use the OpenAI client as follows:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
# Transformers
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
|
||||
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
|
||||
# Prompt embeddings
|
||||
buffer = io.BytesIO()
|
||||
torch.save(prompt_embeds, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
|
||||
|
||||
|
||||
completion = client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
# NOTE: The OpenAI client does not allow `None` as an input to
|
||||
# `prompt`. Use an empty string if you have no text prompts.
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# NOTE: The OpenAI client allows passing in extra JSON body via the
|
||||
# `extra_body` argument.
|
||||
extra_body={"prompt_embeds": encoded_embeds}
|
||||
)
|
||||
|
||||
print(completion.choices[0].text)
|
||||
```
|
257
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
Normal file
257
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
Normal file
@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import io
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||
tmp_dir = TemporaryDirectory()
|
||||
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
# Copy tokenizer to adapter and add some unique tokens
|
||||
# 32000, 32001, 32002
|
||||
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||
special_tokens=True)
|
||||
assert added == 3
|
||||
tokenizer.save_pretrained(tmp_model_dir)
|
||||
yield tmp_model_dir
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(
|
||||
zephyr_lora_files,
|
||||
zephyr_lora_added_tokens_files,
|
||||
) -> list[str]:
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# Prompt Embeds server args
|
||||
"--enable-prompt-embeds",
|
||||
"--no-enable-chunked-prefill",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def server_with_prompt_embeds(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_prompt_embeds(server_with_prompt_embeds):
|
||||
async with server_with_prompt_embeds.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def create_dummy_embeds(num_tokens: int = 5) -> str:
|
||||
"""Create dummy embeddings and return them as base64 encoded string."""
|
||||
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(dummy_embeds, buffer)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_completions_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
# Test case: Single prompt embeds input
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
# Test case: batch completion with prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
assert len(completion.choices) == 2
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert len(completion.choices[1].text) >= 1
|
||||
|
||||
# Test case: streaming with prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
single_completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
chunks = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
# Test case: batch streaming with prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
chunks_stream_embeds: list[list[str]] = [[], []]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks_stream_embeds[chunk.choices[0].index].append(
|
||||
chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == 2
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert len(chunks_stream_embeds[0]) > 0
|
||||
assert len(chunks_stream_embeds[1]) > 0
|
||||
|
||||
# Test case: mixed text and prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion_mixed = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="This is a prompt",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
assert len(completion.choices) == 2
|
||||
completion_text_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="This is a prompt",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
# Embeddings responses should be handled first
|
||||
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
||||
0].text
|
||||
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
||||
0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_completions_errors_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
# Test error case: invalid prompt_embeds
|
||||
with pytest.raises(BadRequestError):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
prompt="",
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": "invalid_base64"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
|
||||
model_name: str):
|
||||
# Test case: Logprobs using prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) == 5
|
||||
assert len(logprobs.token_logprobs) == 5
|
||||
assert len(logprobs.top_logprobs) == 5
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
|
||||
# Test case: Log probs with batch completion and prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
|
||||
assert len(completion.choices) == 2
|
||||
for choice in completion.choices:
|
||||
logprobs = choice.logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) == 5
|
||||
assert len(logprobs.token_logprobs) == 5
|
||||
assert len(logprobs.top_logprobs) == 5
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -23,6 +25,7 @@ class RequestLogger:
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_embeds: Optional[torch.Tensor],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
@ -39,6 +42,8 @@ class RequestLogger:
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"prompt_embeds shape: %s, "
|
||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||
prompt, params, prompt_token_ids, lora_request,
|
||||
prompt_adapter_request)
|
||||
prompt, params, prompt_token_ids,
|
||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
@ -286,6 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
if args.enable_prompt_embeds and args.enable_prompt_adapter:
|
||||
raise ValueError(
|
||||
"Cannot use prompt embeds and prompt adapter at the same time.")
|
||||
|
||||
|
||||
def log_non_default_args(args: argparse.Namespace):
|
||||
|
@ -745,7 +745,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: Optional[str] = None
|
||||
prompt: Union[list[int], list[list[int]], str, list[str]]
|
||||
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
@ -1025,6 +1026,14 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt_and_prompt_embeds(cls, data):
|
||||
if data.get("prompt") is None and data.get("prompt_embeds") is None:
|
||||
raise ValueError(
|
||||
"At least one of `prompt` or `prompt_embeds` must be set.")
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -25,8 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
@ -90,6 +94,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
@ -130,8 +138,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
default_max_tokens = self.max_model_len - input_length
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
@ -152,6 +176,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
@ -211,7 +240,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
final_res.prompt = request_prompts[i]["prompt"]
|
||||
request_prompt = request_prompts[i]
|
||||
if is_text_tokens_prompt(request_prompt):
|
||||
final_res.prompt = request_prompt["prompt"]
|
||||
else:
|
||||
final_res.prompt = None
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
@ -276,8 +309,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[dict[
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@ -8,11 +9,18 @@ from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar, Union)
|
||||
TypeVar, Union, cast, overload)
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
@ -53,7 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -100,7 +109,22 @@ class TextTokensPrompt(TypedDict):
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt]
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
|
||||
|
||||
|
||||
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
@ -112,8 +136,9 @@ class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = \
|
||||
Field(default_factory=list)
|
||||
engine_prompts: Optional[list[TokensPrompt]] = \
|
||||
Field(default_factory=list)
|
||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||
list[EngineEmbedsPrompt]]] = Field(
|
||||
default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@ -311,6 +336,12 @@ class OpenAIServing:
|
||||
lora_request=ctx.lora_request,
|
||||
prompt_adapter_request=ctx.prompt_adapter_request)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance of
|
||||
# TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
@ -596,10 +627,11 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> list[TextTokensPrompt]:
|
||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
@ -607,11 +639,25 @@ class OpenAIServing:
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
inputs_embeds = list[EmbedsPrompt]()
|
||||
inputs_text = list[TextTokensPrompt]()
|
||||
|
||||
if (isinstance(request, CompletionRequest)
|
||||
and request.prompt_embeds is not None):
|
||||
inputs_embeds.extend(
|
||||
self._load_prompt_embeds(request.prompt_embeds,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
# Empty prompts are okay as long as there are prompt embeddings
|
||||
if input_or_inputs is None or (inputs_embeds
|
||||
and input_or_inputs == ""):
|
||||
return [], inputs_embeds
|
||||
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# "is False" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
return [
|
||||
inputs_text.extend([
|
||||
self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
@ -625,29 +671,88 @@ class OpenAIServing:
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
||||
]
|
||||
])
|
||||
|
||||
return inputs_text, inputs_embeds
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
||||
RerankRequest, ClassificationRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest],
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
||||
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
||||
...
|
||||
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
|
||||
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
||||
TextTokensPrompt, EmbedsPrompt]]], Union[
|
||||
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
|
||||
EngineEmbedsPrompt]]]]:
|
||||
if not isinstance(request,
|
||||
CompletionRequest) and input_or_inputs is None:
|
||||
raise ValueError(
|
||||
"Prompt embeds with non-completion requests is not"
|
||||
" currently supported.")
|
||||
|
||||
engine_prompts = [
|
||||
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
|
||||
for request_prompt in request_prompts
|
||||
(request_prompts_text, request_prompts_embeds
|
||||
) = await self._tokenize_prompt_input_or_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
engine_prompts_text = [
|
||||
EngineTokensPrompt(
|
||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||
for request_prompt_text in request_prompts_text
|
||||
]
|
||||
|
||||
# This check is equivalent to simply checking if
|
||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||
# overloads to the private helper functions to enable this check.
|
||||
# This overload is needed because only TextPrompts are allowed for
|
||||
# non-completion requests and if we don't add the overload here,
|
||||
# everywhere this function is used outside of serving_completion will
|
||||
# need logic asserting that only text prompts are in the request.
|
||||
if not isinstance(request,
|
||||
CompletionRequest) and input_or_inputs is not None:
|
||||
return request_prompts_text, engine_prompts_text
|
||||
|
||||
engine_prompts_embeds = [
|
||||
EngineEmbedsPrompt(
|
||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||
for request_prompt_embeds in request_prompts_embeds
|
||||
]
|
||||
|
||||
request_prompts = request_prompts_embeds + request_prompts_text
|
||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _preprocess_chat(
|
||||
@ -666,7 +771,7 @@ class OpenAIServing:
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||
list[TokensPrompt]]:
|
||||
list[EngineTokensPrompt]]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
@ -739,7 +844,7 @@ class OpenAIServing:
|
||||
prompt=tokenizer.decode(request_prompt),
|
||||
prompt_token_ids=request_prompt)
|
||||
|
||||
engine_prompt = TokensPrompt(
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
@ -751,6 +856,35 @@ class OpenAIServing:
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
def _load_prompt_embeds(
|
||||
self,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
) -> list[EmbedsPrompt]:
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
|
||||
weights_only=True)
|
||||
assert isinstance(
|
||||
tensor,
|
||||
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
return {"prompt_embeds": tensor}
|
||||
|
||||
if prompt_embeds:
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [
|
||||
_load_and_validate_embed(embed) for embed in prompt_embeds
|
||||
]
|
||||
else:
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -762,13 +896,13 @@ class OpenAIServing:
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = None
|
||||
elif isinstance(inputs, list):
|
||||
prompt = None
|
||||
prompt_token_ids = inputs
|
||||
elif 'prompt_embeds' in inputs:
|
||||
prompt_embeds = inputs.get("prompt_embeds")
|
||||
else:
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
@ -777,6 +911,7 @@ class OpenAIServing:
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
|
@ -106,8 +106,9 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect
|
||||
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
if isinstance(engine_prompt,
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
count=len(input_ids),
|
||||
|
@ -3,7 +3,7 @@ from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||
@ -98,6 +98,17 @@ where the decoder-prompt is not specified explicitly, or
|
||||
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
|
||||
"""
|
||||
|
||||
|
||||
def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
_T1_co = TypeVar("_T1_co",
|
||||
bound=SingletonPrompt,
|
||||
default=SingletonPrompt,
|
||||
|
Reference in New Issue
Block a user