[benchmark] add peak throughput metrics and plot (#23867)

Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Simon Mo
2025-09-17 22:30:02 -07:00
committed by GitHub
parent b7433ca1a4
commit a904ea78ea
2 changed files with 134 additions and 69 deletions

View File

@ -89,6 +89,7 @@ class RequestFuncOutput:
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
start_time: float = 0.0
async def async_request_openai_completions(
@ -140,6 +141,7 @@ async def async_request_openai_completions(
generated_text = ""
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
@ -396,6 +399,7 @@ async def async_request_openai_audio(
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url,
@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
output = RequestFuncOutput()
st = time.perf_counter()
output.start_time = st
try:
async with session.post(
url=api_url,

View File

@ -18,9 +18,11 @@ On the client side, run:
import argparse
import asyncio
import gc
import importlib.util
import json
import os
import random
import shutil
import time
import warnings
from collections.abc import AsyncGenerator, Iterable
@ -46,6 +48,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None)
and (shutil.which("gnuplot") is not None))
class TaskType(Enum):
GENERATION = "generation"
@ -80,18 +85,23 @@ class BenchmarkMetrics:
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
# Max output tokens per second and concurrent requests at that peak
max_output_tokens_per_s: float
max_concurrent_requests: int
@dataclass
class EmbedBenchmarkMetrics:
completed: int
total_input: int
request_throughput: float
total_token_throughput :float
total_token_throughput: float
mean_e2el_ms: float
std_e2el_ms: float
median_e2el_ms: float
percentiles_e2el_ms: float
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int],
@ -150,8 +160,8 @@ async def get_request(
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
# Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance(
input_requests, list):
if isinstance(input_requests,
Iterable) and not isinstance(input_requests, list):
input_requests = list(input_requests)
total_requests = len(input_requests)
@ -161,12 +171,9 @@ async def get_request(
request_rates = []
delay_ts = []
for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
current_request_rate = _get_current_request_rate(
ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps,
request_index, total_requests, request_rate)
request_rates.append(current_request_rate)
if current_request_rate == float("inf"):
delay_ts.append(0)
@ -206,10 +213,8 @@ async def get_request(
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
outputs: list[RequestFuncOutput], dur_s: float,
selected_percentiles: list[float]) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
Args:
@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings(
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles
],
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
)
return metrics
@ -336,6 +339,67 @@ def calculate_metrics(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
# Calculate max output tokens per second metric
max_output_tokens_per_s = 0.0
max_concurrent_requests = 0
# Find the time range across all successful requests
successful_outputs = [output for output in outputs if output.success]
if successful_outputs:
min_start_time = min(output.start_time
for output in successful_outputs)
max_end_time = max(output.start_time + output.latency
for output in successful_outputs)
# Create second buckets (ceiling to ensure we capture all time)
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
tokens_per_second = np.zeros(duration_seconds)
concurrent_requests_per_second = np.zeros(duration_seconds)
for i, output in enumerate(successful_outputs):
# Calculate token generation timestamp using
# start_time, ttft, and itl
token_times = [output.start_time + output.ttft]
current_time = token_times[0]
for itl_value in output.itl:
current_time += itl_value
token_times.append(current_time)
# Add tokens to second buckets
for token_time in token_times:
second_bucket = int(token_time - min_start_time)
if 0 <= second_bucket < duration_seconds:
tokens_per_second[second_bucket] += 1
# Track concurrent requests for each second this request was active
request_start_second = int(output.start_time - min_start_time)
request_end_second = int((output.start_time + output.latency) -
min_start_time)
for second in range(request_start_second, request_end_second + 1):
concurrent_requests_per_second[second] += 1
# Find the maximum tokens per second and corresponding
# concurrent requests
if len(tokens_per_second) > 0:
max_output_tokens_per_s = float(np.max(tokens_per_second))
max_concurrent_requests = int(
np.max(concurrent_requests_per_second))
if TERM_PLOTLIB_AVAILABLE:
import termplotlib as tpl
fig = tpl.figure()
fig.plot(np.arange(len(tokens_per_second)),
tokens_per_second,
title="Output tokens per second")
fig.plot(np.arange(len(concurrent_requests_per_second)),
concurrent_requests_per_second,
title="Concurrent requests per second")
fig.show()
else:
print("tip: install termplotlib and gnuplot to plot the metrics")
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@ -365,6 +429,8 @@ def calculate_metrics(
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
max_output_tokens_per_s=max_output_tokens_per_s,
max_concurrent_requests=max_concurrent_requests,
)
return metrics, actual_output_lens
@ -396,11 +462,8 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600,
):
task_type = (
TaskType.EMBEDDING
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else
TaskType.GENERATION)
if endpoint_type in ASYNC_REQUEST_FUNCS:
if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
@ -435,14 +498,10 @@ async def benchmark(
input_requests[0].multi_modal_data,
)
assert (
test_mm_content is None
or isinstance(test_mm_content, dict)
or (
isinstance(test_mm_content, list)
and all(isinstance(item, dict) for item in test_mm_content)
)
), "multi_modal_data must be a dict or list[dict]"
assert (test_mm_content is None or isinstance(test_mm_content, dict)
or (isinstance(test_mm_content, list)
and all(isinstance(item, dict) for item in test_mm_content))
), "multi_modal_data must be a dict or list[dict]"
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
@ -488,13 +547,13 @@ async def benchmark(
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body)
profile_output = await request_func(
request_func_input=profile_input, session=session)
profile_output = await request_func(request_func_input=profile_input,
session=session)
if profile_output.success:
print("Profiler started")
distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution")
distribution = ("Poisson process"
if burstiness == 1.0 else "Gamma distribution")
if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
@ -562,18 +621,20 @@ async def benchmark(
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,)
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,
)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
@ -615,19 +676,21 @@ async def benchmark(
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format(
"Total generated tokens:", metrics.total_output))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
if isinstance(metrics, BenchmarkMetrics):
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format(
"Peak output token throughput (tok/s):",
metrics.max_output_tokens_per_s))
print("{:<40} {:<10.2f}".format("Peak concurrent requests:",
metrics.max_concurrent_requests))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
@ -648,6 +711,8 @@ async def benchmark(
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
"max_concurrent_requests": metrics.max_concurrent_requests,
}
else:
result = {
@ -697,8 +762,8 @@ async def benchmark(
if task_type == TaskType.GENERATION:
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric(
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@ -714,8 +779,8 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
)
profile_output = await request_func(
request_func_input=profile_input, session=session)
profile_output = await request_func(request_func_input=profile_input,
session=session)
if profile_output.success:
print("Profiler stopped")
@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Specify the prefix of request id.",
)
sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument(
"--top-p",
@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The ramp-up strategy. This would be used to "
"ramp up the request rate from initial RPS to final "
"RPS rate (specified by --ramp-up-start-rps and "
"--ramp-up-end-rps.) over the duration of the benchmark."
)
"--ramp-up-end-rps.) over the duration of the benchmark.")
parser.add_argument(
"--ramp-up-start-rps",
type=int,
@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise ValueError(
"When using ramp-up, do not specify --request-rate. "
"The request rate will be controlled by ramp-up parameters. "
"Please remove the --request-rate argument."
)
"Please remove the --request-rate argument.")
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
raise ValueError(
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
"--ramp-up-end-rps must be specified"
)
"--ramp-up-end-rps must be specified")
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
raise ValueError("Ramp-up start and end RPS must be non-negative")
if args.ramp_up_start_rps > args.ramp_up_end_rps:
@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
headers[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid header format. Please use KEY=VALUE format."
)
"Invalid header format. Please use KEY=VALUE format.")
tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode,
@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
"Invalid metadata format. Please use KEY=VALUE format.")
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate