mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc][Benchmark] Add support for CustomDataset (#18511)
This commit is contained in:
@ -64,6 +64,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Custom</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>data.jsonl</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@ -124,6 +130,38 @@ P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Custom Dataset
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
|
||||
```bash
|
||||
# start server
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
# run benchmarking script
|
||||
python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name custom \
|
||||
--dataset-path <path-to-your-data-jsonl> \
|
||||
--custom-skip-chat-template \
|
||||
--num-prompts 80 \
|
||||
--max-concurrency 1 \
|
||||
--temperature=0.3 \
|
||||
--top-p=0.75 \
|
||||
--result-dir "./log/"
|
||||
```
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
**`philschmid/mt-bench`**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
|
@ -9,9 +9,6 @@ generation. Supported dataset types include:
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
|
||||
import base64
|
||||
@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Custom Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CustomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the Custom dataset. Loads data from a JSONL file and generates
|
||||
sample requests based on conversation turns. E.g.,
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
# self.data will be a list of dictionaries
|
||||
# e.g., [{"prompt": "What is the capital of India?"}, ...]
|
||||
# This will be the standardized format which load_data()
|
||||
# has to convert into depending on the filetype of dataset_path.
|
||||
# sample() will assume this standardized format of self.data
|
||||
self.data = []
|
||||
|
||||
# Load the JSONL file
|
||||
if self.dataset_path.endswith(".jsonl"):
|
||||
jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
|
||||
|
||||
# check if the JSONL file has a 'prompt' column
|
||||
if "prompt" not in jsonl_data.columns:
|
||||
raise ValueError("JSONL file must contain a 'prompt' column.")
|
||||
|
||||
# Convert each row to a dictionary and append to self.data
|
||||
# This will convert the DataFrame to a list of dictionaries
|
||||
# where each dictionary corresponds to a row in the DataFrame.
|
||||
# This is the standardized format we want for self.data
|
||||
for _, row in jsonl_data.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only JSONL format is supported for CustomDataset."
|
||||
)
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
|
||||
# apply template
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -60,6 +60,7 @@ from benchmark_dataset import (
|
||||
ASRDataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
CustomDataset,
|
||||
HuggingFaceDataset,
|
||||
InstructCoderDataset,
|
||||
MTBenchDataset,
|
||||
@ -627,7 +628,16 @@ def main(args: argparse.Namespace):
|
||||
"'--dataset-path' if required."
|
||||
)
|
||||
|
||||
if args.dataset_name == "sonnet":
|
||||
if args.dataset_name == "custom":
|
||||
dataset = CustomDataset(dataset_path=args.dataset_path)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
@ -838,6 +848,8 @@ def main(args: argparse.Namespace):
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
if field in benchmark_result:
|
||||
del benchmark_result[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
@ -850,6 +862,7 @@ def main(args: argparse.Namespace):
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(
|
||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
||||
@ -890,7 +903,7 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -1060,6 +1073,19 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
custom_group.add_argument(
|
||||
"--custom-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of output tokens per request, used only for custom dataset.",
|
||||
)
|
||||
custom_group.add_argument(
|
||||
"--custom-skip-chat-template",
|
||||
action="store_true",
|
||||
help="Skip applying chat template to prompt, used only for custom dataset.",
|
||||
)
|
||||
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
|
@ -9,9 +9,6 @@ generation. Supported dataset types include:
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
@ -26,6 +23,7 @@ from io import BytesIO
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Custom Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CustomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the Custom dataset. Loads data from a JSONL file and generates
|
||||
sample requests based on conversation turns. E.g.,
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
# self.data will be a list of dictionaries
|
||||
# e.g., [{"prompt": "What is the capital of India?"}, ...]
|
||||
# This will be the standardized format which load_data()
|
||||
# has to convert into depending on the filetype of dataset_path.
|
||||
# sample() will assume this standardized format of self.data
|
||||
self.data = []
|
||||
|
||||
# Load the JSONL file
|
||||
if self.dataset_path.endswith(".jsonl"):
|
||||
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
|
||||
lines=True)
|
||||
|
||||
# check if the JSONL file has a 'prompt' column
|
||||
if "prompt" not in jsonl_data.columns:
|
||||
raise ValueError("JSONL file must contain a 'prompt' column.")
|
||||
|
||||
# Convert each row to a dictionary and append to self.data
|
||||
# This will convert the DataFrame to a list of dictionaries
|
||||
# where each dictionary corresponds to a row in the DataFrame.
|
||||
# This is the standardized format we want for self.data
|
||||
for _, row in jsonl_data.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only JSONL format is supported for CustomDataset.")
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
|
||||
# apply template
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace):
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
if field in benchmark_result:
|
||||
del benchmark_result[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace):
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
|
Reference in New Issue
Block a user