[Misc][Benchmark] Add support for CustomDataset (#18511)

This commit is contained in:
Ekagra Ranjan
2025-05-31 15:07:38 -04:00
committed by GitHub
parent 20079c6e36
commit bbfa0c61d1
5 changed files with 264 additions and 8 deletions

View File

@ -64,6 +64,12 @@ become available.
<td style="text-align: center;"></td> <td style="text-align: center;"></td>
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td> <td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
</tr> </tr>
<tr>
<td><strong>Custom</strong></td>
<td style="text-align: center;"></td>
<td style="text-align: center;"></td>
<td>Local file: <code>data.jsonl</code></td>
</tr>
</tbody> </tbody>
</table> </table>
@ -124,6 +130,38 @@ P99 ITL (ms): 8.39
================================================== ==================================================
``` ```
### Custom Dataset
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
```
{"prompt": "What is the capital of India?"}
{"prompt": "What is the capital of Iran?"}
{"prompt": "What is the capital of China?"}
```
```bash
# start server
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
```
```bash
# run benchmarking script
python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
--backend vllm \
--model meta-llama/Llama-3.1-8B-Instruct \
--endpoint /v1/completions \
--dataset-name custom \
--dataset-path <path-to-your-data-jsonl> \
--custom-skip-chat-template \
--num-prompts 80 \
--max-concurrency 1 \
--temperature=0.3 \
--top-p=0.75 \
--result-dir "./log/"
```
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
### VisionArena Benchmark for Vision Language Models ### VisionArena Benchmark for Vision Language Models
```bash ```bash
@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \
--seed 42 --seed 42
``` ```
**`philschmid/mt-bench`**
``` bash
python3 vllm/benchmarks/benchmark_serving.py \
--model Qwen/QwQ-32B \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts 80
```
### Running With Sampling Parameters ### Running With Sampling Parameters
When using OpenAI-compatible backends such as `vllm`, optional sampling When using OpenAI-compatible backends such as `vllm`, optional sampling

View File

@ -9,9 +9,6 @@ generation. Supported dataset types include:
- BurstGPT - BurstGPT
- HuggingFace - HuggingFace
- VisionArena - VisionArena
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
""" """
import base64 import base64
@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset):
return samples return samples
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------
class CustomDataset(BenchmarkDataset):
"""
Implements the Custom dataset. Loads data from a JSONL file and generates
sample requests based on conversation turns. E.g.,
```
{"prompt": "What is the capital of India?"}
{"prompt": "What is the capital of Iran?"}
{"prompt": "What is the capital of China?"}
```
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
# self.data will be a list of dictionaries
# e.g., [{"prompt": "What is the capital of India?"}, ...]
# This will be the standardized format which load_data()
# has to convert into depending on the filetype of dataset_path.
# sample() will assume this standardized format of self.data
self.data = []
# Load the JSONL file
if self.dataset_path.endswith(".jsonl"):
jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
# check if the JSONL file has a 'prompt' column
if "prompt" not in jsonl_data.columns:
raise ValueError("JSONL file must contain a 'prompt' column.")
# Convert each row to a dictionary and append to self.data
# This will convert the DataFrame to a list of dictionaries
# where each dictionary corresponds to a row in the DataFrame.
# This is the standardized format we want for self.data
for _, row in jsonl_data.iterrows():
self.data.append(row.to_dict())
else:
raise NotImplementedError(
"Only JSONL format is supported for CustomDataset."
)
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
**kwargs,
) -> list:
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
# apply template
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Sonnet Dataset Implementation # Sonnet Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -60,6 +60,7 @@ from benchmark_dataset import (
ASRDataset, ASRDataset,
BurstGPTDataset, BurstGPTDataset,
ConversationDataset, ConversationDataset,
CustomDataset,
HuggingFaceDataset, HuggingFaceDataset,
InstructCoderDataset, InstructCoderDataset,
MTBenchDataset, MTBenchDataset,
@ -627,7 +628,16 @@ def main(args: argparse.Namespace):
"'--dataset-path' if required." "'--dataset-path' if required."
) )
if args.dataset_name == "sonnet": if args.dataset_name == "custom":
dataset = CustomDataset(dataset_path=args.dataset_path)
input_requests = dataset.sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
)
elif args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path) dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend. # For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat": if args.backend == "openai-chat":
@ -838,6 +848,8 @@ def main(args: argparse.Namespace):
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
if field in benchmark_result:
del benchmark_result[field]
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
@ -850,6 +862,7 @@ def main(args: argparse.Namespace):
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open( with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8" file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
@ -890,7 +903,7 @@ if __name__ == "__main__":
"--dataset-name", "--dataset-name",
type=str, type=str,
default="sharegpt", default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument( parser.add_argument(
@ -1060,6 +1073,19 @@ if __name__ == "__main__":
) )
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
custom_group.add_argument(
"--custom-output-len",
type=int,
default=256,
help="Number of output tokens per request, used only for custom dataset.",
)
custom_group.add_argument(
"--custom-skip-chat-template",
action="store_true",
help="Skip applying chat template to prompt, used only for custom dataset.",
)
sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument( sonnet_group.add_argument(
"--sonnet-input-len", "--sonnet-input-len",

View File

@ -9,9 +9,6 @@ generation. Supported dataset types include:
- BurstGPT - BurstGPT
- HuggingFace - HuggingFace
- VisionArena - VisionArena
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
""" """
import base64 import base64
import io import io
@ -26,6 +23,7 @@ from io import BytesIO
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
import pandas as pd
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset):
return samples return samples
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------
class CustomDataset(BenchmarkDataset):
"""
Implements the Custom dataset. Loads data from a JSONL file and generates
sample requests based on conversation turns. E.g.,
```
{"prompt": "What is the capital of India?"}
{"prompt": "What is the capital of Iran?"}
{"prompt": "What is the capital of China?"}
```
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
# self.data will be a list of dictionaries
# e.g., [{"prompt": "What is the capital of India?"}, ...]
# This will be the standardized format which load_data()
# has to convert into depending on the filetype of dataset_path.
# sample() will assume this standardized format of self.data
self.data = []
# Load the JSONL file
if self.dataset_path.endswith(".jsonl"):
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
lines=True)
# check if the JSONL file has a 'prompt' column
if "prompt" not in jsonl_data.columns:
raise ValueError("JSONL file must contain a 'prompt' column.")
# Convert each row to a dictionary and append to self.data
# This will convert the DataFrame to a list of dictionaries
# where each dictionary corresponds to a row in the DataFrame.
# This is the standardized format we want for self.data
for _, row in jsonl_data.iterrows():
self.data.append(row.to_dict())
else:
raise NotImplementedError(
"Only JSONL format is supported for CustomDataset.")
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
**kwargs,
) -> list:
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
# apply template
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Sonnet Dataset Implementation # Sonnet Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace):
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
if field in benchmark_result:
del benchmark_result[field]
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace):
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, with open(file_name,
mode="a+" if args.append_result else "w", mode="a+" if args.append_result else "w",