[Benchmarks] add benchmark for embedding models (#23000)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@ -73,7 +73,7 @@ class SampleRequest:
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
prompt: Union[str, Any]
|
||||
prompt: Union[str, list[str]]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[
|
||||
@ -409,6 +409,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
batchsize: int = 1,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
|
||||
@ -439,6 +440,21 @@ class RandomDataset(BenchmarkDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
# only used for embeddings benchmark.
|
||||
if batchsize > 1:
|
||||
batch_requests = []
|
||||
# Create batched requests
|
||||
for i in range(0, num_requests, batchsize):
|
||||
batch = requests[i : i + batchsize]
|
||||
batch_requests.append(
|
||||
SampleRequest(
|
||||
prompt=[req.prompt for req in batch],
|
||||
prompt_len=sum(req.prompt_len for req in batch),
|
||||
expected_output_len=0,
|
||||
request_id=request_id_prefix + str(i // batchsize),
|
||||
)
|
||||
)
|
||||
requests = batch_requests
|
||||
return requests
|
||||
|
||||
def get_prefix(
|
||||
@ -475,8 +491,8 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_high = math.ceil(real_input_len * (1 + range_ratio))
|
||||
output_low = math.floor(output_len * (1 - range_ratio))
|
||||
output_high = math.ceil(output_len * (1 + range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
output_low = max(output_low, 1)
|
||||
|
||||
if input_low > input_high:
|
||||
@ -506,7 +522,6 @@ class RandomDataset(BenchmarkDataset):
|
||||
size=num_requests)
|
||||
return input_lens, output_lens, offsets
|
||||
|
||||
|
||||
def generate_token_sequence(
|
||||
self,
|
||||
*,
|
||||
@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Batch size for random sampling. "
|
||||
"Only used for embeddings benchmark."),
|
||||
)
|
||||
|
||||
# random multimodal dataset options
|
||||
random_mm_group = parser.add_argument_group(
|
||||
@ -1196,8 +1218,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
@ -1348,22 +1368,24 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
else:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts,
|
||||
request_id_prefix=args.request_id_prefix,),
|
||||
"random":
|
||||
lambda: RandomDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
"sharegpt": lambda: ShareGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"random": lambda: RandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
@ -1371,6 +1393,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
batchsize=args.random_batch_size,
|
||||
),
|
||||
"random-mm":
|
||||
lambda: RandomMultiModalDataset(
|
||||
|
@ -69,8 +69,8 @@ async def async_request_openai_completions(
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
@ -135,7 +135,7 @@ async def async_request_openai_completions(
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
@ -254,7 +254,7 @@ async def async_request_openai_chat_completions(
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
@ -394,12 +394,61 @@ async def async_request_openai_audio(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
):
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"embeddings"
|
||||
), "OpenAI Embeddings API URL must end with 'embeddings'."
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
output.latency = time.perf_counter() - st
|
||||
data = await response.json()
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
output.prompt_len = data.get(
|
||||
"usage", {}).get(
|
||||
"prompt_tokens", 0)
|
||||
else:
|
||||
output.success = False
|
||||
output.error = response.reason or ""
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
|
@ -4,7 +4,7 @@ r"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands
|
||||
to launch the vLLM OpenAI API server:
|
||||
vllm serve <your_model> <engine arguments>
|
||||
vllm serve <your_model> <engine arguments>
|
||||
|
||||
On the client side, run:
|
||||
vllm bench serve \
|
||||
@ -26,6 +26,7 @@ import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATION = "generation"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
@ -75,6 +81,16 @@ class BenchmarkMetrics:
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
@dataclass
|
||||
class EmbedBenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
request_throughput: float
|
||||
total_token_throughput :float
|
||||
mean_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
percentiles_e2el_ms: float
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
@ -146,11 +162,11 @@ async def get_request(
|
||||
delay_ts = []
|
||||
for request_index, request in enumerate(input_requests):
|
||||
current_request_rate = _get_current_request_rate(ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate)
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate)
|
||||
request_rates.append(current_request_rate)
|
||||
if current_request_rate == float("inf"):
|
||||
delay_ts.append(0)
|
||||
@ -160,7 +176,7 @@ async def get_request(
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
|
||||
|
||||
|
||||
# Calculate the cumulative delay time from the first sent out requests.
|
||||
for i in range(1, len(delay_ts)):
|
||||
delay_ts[i] += delay_ts[i - 1]
|
||||
@ -170,11 +186,11 @@ async def get_request(
|
||||
# logic would re-scale delay time to ensure the final delay_ts
|
||||
# align with target_total_delay_s.
|
||||
#
|
||||
# NOTE: If we simply accumulate the random delta values
|
||||
# from the gamma distribution, their sum would have 1-2% gap
|
||||
# NOTE: If we simply accumulate the random delta values
|
||||
# from the gamma distribution, their sum would have 1-2% gap
|
||||
# from target_total_delay_s. The purpose of the following logic is to
|
||||
# close the gap for stablizing the throughput data
|
||||
# from different random seeds.
|
||||
# close the gap for stablizing the throughput data
|
||||
# from different random seeds.
|
||||
target_total_delay_s = total_requests / request_rate
|
||||
normalize_factor = target_total_delay_s / delay_ts[-1]
|
||||
delay_ts = [delay * normalize_factor for delay in delay_ts]
|
||||
@ -189,6 +205,51 @@ async def get_request(
|
||||
yield request, request_rates[request_index]
|
||||
|
||||
|
||||
def calculate_metrics_for_embeddings(
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
selected_percentiles: list[float]
|
||||
) -> EmbedBenchmarkMetrics:
|
||||
"""Calculate the metrics for the embedding requests.
|
||||
|
||||
Args:
|
||||
outputs: The outputs of the requests.
|
||||
dur_s: The duration of the benchmark.
|
||||
selected_percentiles: The percentiles to select.
|
||||
|
||||
Returns:
|
||||
The calculated benchmark metrics.
|
||||
"""
|
||||
total_input = 0
|
||||
completed = 0
|
||||
e2els: list[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
e2els.append(outputs[i].latency)
|
||||
completed += 1
|
||||
total_input += outputs[i].prompt_len
|
||||
|
||||
if completed == 0:
|
||||
warnings.warn(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
"on the benchmark arguments.",
|
||||
stacklevel=2)
|
||||
metrics = EmbedBenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
request_throughput=completed / dur_s,
|
||||
total_token_throughput=total_input / dur_s,
|
||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
return metrics
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
@ -334,8 +395,16 @@ async def benchmark(
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
ready_check_timeout_sec: int = 600,
|
||||
):
|
||||
task_type = (
|
||||
TaskType.EMBEDDING
|
||||
if api_url.endswith("/v1/embeddings")
|
||||
else TaskType.GENERATION
|
||||
)
|
||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
if task_type == TaskType.EMBEDDING:
|
||||
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
|
||||
else:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
else:
|
||||
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
|
||||
|
||||
@ -421,8 +490,8 @@ async def benchmark(
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
distribution = ("Poisson process" if burstiness == 1.0
|
||||
else "Gamma distribution")
|
||||
distribution = ("Poisson process" if burstiness == 1.0
|
||||
else "Gamma distribution")
|
||||
|
||||
if ramp_up_strategy is not None:
|
||||
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
|
||||
@ -449,7 +518,7 @@ async def benchmark(
|
||||
session=session,
|
||||
pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
session=session,
|
||||
pbar=pbar)
|
||||
|
||||
@ -513,14 +582,22 @@ async def benchmark(
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentiles=selected_percentiles,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
if task_type == TaskType.GENERATION:
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentiles=selected_percentiles,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
else:
|
||||
metrics = calculate_metrics_for_embeddings(
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
actual_output_lens = 0
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
@ -529,39 +606,55 @@ async def benchmark(
|
||||
max_concurrency))
|
||||
if request_rate != float('inf'):
|
||||
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
|
||||
request_rate ))
|
||||
request_rate))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
if isinstance(metrics, BenchmarkMetrics):
|
||||
print("{:<40} {:<10}".format(
|
||||
"Total generated tokens:", metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
if isinstance(metrics, BenchmarkMetrics):
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput":
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
if isinstance(metrics, BenchmarkMetrics):
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput":
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
@ -598,10 +691,11 @@ async def benchmark(
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
if task_type == TaskType.GENERATION:
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric(
|
||||
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
print("=" * 50)
|
||||
@ -732,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"initiated, this argument will control how many are actually allowed "
|
||||
"to execute at a time. This means that when used in combination, the "
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.")
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@ -743,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
@ -968,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||
return asyncio.run(main_async(args))
|
||||
|
||||
|
||||
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
@ -1046,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = await benchmark(
|
||||
endpoint_type=args.endpoint_type,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
||||
)
|
||||
endpoint_type=args.endpoint_type,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
||||
)
|
||||
|
||||
# Save config and results to json
|
||||
result_json: dict[str, Any] = {}
|
||||
@ -1098,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
@ -1132,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
if args.max_concurrency is not None else "")
|
||||
label = label or endpoint_type
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
@ -1149,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
return result_json
|
||||
return result_json
|
||||
|
Reference in New Issue
Block a user