[Benchmarks] add benchmark for embedding models (#23000)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-08-26 14:57:08 +08:00
committed by GitHub
parent 7d67a9d9f9
commit 3ecbb14b81
3 changed files with 274 additions and 107 deletions

View File

@ -73,7 +73,7 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
prompt: Union[str, Any]
prompt: Union[str, list[str]]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[
@ -409,6 +409,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
**kwargs,
) -> list[SampleRequest]:
@ -439,6 +440,21 @@ class RandomDataset(BenchmarkDataset):
request_id=request_id_prefix + str(i),
)
)
# only used for embeddings benchmark.
if batchsize > 1:
batch_requests = []
# Create batched requests
for i in range(0, num_requests, batchsize):
batch = requests[i : i + batchsize]
batch_requests.append(
SampleRequest(
prompt=[req.prompt for req in batch],
prompt_len=sum(req.prompt_len for req in batch),
expected_output_len=0,
request_id=request_id_prefix + str(i // batchsize),
)
)
requests = batch_requests
return requests
def get_prefix(
@ -475,8 +491,8 @@ class RandomDataset(BenchmarkDataset):
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
if input_low > input_high:
@ -506,7 +522,6 @@ class RandomDataset(BenchmarkDataset):
size=num_requests)
return input_lens, output_lens, offsets
def generate_token_sequence(
self,
*,
@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
)
random_group.add_argument(
"--random-batch-size",
type=int,
default=1,
help=("Batch size for random sampling. "
"Only used for embeddings benchmark."),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
@ -1196,8 +1218,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
@ -1348,22 +1368,24 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
),
"burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,),
"random":
lambda: RandomDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
@ -1371,6 +1393,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
),
"random-mm":
lambda: RandomMultiModalDataset(

View File

@ -69,8 +69,8 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
@ -135,7 +135,7 @@ async def async_request_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
@ -254,7 +254,7 @@ async def async_request_openai_chat_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
@ -394,12 +394,61 @@ async def async_request_openai_audio(
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
):
api_url = request_func_input.api_url
assert api_url.endswith(
"embeddings"
), "OpenAI Embeddings API URL must end with 'embeddings'."
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
payload = {
"model": request_func_input.model,
"input": request_func_input.prompt,
}
output = RequestFuncOutput()
st = time.perf_counter()
try:
async with session.post(
url=api_url,
headers=headers,
json=payload
) as response:
if response.status == 200:
output.latency = time.perf_counter() - st
data = await response.json()
output.success = True
output.generated_text = ""
output.prompt_len = data.get(
"usage", {}).get(
"prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
}
OPENAI_COMPATIBLE_BACKENDS = [

View File

@ -4,7 +4,7 @@ r"""Benchmark online serving throughput.
On the server side, run one of the following commands
to launch the vLLM OpenAI API server:
vllm serve <your_model> <engine arguments>
vllm serve <your_model> <engine arguments>
On the client side, run:
vllm bench serve \
@ -26,6 +26,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Literal, Optional
import aiohttp
@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
class TaskType(Enum):
GENERATION = "generation"
EMBEDDING = "embedding"
@dataclass
class BenchmarkMetrics:
completed: int
@ -75,6 +81,16 @@ class BenchmarkMetrics:
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
@dataclass
class EmbedBenchmarkMetrics:
completed: int
total_input: int
request_throughput: float
total_token_throughput :float
mean_e2el_ms: float
std_e2el_ms: float
median_e2el_ms: float
percentiles_e2el_ms: float
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
@ -146,11 +162,11 @@ async def get_request(
delay_ts = []
for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
request_rates.append(current_request_rate)
if current_request_rate == float("inf"):
delay_ts.append(0)
@ -160,7 +176,7 @@ async def get_request(
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
# Calculate the cumulative delay time from the first sent out requests.
for i in range(1, len(delay_ts)):
delay_ts[i] += delay_ts[i - 1]
@ -170,11 +186,11 @@ async def get_request(
# logic would re-scale delay time to ensure the final delay_ts
# align with target_total_delay_s.
#
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# from target_total_delay_s. The purpose of the following logic is to
# close the gap for stablizing the throughput data
# from different random seeds.
# close the gap for stablizing the throughput data
# from different random seeds.
target_total_delay_s = total_requests / request_rate
normalize_factor = target_total_delay_s / delay_ts[-1]
delay_ts = [delay * normalize_factor for delay in delay_ts]
@ -189,6 +205,51 @@ async def get_request(
yield request, request_rates[request_index]
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
Args:
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
selected_percentiles: The percentiles to select.
Returns:
The calculated benchmark metrics.
"""
total_input = 0
completed = 0
e2els: list[float] = []
for i in range(len(outputs)):
if outputs[i].success:
e2els.append(outputs[i].latency)
completed += 1
total_input += outputs[i].prompt_len
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = EmbedBenchmarkMetrics(
completed=completed,
total_input=total_input,
request_throughput=completed / dur_s,
total_token_throughput=total_input / dur_s,
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles
],
)
return metrics
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
@ -334,8 +395,16 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600,
):
task_type = (
TaskType.EMBEDDING
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
if endpoint_type in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
else:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
@ -421,8 +490,8 @@ async def benchmark(
if profile_output.success:
print("Profiler started")
distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution")
distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution")
if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
@ -449,7 +518,7 @@ async def benchmark(
session=session,
pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
@ -513,14 +582,22 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
else:
metrics = calculate_metrics_for_embeddings(
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
)
actual_output_lens = 0
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
@ -529,39 +606,55 @@ async def benchmark(
max_concurrency))
if request_rate != float('inf'):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
request_rate ))
request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format(
"Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
if isinstance(metrics, BenchmarkMetrics):
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
if isinstance(metrics, BenchmarkMetrics):
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
else:
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"errors": [output.error for output in outputs],
}
if rps_change_events:
result["rps_change_events"] = rps_change_events
@ -598,10 +691,11 @@ async def benchmark(
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
if task_type == TaskType.GENERATION:
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric(
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
print("=" * 50)
@ -732,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
@ -743,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@ -968,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
print(args)
random.seed(args.seed)
@ -1046,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
gc.freeze()
benchmark_result = await benchmark(
endpoint_type=args.endpoint_type,
api_url=api_url,
base_url=base_url,
model_id=model_id,
model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec,
)
endpoint_type=args.endpoint_type,
api_url=api_url,
base_url=base_url,
model_id=model_id,
model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec,
)
# Save config and results to json
result_json: dict[str, Any] = {}
@ -1098,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@ -1132,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.max_concurrency is not None else "")
label = label or endpoint_type
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
@ -1149,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
return result_json
return result_json