[Frontend][1/N] Improve all pooling task | Support FP16 Embedding Base64 (Still uses fp32 by default). (#26414)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi
2025-10-14 03:06:43 +08:00
committed by GitHub
parent 89342ce4c0
commit d2a7938582
8 changed files with 312 additions and 30 deletions

View File

@ -6,6 +6,12 @@
python examples/online_serving/pooling/cohere_rerank_client.py
```
## Embedding embed_dtype usage
```bash
python examples/online_serving/pooling/embedding_embed_dtype_client.py
```
## Jinaai rerank usage
```bash

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
start a supported embeddings model server with `vllm serve`, e.g.
vllm serve intfloat/e5-small
"""
import argparse
import base64
import requests
import torch
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
prompt = {
"model": model_name,
"input": "vLLM is great!",
"encoding_format": "base64",
"embed_dtype": embed_dtype,
}
response = post_http_request(prompt=prompt, api_url=api_url)
embedding = []
for data in response.json()["data"]:
embedding.append(
torch.frombuffer(
base64.b64decode(data["embedding"]), dtype=torch_dtype
).to(torch.float32)
)
embedding = torch.cat(embedding)
print(embed_dtype, embedding.shape)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -14,7 +14,10 @@ import torch.nn.functional as F
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingResponse,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "intfloat/multilingual-e5-small"
@ -244,6 +247,75 @@ async def test_batch_base64_embedding(
run_embedding_correctness_test(hf_model, input_texts, default_data)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(
hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float"
)
float_data = [d.embedding for d in responses_float.data]
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
},
)
base64_data = []
for data in responses_base64.json()["data"]:
base64_data.append(
torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
.to(torch.float32)
.tolist()
)
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
hf_model, server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
bad_embed_dtype = "bad_embed_dtype"
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": bad_embed_dtype,
},
)
assert responses_base64.status_code == 400
assert responses_base64.json()["error"]["message"].startswith(
f"embed_dtype={bad_embed_dtype!r} is not supported."
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):

View File

@ -6,10 +6,11 @@ import base64
import numpy as np
import pytest
import requests
import torch
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE, PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "internlm/internlm2-1_8b-reward"
@ -248,6 +249,80 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str)
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(server: RemoteOpenAIServer, model_name: str):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
url = server.url_for("pooling")
float_response = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "float",
},
)
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
responses_base64 = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
},
)
base64_data = []
for data in responses_base64.json()["data"]:
base64_data.append(
torch.frombuffer(base64.b64decode(data["data"]), dtype=torch_dtype)
.to(torch.float32)
.tolist()
)
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
bad_embed_dtype = "bad_embed_dtype"
responses_base64 = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": bad_embed_dtype,
},
)
assert responses_base64.status_code == 400
assert responses_base64.json()["error"]["message"].startswith(
f"embed_dtype={bad_embed_dtype!r} is not supported."
)
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
input_texts = [

View File

@ -83,6 +83,18 @@ from vllm.sampling_params import (
)
from vllm.utils import random_uuid, resolve_obj_by_qualname
EMBED_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
@ -1517,8 +1529,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
"through out the inference process and return in response."
),
)
normalize: bool | None = None
normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: str = Field(
default="float32",
description=(
"What dtype to use for base64 encoding. Default to using "
"float32 for base64 encoding to match the OpenAI python client behavior."
),
)
# --8<-- [end:embedding-extra-params]
def to_pooling_params(self):
@ -1594,7 +1615,17 @@ class EmbeddingChatRequest(OpenAIBaseModel):
"through out the inference process and return in response."
),
)
normalize: bool | None = None
normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: str = Field(
default="float32",
description=(
"Which dtype to use for base64 encoding. Defaults to float32 "
"to match OpenAI API."
),
)
# --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before")
@ -1639,6 +1670,14 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
"""
softmax: bool = True
embed_dtype: str = Field(
default="float32",
description=(
"What dtype to use for base64 encoding. Default to using "
"float32 for base64 encoding to match the OpenAI python client behavior."
),
)
def to_pooling_params(self):
return PoolingParams(task="encode", softmax=self.softmax)

View File

@ -1,19 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, Literal, cast
from typing import Any, Final, cast
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never, override
from typing_extensions import override
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
@ -29,11 +28,11 @@ from vllm.entrypoints.openai.serving_engine import (
TextTokensPrompt,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.utils import encoding_pooling_output
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
EmbeddingOutput,
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
@ -45,21 +44,6 @@ from vllm.utils import chunk_list
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> list[float] | str:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -83,6 +67,12 @@ class EmbeddingMixin(OpenAIServing):
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
return self.create_error_response(
f"embed_dtype={ctx.request.embed_dtype!r} is not supported. "
f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
)
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer()
@ -137,12 +127,10 @@ class EmbeddingMixin(OpenAIServing):
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(
embedding_res.outputs, ctx.request.encoding_format
embedding=encoding_pooling_output(
final_res, ctx.request.encoding_format, ctx.request.embed_dtype
),
)
prompt_token_ids = final_res.prompt_token_ids

View File

@ -17,6 +17,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
@ -29,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.utils import encoding_pooling_output
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
@ -90,6 +92,12 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
if request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
return self.create_error_response(
f"embed_dtype={request.embed_dtype!r} is not supported. "
f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
)
model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}"
@ -235,6 +243,7 @@ class OpenAIServingPooling(OpenAIServing):
created_time,
model_name,
request.encoding_format,
request.embed_dtype,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
@ -251,6 +260,7 @@ class OpenAIServingPooling(OpenAIServing):
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
embed_dtype: str,
) -> PoolingResponse:
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
@ -258,7 +268,7 @@ class OpenAIServingPooling(OpenAIServing):
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=_get_data(final_res.outputs, encoding_format),
data=encoding_pooling_output(final_res, encoding_format, embed_dtype),
)
prompt_token_ids = final_res.prompt_token_ids

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from typing import Literal
import torch
from typing_extensions import assert_never
from vllm import PoolingRequestOutput
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE
def encoding_pooling_output(
output: PoolingRequestOutput,
encoding_format: Literal["float", "base64"],
embed_dtype: str,
) -> list[float] | str:
if encoding_format == "float":
return output.outputs.data.tolist()
elif encoding_format == "base64":
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
embedding_bytes = (
output.outputs.data.to(torch_dtype)
.flatten()
.contiguous()
.view(torch.uint8)
.numpy()
.tobytes()
)
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)