mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Benchmark] Refactor perf script to use benchmark cli (#1524)
### What this PR does / why we need it? Since, `vllm bench` cli has optimized enough for production use(support more datasets), we are now do not need to copy vllm codes, now , with vllm installed, we can easily use the benchmark cli ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
# Introduction
|
||||
This document outlines the benchmarking methodology for vllm-ascend, aimed at evaluating the performance under a variety of workloads. The primary goal is to help developers assess whether their pull requests improve or degrade vllm-ascend's performance. To maintain alignment with vLLM, we use the [benchmark](https://github.com/vllm-project/vllm/tree/main/benchmarks) script provided by the vllm project.
|
||||
This document outlines the benchmarking methodology for vllm-ascend, aimed at evaluating the performance under a variety of workloads. The primary goal is to help developers assess whether their pull requests improve or degrade vllm-ascend's performance.
|
||||
|
||||
# Overview
|
||||
**Benchmarking Coverage**: We measure latency, throughput, and fixed-QPS serving on the Atlas800I A2 (see [quick_start](../docs/source/quick_start.md) to learn more supported devices list), with different models(coming soon).
|
||||
@ -7,21 +7,21 @@ This document outlines the benchmarking methodology for vllm-ascend, aimed at ev
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: Meta-Llama-3.1-8B-Instruct, Qwen2.5-7B-Instruct.
|
||||
- Models: Qwen2.5-7B-Instruct, Qwen3-8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
- Throughput tests
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: Meta-Llama-3.1-8B-Instruct, Qwen2.5-7B-Instruct.
|
||||
- Models: Qwen2.5-VL-7B-Instruct, Qwen2.5-7B-Instruct, Qwen3-8B.
|
||||
- Evaluation metrics: throughput.
|
||||
- Serving tests
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: Meta-Llama-3.1-8B-Instruct, Qwen2.5-7B-Instruct.
|
||||
- Models: Qwen2.5-VL-7B-Instruct, Qwen2.5-7B-Instruct, Qwen3-8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
|
||||
**Benchmarking Duration**: about 800 senond for single model.
|
||||
@ -38,20 +38,129 @@ Before running the benchmarks, ensure the following:
|
||||
pip install -r benchmarks/requirements-bench.txt
|
||||
```
|
||||
|
||||
- For performance benchmark, it is recommended to set the [load-format](https://github.com/vllm-project/vllm-ascend/blob/5897dc5bbe321ca90c26225d0d70bff24061d04b/benchmarks/tests/latency-tests.json#L7) as `dummy`, It will construct random weights based on the passed model without downloading the weights from internet, which can greatly reduce the benchmark time. feel free to add your own models and parameters in the JSON to run your customized benchmarks.
|
||||
- For performance benchmark, it is recommended to set the [load-format](https://github.com/vllm-project/vllm-ascend/blob/5897dc5bbe321ca90c26225d0d70bff24061d04b/benchmarks/tests/latency-tests.json#L7) as `dummy`, It will construct random weights based on the passed model without downloading the weights from internet, which can greatly reduce the benchmark time.
|
||||
- If you want to run benchmark customized, feel free to add your own models and parameters in the [JSON](https://github.com/vllm-project/vllm-ascend/tree/main/benchmarks/tests), let's take `Qwen2.5-VL-7B-Instruct`as an example:
|
||||
|
||||
```shell
|
||||
[
|
||||
{
|
||||
"test_name": "serving_qwen2_5vl_7B_tp1",
|
||||
"qps_list": [
|
||||
1,
|
||||
4,
|
||||
16,
|
||||
"inf"
|
||||
],
|
||||
"server_parameters": {
|
||||
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"swap_space": 16,
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"trust_remote_code": "",
|
||||
"max_model_len": 16384
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"backend": "openai-chat",
|
||||
"dataset_name": "hf",
|
||||
"hf_split": "train",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"dataset_path": "lmarena-ai/vision-arena-bench-v0.1",
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
this Json will be structured and parsed into server parameters and client parameters by the benchmark script. This configuration defines a test case named `serving_qwen2_5vl_7B_tp1`, designed to evaluate the performance of the `Qwen/Qwen2.5-VL-7B-Instruct` model under different request rates. The test includes both server and client parameters, for more parameters details, see vllm benchmark [cli](https://github.com/vllm-project/vllm/tree/main/vllm/benchmarks).
|
||||
|
||||
- **Test Overview**
|
||||
- Test Name: serving_qwen2_5vl_7B_tp1
|
||||
|
||||
- Queries Per Second (QPS): The test is run at four different QPS levels: 1, 4, 16, and inf (infinite load, typically used for stress testing).
|
||||
|
||||
- Server Parameters
|
||||
- Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
- Tensor Parallelism: 1 (no model parallelism is used; the model runs on a single device or node)
|
||||
|
||||
- Swap Space: 16 GB (used to handle memory overflow by swapping to disk)
|
||||
|
||||
- disable_log_stats: disables logging of performance statistics.
|
||||
|
||||
- disable_log_requests: disables logging of individual requests.
|
||||
|
||||
- Trust Remote Code: enabled (allows execution of model-specific custom code)
|
||||
|
||||
- Max Model Length: 16,384 tokens (maximum context length supported by the model)
|
||||
|
||||
- Client Parameters
|
||||
|
||||
- Model: Qwen/Qwen2.5-VL-7B-Instruct (same as the server)
|
||||
|
||||
- Backend: openai-chat (suggests the client uses the OpenAI-compatible chat API format)
|
||||
|
||||
- Dataset Source: Hugging Face (hf)
|
||||
|
||||
- Dataset Split: train
|
||||
|
||||
- Endpoint: /v1/chat/completions (the REST API endpoint to which chat requests are sent)
|
||||
|
||||
- Dataset Path: lmarena-ai/vision-arena-bench-v0.1 (the benchmark dataset used for evaluation, hosted on Hugging Face)
|
||||
|
||||
- Number of Prompts: 200 (the total number of prompts used during the test)
|
||||
|
||||
|
||||
|
||||
## Run benchmarks
|
||||
|
||||
### Use benchmark script
|
||||
The provided scripts automatically execute performance tests for serving, throughput, and latency. To start the benchmarking process, run command in the vllm-ascend root directory:
|
||||
```
|
||||
bash benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
Once the script completes, you can find the results in the benchmarks/results folder. The output files may resemble the following:
|
||||
```
|
||||
|-- latency_llama8B_tp1.json
|
||||
|-- serving_llama8B_tp1_sharegpt_qps_1.json
|
||||
|-- serving_llama8B_tp1_sharegpt_qps_16.json
|
||||
|-- serving_llama8B_tp1_sharegpt_qps_4.json
|
||||
|-- serving_llama8B_tp1_sharegpt_qps_inf.json
|
||||
|-- throughput_llama8B_tp1.json
|
||||
.
|
||||
|-- serving_qwen2_5_7B_tp1_qps_1.json
|
||||
|-- serving_qwen2_5_7B_tp1_qps_16.json
|
||||
|-- serving_qwen2_5_7B_tp1_qps_4.json
|
||||
|-- serving_qwen2_5_7B_tp1_qps_inf.json
|
||||
|-- latency_qwen2_5_7B_tp1.json
|
||||
|-- throughput_qwen2_5_7B_tp1.json
|
||||
```
|
||||
These files contain detailed benchmarking results for further analysis.
|
||||
|
||||
### Use benchmark cli
|
||||
|
||||
For more flexible and customized use, benchmark cli is also provided to run online/offline benchmarks
|
||||
Similarly, let’s take `Qwen2.5-VL-7B-Instruct` benchmark as an example:
|
||||
#### Online serving
|
||||
1. Launch the server:
|
||||
```shell
|
||||
vllm serve Qwen2.5-VL-7B-Instruct --max-model-len 16789
|
||||
```
|
||||
2. Running performance tests using cli
|
||||
```shell
|
||||
vllm bench serve --model Qwen2.5-VL-7B-Instruct\
|
||||
--endpoint-type "openai-chat" --dataset-name hf \
|
||||
--hf-split train --endpoint "/v1/chat/completions" \
|
||||
--dataset-path "lmarena-ai/vision-arena-bench-v0.1" \
|
||||
--num-prompts 200 \
|
||||
--request-rate 16
|
||||
```
|
||||
|
||||
#### Offline
|
||||
- **Throughput**
|
||||
```shell
|
||||
vllm bench throughput --output-json results/throughput_qwen2_5_7B_tp1.json \
|
||||
--model Qwen/Qwen2.5-7B-Instruct --tensor-parallel-size 1 --load-format dummy \
|
||||
--dataset-path /github/home/.cache/datasets/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--num-prompts 200 --backend vllm
|
||||
```
|
||||
- **Latency**
|
||||
```shell
|
||||
vllm bench latency --output-json results/latency_qwen2_5_7B_tp1.json \
|
||||
--model Qwen/Qwen2.5-7B-Instruct --tensor-parallel-size 1 \
|
||||
--load-format dummy --num-iters-warmup 5 --num-iters 15
|
||||
```
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import libcst as cst
|
||||
@ -44,16 +45,22 @@ class StreamingFalseTransformer(cst.CSTTransformer):
|
||||
|
||||
|
||||
def patch_file(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
if not os.path.exists(abs_path):
|
||||
print(f"File not found: {abs_path}")
|
||||
return
|
||||
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
module = cst.parse_module(source)
|
||||
modified = module.visit(StreamingFalseTransformer())
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
with open(abs_path, "w", encoding="utf-8") as f:
|
||||
f.write(modified.code)
|
||||
|
||||
print(f"Patched: {path}")
|
||||
print(f"Patched: {abs_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -61,8 +68,10 @@ if __name__ == '__main__':
|
||||
description=
|
||||
"Patch benchmark_dataset.py to set streaming=False in load_dataset calls"
|
||||
)
|
||||
parser.add_argument("--path",
|
||||
type=str,
|
||||
help="Path to the benchmark_dataset.py file")
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
type=str,
|
||||
default="/vllm-workspace/vllm/vllm/benchmarks/datasets.py",
|
||||
help="Path to the benchmark_dataset.py file")
|
||||
args = parser.parse_args()
|
||||
patch_file(args.path)
|
||||
|
@ -54,13 +54,20 @@ json2args() {
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
timeout 1200 bash -c '
|
||||
until curl -s -X GET localhost:8000/health; do
|
||||
echo "Waiting for vllm server to start..."
|
||||
sleep 1
|
||||
done' && return 0 || return 1
|
||||
local waited=0
|
||||
local timeout_sec=1200
|
||||
|
||||
while (( waited < timeout_sec )); do
|
||||
if curl -s -X GET localhost:8000/health > /dev/null; then
|
||||
return 0
|
||||
fi
|
||||
echo "Waiting for vllm server to start..."
|
||||
sleep 1
|
||||
((waited++))
|
||||
done
|
||||
|
||||
echo "Timeout waiting for server"
|
||||
return 1
|
||||
}
|
||||
|
||||
get_cur_npu_id() {
|
||||
@ -114,7 +121,7 @@ run_latency_tests() {
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
|
||||
latency_command="python3 vllm_benchmarks/benchmark_latency.py \
|
||||
latency_command="vllm bench latency \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -157,7 +164,7 @@ run_throughput_tests() {
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
|
||||
throughput_command="python3 vllm_benchmarks/benchmark_throughput.py \
|
||||
throughput_command="vllm bench throughput \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -243,7 +250,7 @@ run_serving_tests() {
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
|
||||
client_command="python3 vllm_benchmarks/benchmark_serving.py \
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
@ -271,17 +278,11 @@ cleanup_on_error() {
|
||||
rm -rf $RESULTS_FOLDER
|
||||
}
|
||||
|
||||
get_benchmarks_scripts() {
|
||||
git clone -b main --depth=1 https://github.com/vllm-project/vllm.git && \
|
||||
mv vllm/benchmarks vllm_benchmarks
|
||||
rm -rf ./vllm
|
||||
}
|
||||
|
||||
main() {
|
||||
|
||||
START_TIME=$(date +%s)
|
||||
check_npus
|
||||
|
||||
python3 benchmarks/scripts/patch_benchmark_dataset.py
|
||||
|
||||
# dependencies
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get update && apt-get -y install jq)
|
||||
@ -298,8 +299,6 @@ main() {
|
||||
|
||||
# prepare for benchmarking
|
||||
cd benchmarks || exit 1
|
||||
get_benchmarks_scripts
|
||||
python3 scripts/patch_benchmark_dataset.py --path vllm_benchmarks/benchmark_dataset.py
|
||||
trap cleanup EXIT
|
||||
|
||||
QUICK_BENCHMARK_ROOT=./
|
||||
|
@ -18,7 +18,7 @@
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"backend": "openai-chat",
|
||||
"endpoint_type": "openai-chat",
|
||||
"dataset_name": "hf",
|
||||
"hf_split": "train",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
@ -44,7 +44,7 @@
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "Qwen/Qwen3-8B",
|
||||
"backend": "vllm",
|
||||
"endpoint_type": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "/github/home/.cache/datasets/ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
@ -68,7 +68,7 @@
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "Qwen/Qwen2.5-7B-Instruct",
|
||||
"backend": "vllm",
|
||||
"endpoint_type": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "/github/home/.cache/datasets/ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
|
Reference in New Issue
Block a user