diff --git a/docs/source/index.md b/docs/source/index.md index 0470a43a95..7e5b73c968 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -119,6 +119,7 @@ serving/offline_inference serving/openai_compatible_server serving/serve_args serving/multimodal_inputs +serving/prompt_embeds serving/distributed_serving serving/metrics serving/engine_args diff --git a/docs/source/serving/prompt_embeds.md b/docs/source/serving/prompt_embeds.md new file mode 100644 index 0000000000..483ca16648 --- /dev/null +++ b/docs/source/serving/prompt_embeds.md @@ -0,0 +1,142 @@ +# Prompt Embedding Inputs + +This page teaches you how to pass prompt embedding inputs to vLLM. + +## What are prompt embeddings? + +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. + +:::{note} +Prompt embeddings are currently only supported in the v0 engine. +::: + +## Offline Inference + +To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: + +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. + +### Hugging Face Transformers Inputs + +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: + +```python +from vllm import LLM +import transformers + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + +llm = LLM(model=model_name, enable_prompt_embeds=True) + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Single prompt inference +outputs = llm.generate({ + "prompt_embeds": prompt_embeds, +}) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +# Batch inference + +chats = [ + [{"role": "user", "content": "Please tell me about the capital of France."}], + [{"role": "user", "content": "When is the day longest during the year?"}], + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}] +] + +token_ids_list = [ + tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats +] +prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list] + +outputs = llm.generate( + [ + { + "prompt_embeds": prompt_embeds, + } for prompt_embeds in prompt_embeds_list + ] +) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) +``` + +## Online Serving + +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. + +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. + +Prompt embeddings are passed in as base64 encoded torch tensors. + +### Transformers Inputs via OpenAI Client + +First, launch the OpenAI-compatible server: + +```bash +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ + --max-model-len 4096 --enable-prompt-embeds +``` + +Then, you can use the OpenAI client as follows: + +```python +from openai import OpenAI +import transformers +import torch + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Prompt embeddings +buffer = io.BytesIO() +torch.save(prompt_embeds, buffer) +buffer.seek(0) +binary_data = buffer.read() +encoded_embeds = base64.b64encode(binary_data).decode('utf-8') + + +completion = client_with_prompt_embeds.completions.create( + model=model_name, + # NOTE: The OpenAI client does not allow `None` as an input to + # `prompt`. Use an empty string if you have no text prompts. + prompt="", + max_tokens=5, + temperature=0.0, + # NOTE: The OpenAI client allows passing in extra JSON body via the + # `extra_body` argument. + extra_body={"prompt_embeds": encoded_embeds} +) + +print(completion.choices[0].text) +``` diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py new file mode 100644 index 0000000000..b7ee3e33c2 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import io +import shutil +from tempfile import TemporaryDirectory + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoConfig, AutoTokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +LORA_NAME = "typeof/zephyr-7b-beta-lora" + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def default_server_args( + zephyr_lora_files, + zephyr_lora_added_tokens_files, +) -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + "--no-enable-chunked-prefill", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + # Test case: mixed text and prompt_embeds + encoded_embeds = create_dummy_embeds() + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index ea5759152a..d4655dd5e6 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -2,6 +2,8 @@ from typing import Optional, Union +import torch + from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -23,6 +25,7 @@ class RequestLogger: request_id: str, prompt: Optional[str], prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -39,6 +42,8 @@ class RequestLogger: logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " + "prompt_embeds shape: %s, " "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, lora_request, - prompt_adapter_request) + prompt, params, prompt_token_ids, + prompt_embeds.shape if prompt_embeds is not None else None, + lora_request, prompt_adapter_request) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d8cec22021..d01af5e422 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -286,6 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + if args.enable_prompt_embeds and args.enable_prompt_adapter: + raise ValueError( + "Cannot use prompt embeds and prompt adapter at the same time.") def log_non_default_args(args: argparse.Namespace): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cd6ee36701..5ab2356a08 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -745,7 +745,8 @@ class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None - prompt: Union[list[int], list[list[int]], str, list[str]] + prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -1025,6 +1026,14 @@ class CompletionRequest(OpenAIBaseModel): return data + @model_validator(mode="before") + @classmethod + def validate_prompt_and_prompt_embeds(cls, data): + if data.get("prompt") is None and data.get("prompt_embeds") is None: + raise ValueError( + "At least one of `prompt` or `prompt_embeds` must be set.") + return data + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0b3bdf7d48..7beaae287d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ from typing import Optional, Union, cast import jinja2 from fastapi import Request +from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -25,8 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) + clamp_prompt_logprobs, + is_text_tokens_prompt) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, + is_tokens_prompt) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -90,6 +94,10 @@ class OpenAIServingCompletion(OpenAIServing): return self.create_error_response( "suffix is not currently supported") + if request.echo and request.prompt_embeds is not None: + return self.create_error_response( + "Echo is unsupported with prompt embeds.") + request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) @@ -130,8 +138,24 @@ class OpenAIServingCompletion(OpenAIServing): try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) + # Mypy does not infer that engine_prompt will have only one of + # "prompt_token_ids" or "prompt_embeds" defined, and both of + # these as Union[object, the expected type], where it infers + # object if engine_prompt is a subclass of one of the + # typeddicts that defines both keys. Worse, because of + # https://github.com/python/mypy/issues/8586, mypy does not + # infer the type of engine_prompt correctly because of the + # enumerate. So we need an unnecessary cast here. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) + if is_embeds_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_embeds"]) + elif is_tokens_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_token_ids"]) + else: + assert_never(engine_prompt) + default_max_tokens = self.max_model_len - input_length + if request.use_beam_search: sampling_params = request.to_beam_search_params( default_max_tokens, self.default_sampling_params) @@ -152,6 +176,11 @@ class OpenAIServingCompletion(OpenAIServing): trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) + # Mypy inconsistently requires this second cast in different + # environments. It shouldn't be necessary (redundant from above) + # but pre-commit in CI fails without it. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, @@ -211,7 +240,11 @@ class OpenAIServingCompletion(OpenAIServing): # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - final_res.prompt = request_prompts[i]["prompt"] + request_prompt = request_prompts[i] + if is_text_tokens_prompt(request_prompt): + final_res.prompt = request_prompt["prompt"] + else: + final_res.prompt = None final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -276,8 +309,8 @@ class OpenAIServingCompletion(OpenAIServing): prompt_text = res.prompt # Prompt details are excluded from later streamed outputs - if res.prompt_token_ids is not None: - num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + if prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[dict[ diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f9eebde371..93de9f3a5c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - +import base64 +import io import json import sys import time @@ -8,11 +9,18 @@ from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union) + TypeVar, Union, cast, overload) +import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers +from typing_extensions import TypeIs + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict if sys.version_info >= (3, 12): from typing import TypedDict @@ -53,7 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable -from vllm.inputs import TokensPrompt +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -100,7 +109,22 @@ class TextTokensPrompt(TypedDict): prompt_token_ids: list[int] -RequestPrompt = Union[list[int], str, TextTokensPrompt] +class EmbedsPrompt(TypedDict): + prompt_embeds: torch.Tensor + + +RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] + + +def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -112,8 +136,9 @@ class RequestProcessingMixin(BaseModel): """ request_prompts: Optional[Sequence[RequestPrompt]] = \ Field(default_factory=list) - engine_prompts: Optional[list[TokensPrompt]] = \ - Field(default_factory=list) + engine_prompts: Optional[Union[list[EngineTokensPrompt], + list[EngineEmbedsPrompt]]] = Field( + default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -311,6 +336,12 @@ class OpenAIServing: lora_request=ctx.lora_request, prompt_adapter_request=ctx.prompt_adapter_request) + # Mypy has an existing bug related to inferring the variance of + # TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -596,10 +627,11 @@ class OpenAIServing: self, request: AnyRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> list[TextTokensPrompt]: + ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ Tokenize/detokenize depending on the input format. @@ -607,11 +639,25 @@ class OpenAIServing: , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ + inputs_embeds = list[EmbedsPrompt]() + inputs_text = list[TextTokensPrompt]() + + if (isinstance(request, CompletionRequest) + and request.prompt_embeds is not None): + inputs_embeds.extend( + self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens)) + + # Empty prompts are okay as long as there are prompt embeddings + if input_or_inputs is None or (inputs_embeds + and input_or_inputs == ""): + return [], inputs_embeds + # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly - # "is True" is required for Pyright to perform type narrowing + # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - return [ + inputs_text.extend([ self._normalize_prompt_text_to_input( request, tokenizer, @@ -625,29 +671,88 @@ class OpenAIServing: prompt_ids=prompt_input["content"], truncate_prompt_tokens=truncate_prompt_tokens) for prompt_input in parse_and_batch_prompt(input_or_inputs) - ] + ]) + + return inputs_text, inputs_embeds + + @overload + async def _preprocess_completion( + self, + request: Union[DetokenizeRequest, EmbeddingCompletionRequest, + RerankRequest, ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest], + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: + ... + + @overload + async def _preprocess_completion( + self, + request: CompletionRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ + EngineTokensPrompt, EngineEmbedsPrompt]]]: + ... async def _preprocess_completion( self, request: CompletionLikeRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: - request_prompts = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) + ) -> tuple[Union[list[TextTokensPrompt], list[Union[ + TextTokensPrompt, EmbedsPrompt]]], Union[ + list[EngineTokensPrompt], list[Union[EngineTokensPrompt, + EngineEmbedsPrompt]]]]: + if not isinstance(request, + CompletionRequest) and input_or_inputs is None: + raise ValueError( + "Prompt embeds with non-completion requests is not" + " currently supported.") - engine_prompts = [ - TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) - for request_prompt in request_prompts + (request_prompts_text, request_prompts_embeds + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts_text = [ + EngineTokensPrompt( + prompt_token_ids=request_prompt_text["prompt_token_ids"]) + for request_prompt_text in request_prompts_text ] + # This check is equivalent to simply checking if + # `request_prompts_embeds` is empty, but it's difficult to propagate + # overloads to the private helper functions to enable this check. + # This overload is needed because only TextPrompts are allowed for + # non-completion requests and if we don't add the overload here, + # everywhere this function is used outside of serving_completion will + # need logic asserting that only text prompts are in the request. + if not isinstance(request, + CompletionRequest) and input_or_inputs is not None: + return request_prompts_text, engine_prompts_text + + engine_prompts_embeds = [ + EngineEmbedsPrompt( + prompt_embeds=request_prompt_embeds["prompt_embeds"]) + for request_prompt_embeds in request_prompts_embeds + ] + + request_prompts = request_prompts_embeds + request_prompts_text + engine_prompts = engine_prompts_embeds + engine_prompts_text return request_prompts, engine_prompts async def _preprocess_chat( @@ -666,7 +771,7 @@ class OpenAIServing: truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[TokensPrompt]]: + list[EngineTokensPrompt]]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -739,7 +844,7 @@ class OpenAIServing: prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt) - engine_prompt = TokensPrompt( + engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data @@ -751,6 +856,35 @@ class OpenAIServing: return conversation, [request_prompt], [engine_prompt] + def _load_prompt_embeds( + self, + prompt_embeds: Optional[Union[bytes, list[bytes]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + ) -> list[EmbedsPrompt]: + + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + tensor = torch.load(io.BytesIO(base64.b64decode(embed)), + weights_only=True) + assert isinstance( + tensor, + (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + return {"prompt_embeds": tensor} + + if prompt_embeds: + if isinstance(prompt_embeds, list): + return [ + _load_and_validate_embed(embed) for embed in prompt_embeds + ] + else: + return [_load_and_validate_embed(prompt_embeds)] + else: + return [] + def _log_inputs( self, request_id: str, @@ -762,13 +896,13 @@ class OpenAIServing: ) -> None: if self.request_logger is None: return - + prompt, prompt_token_ids, prompt_embeds = None, None, None if isinstance(inputs, str): prompt = inputs - prompt_token_ids = None elif isinstance(inputs, list): - prompt = None prompt_token_ids = inputs + elif 'prompt_embeds' in inputs: + prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] @@ -777,6 +911,7 @@ class OpenAIServing: request_id, prompt, prompt_token_ids, + prompt_embeds, params=params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 349e0ac9e6..5ef1a486d8 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -106,8 +106,9 @@ class OpenAIServingTokenization(OpenAIServing): # Silently ignore prompt adapter since it does not affect # tokenization (Unlike in Embeddings API where an error is raised) - - input_ids.extend(engine_prompt["prompt_token_ids"]) + if isinstance(engine_prompt, + dict) and "prompt_token_ids" in engine_prompt: + input_ids.extend(engine_prompt["prompt_token_ids"]) return TokenizeResponse(tokens=input_ids, count=len(input_ids), diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c83ab73b61..3b58ec47d5 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast import torch -from typing_extensions import NotRequired, TypedDict, TypeVar +from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs @@ -98,6 +98,17 @@ where the decoder-prompt is not specified explicitly, or more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt` """ + +def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + _T1_co = TypeVar("_T1_co", bound=SingletonPrompt, default=SingletonPrompt,