Feature/benchmark/random mm data/images (#23119)
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
This commit is contained in:
committed by
GitHub
parent
2da02dd0d8
commit
0cb7b065c3
@ -59,6 +59,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🚧</td>
|
||||
<td><code>synthetic</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
|
344
tests/benchmarks/test_random_dataset.py
Normal file
344
tests/benchmarks/test_random_dataset.py
Normal file
@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any, NamedTuple, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
||||
SampleRequest)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||
# Use a small, commonly available tokenizer
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
num_requests: int
|
||||
prefix_len: int
|
||||
range_ratio: float
|
||||
input_len: int
|
||||
output_len: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def random_dataset_params() -> Params:
|
||||
return Params(num_requests=16,
|
||||
prefix_len=7,
|
||||
range_ratio=0.3,
|
||||
input_len=50,
|
||||
output_len=20)
|
||||
|
||||
|
||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
"""Project a SampleRequest into a comparable tuple."""
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||
|
||||
|
||||
def _collect_samples(dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
||||
samples = dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
return [_fingerprint_sample(s) for s in samples]
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_same_seed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||
|
||||
This guards against accidental reliance on Python's random or np.random
|
||||
in RandomDataset after moving to numpy.default_rng.
|
||||
"""
|
||||
p = random_dataset_params
|
||||
common_seed = 123
|
||||
dataset_a = RandomDataset(random_seed=common_seed)
|
||||
dataset_b = RandomDataset(random_seed=common_seed)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
# Perturb global RNG state to ensure isolation
|
||||
random.seed(999)
|
||||
_ = [random.random() for _ in range(100)]
|
||||
np.random.seed(888)
|
||||
_ = [np.random.random() for _ in range(100)]
|
||||
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a == b
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||
p = random_dataset_params
|
||||
seed_a = 0
|
||||
dataset_a = RandomDataset(random_seed=seed_a)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
seed_b = 999
|
||||
dataset_b = RandomDataset(random_seed=seed_b)
|
||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||
random.seed(seed_a)
|
||||
np.random.seed(seed_a)
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a != b
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# RandomMultiModalDataset tests
|
||||
# -----------------------------
|
||||
|
||||
def _mm_fingerprint_sample(
|
||||
req: SampleRequest,
|
||||
) -> tuple[str, int, int, int, list[str]]:
|
||||
"""Create a compact fingerprint for multimodal samples.
|
||||
|
||||
Includes:
|
||||
- prompt string
|
||||
- prompt_len
|
||||
- expected_output_len
|
||||
- count of multimodal items
|
||||
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
|
||||
"""
|
||||
items = req.multi_modal_data or []
|
||||
item_prefixes: list[str] = []
|
||||
for it in items:
|
||||
if isinstance(it, dict) and it.get("type") == "image_url":
|
||||
url = it.get("image_url", {}).get("url", "")
|
||||
# Only keep a short identifying prefix to avoid huge strings
|
||||
item_prefixes.append(f"image:{url[:22]}")
|
||||
elif isinstance(it, dict) and it.get("type") == "video_url":
|
||||
url = it.get("video_url", {}).get("url", "")
|
||||
item_prefixes.append(f"video:{url[:22]}")
|
||||
else:
|
||||
item_prefixes.append("unknown:")
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
||||
item_prefixes)
|
||||
|
||||
|
||||
def _collect_mm_samples(
|
||||
dataset: RandomMultiModalDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
num_requests: int = 8,
|
||||
prefix_len: int = 3,
|
||||
range_ratio: float = 0.0,
|
||||
input_len: int = 20,
|
||||
output_len: int = 5,
|
||||
base_items_per_request: int = 2,
|
||||
num_mm_items_range_ratio: float = 0.0,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
) -> list[SampleRequest]:
|
||||
if limit_mm_per_prompt is None:
|
||||
limit_mm_per_prompt = {"image": 5, "video": 0}
|
||||
if bucket_config is None:
|
||||
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
|
||||
return dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
base_items_per_request=base_items_per_request,
|
||||
num_mm_items_range_ratio=num_mm_items_range_ratio,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
enable_multimodal_chat=enable_multimodal_chat,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
seed = 42
|
||||
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa == fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds_a = RandomMultiModalDataset(random_seed=0)
|
||||
ds_b = RandomMultiModalDataset(random_seed=999)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa != fb
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_respects_limits(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Requesting 3 items with a per-prompt limit of 1 should error per current
|
||||
# design (dataset refuses to silently clamp below the requested baseline).
|
||||
with pytest.raises(ValueError):
|
||||
_collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=12,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_prob_entries_are_removed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Second bucket has zero probability and should be ignored after
|
||||
# normalization
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=6,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 10, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert isinstance(s.multi_modal_data, list)
|
||||
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
for it in typed_mm:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=0,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 5, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert s.multi_modal_data == []
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_num_items_per_prompt(
|
||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Fixed number of images per prompt
|
||||
# set num_mm_items_range_ratio to 0.0
|
||||
# TODO: modify video values when video sampling is implemented
|
||||
samples_fixed_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 3, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with 3 mm items per prompt
|
||||
assert len(samples_fixed_items) == 5
|
||||
for s in samples_fixed_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 3
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_bucket_config_not_mutated(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# This bucket config is not normalized to sum to 1
|
||||
# and has more buckets than requested images
|
||||
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
|
||||
# Keep a snapshot to compare after sampling
|
||||
snapshot = dict(original)
|
||||
|
||||
_ = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=4,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config=original,
|
||||
)
|
||||
|
||||
# Ensure the original dict content is unchanged
|
||||
assert original == snapshot
|
||||
|
||||
|
||||
# Vary number of mm items per prompt
|
||||
# set num_mm_items_range_ratio to 0.5
|
||||
samples_varying_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.5,
|
||||
limit_mm_per_prompt={"image": 4, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with less than 4 mm items per prompt
|
||||
# but at least 1 mm item per prompt
|
||||
assert len(samples_varying_items) == 5
|
||||
for s in samples_varying_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) <= 4
|
||||
assert len(mm_data) >= 1
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
@ -11,18 +11,21 @@ generation. Supported dataset types include:
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
"""
|
||||
import ast
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterator, Mapping
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -114,7 +117,9 @@ class BenchmarkDataset(ABC):
|
||||
def apply_multimodal_chat_transformation(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||
mm_content: Optional[
|
||||
Union[MultiModalDataDict, dict, list[dict]]
|
||||
] = None) -> list[dict]:
|
||||
"""
|
||||
Transform a prompt and optional multimodal content into a chat format.
|
||||
This method is used for chat models that expect a specific conversation
|
||||
@ -122,7 +127,15 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
content = [{"text": prompt, "type": "text"}]
|
||||
if mm_content is not None:
|
||||
content.append(mm_content)
|
||||
if isinstance(mm_content, list):
|
||||
content.extend(cast(list[dict[str, Any]], mm_content))
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Could not process multimodal content of type: " +
|
||||
f"{type(mm_content)}"
|
||||
)
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]:
|
||||
|
||||
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||
|
||||
Strategy:
|
||||
- Sample input/output token lengths per request from integer-uniform ranges
|
||||
around configured means (controlled by range_ratio).
|
||||
- Prepend a fixed random prefix of length prefix_len.
|
||||
- Generate the remaining tokens as a reproducible sequence:
|
||||
(offset + index + arange(input_len)) % vocab_size.
|
||||
- Decode then re-encode/truncate to ensure prompt token counts match.
|
||||
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
|
||||
"""
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 0.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
random.seed(self.random_seed)
|
||||
np.random.seed(self.random_seed)
|
||||
# Use numpy's default_rng for deterministic sampling
|
||||
# Do not use random.seed() or np.random.seed() elsewhere in this class.
|
||||
# This ensures that the RNG is isolated from global RNG state.
|
||||
self._rng = np.random.default_rng(self.random_seed)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
assert range_ratio < 1.0, (
|
||||
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
||||
|
||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||
)
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
||||
real_input_len = input_len - num_special_tokens
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
logger.info(
|
||||
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
||||
input_low, input_high, output_low, output_high)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
||||
vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
# After decoding the prompt we have to encode and decode it again.
|
||||
# This is done because in some cases N consecutive tokens
|
||||
# give a string tokenized into != N number of tokens.
|
||||
# For example for GPT2Tokenizer:
|
||||
# [6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
# To avoid uncontrolled change of the prompt length,
|
||||
# the encoded sequence is truncated before being decode again.
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
re_encoded_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=False)[:total_input_len]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = len(re_encoded_sequence)
|
||||
prompt, total_input_len = self.generate_token_sequence(
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=prefix_token_ids,
|
||||
prefix_len=prefix_len,
|
||||
vocab_size=vocab_size,
|
||||
input_len=int(input_lens[i]),
|
||||
offset=int(offsets[i]),
|
||||
index=i,
|
||||
)
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
)
|
||||
)
|
||||
return requests
|
||||
|
||||
def get_prefix(
|
||||
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the prefix for the dataset.
|
||||
"""
|
||||
return (
|
||||
self._rng.integers(
|
||||
0, tokenizer.vocab_size, size=prefix_len).tolist()
|
||||
if prefix_len > 0
|
||||
else []
|
||||
)
|
||||
|
||||
def get_sampling_params(
|
||||
self,
|
||||
num_requests: int,
|
||||
range_ratio: float,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get the sampling parameters for the dataset.
|
||||
"""
|
||||
# Enforce range_ratio < 1
|
||||
if not (0.0 <= range_ratio < 1.0):
|
||||
raise ValueError("range_ratio must be in [0, 1).")
|
||||
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
|
||||
real_input_len = max(0, int(input_len) - num_special_tokens)
|
||||
# Bounds use floor for low and ceil for high
|
||||
input_low = math.floor(real_input_len * (1 - range_ratio))
|
||||
input_high = math.ceil(real_input_len * (1 + range_ratio))
|
||||
output_low = math.floor(output_len * (1 - range_ratio))
|
||||
output_high = math.ceil(output_len * (1 + range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
output_low = max(output_low, 1)
|
||||
|
||||
if input_low > input_high:
|
||||
raise ValueError(
|
||||
"Invalid input sampling interval: "
|
||||
f"low={input_low} > high={input_high}"
|
||||
)
|
||||
if output_low > output_high:
|
||||
raise ValueError(
|
||||
"Invalid output sampling interval: "
|
||||
f"low={output_low} > high={output_high}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
||||
input_low,
|
||||
input_high,
|
||||
output_low,
|
||||
output_high,
|
||||
)
|
||||
|
||||
input_lens = self._rng.integers(input_low, input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = self._rng.integers(output_low, output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = self._rng.integers(0, tokenizer.vocab_size,
|
||||
size=num_requests)
|
||||
return input_lens, output_lens, offsets
|
||||
|
||||
|
||||
def generate_token_sequence(
|
||||
self,
|
||||
*,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
prefix_token_ids: list[int],
|
||||
prefix_len: int,
|
||||
vocab_size: int,
|
||||
input_len: int,
|
||||
offset: int,
|
||||
index: int,
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
Returns (prompt, total_input_len).
|
||||
|
||||
NOTE: After decoding the prompt we have to encode and decode it again.
|
||||
This is done because in some cases N consecutive tokens
|
||||
give a string tokenized into != N number of tokens.
|
||||
For example for GPT2Tokenizer:
|
||||
[6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
To avoid uncontrolled change of the prompt length,
|
||||
the encoded sequence is truncated before being decode again.
|
||||
"""
|
||||
# Build the inner sequence by sampling sequentially from the vocab
|
||||
inner_seq = ((offset + index + np.arange(input_len))
|
||||
% vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
|
||||
# Decode, then re-encode and truncate to preserve token count invariants
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
total_input_len = prefix_len + int(input_len)
|
||||
|
||||
re_encoded_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=False)[:total_input_len]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = len(re_encoded_sequence)
|
||||
|
||||
return prompt, total_input_len
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MultiModalDataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class RandomMultiModalDataset(RandomDataset):
|
||||
"""
|
||||
Synthetic multimodal dataset (text + images) that extends RandomDataset.
|
||||
|
||||
Status:
|
||||
- Images: supported via synthetic RGB data.
|
||||
- Video: not yet supported (TODO: implement video generation method).
|
||||
- Audio: not yet supported.
|
||||
|
||||
Sampling overview:
|
||||
1) Number of items per request is sampled uniformly from the integer range
|
||||
[floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
|
||||
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
|
||||
The maximum is further clamped to the sum of per-modality limits.
|
||||
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
|
||||
mapping (height, width, num_frames) → probability. We treat
|
||||
`num_frames`=1 as image and and `num_frames` > 1 as video.
|
||||
Entries with zero probability are removed and the rest are renormalized
|
||||
to sum to 1.
|
||||
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
|
||||
When a modality reaches its cap, all of its buckets are excluded and the
|
||||
remaining probabilities are renormalized.
|
||||
|
||||
Example bucket configuration:
|
||||
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
|
||||
- Two image buckets (`num_frames`=1) and one video bucket
|
||||
(`num_frames`=16).
|
||||
OBS.: Only image sampling is supported for now.
|
||||
"""
|
||||
|
||||
IS_MULTIMODAL = True
|
||||
# NOTE: video sampling is WIP. Setting it to 0.
|
||||
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
|
||||
|
||||
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
|
||||
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
|
||||
DEFAULT_MM_ITEM_BUCKET_CONFIG = {
|
||||
(256, 256, 1): 0.5,
|
||||
(720, 1280, 1): 0.5,
|
||||
(720, 1280, 16): 0.0,
|
||||
}
|
||||
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
|
||||
"""Generate synthetic PIL image with random RGB values.
|
||||
|
||||
NOTE: iid pixel sampling results in worst-case compression
|
||||
(good for stressing I/O), but very unlike real photos.
|
||||
We could consider a “low-freq” mode (e.g., noise blur)
|
||||
to emulate network realism instead of max stress.
|
||||
"""
|
||||
random_pixels = self._rng.integers(
|
||||
0,
|
||||
256,
|
||||
(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
return Image.fromarray(random_pixels)
|
||||
|
||||
def generate_synthetic_video(self, width: int,
|
||||
height: int,
|
||||
num_frames: int) -> Any:
|
||||
"""Generate synthetic video with random values.
|
||||
|
||||
TODO: Finish this method.
|
||||
"""
|
||||
raise NotImplementedError("Video sampling is WIP.")
|
||||
|
||||
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
|
||||
"""Map the configuration to the modality."""
|
||||
if config[-1] == 1:
|
||||
return "image"
|
||||
elif config[-1] > 1:
|
||||
return "video"
|
||||
else:
|
||||
raise ValueError(f"Invalid multimodal item configuration: {config}")
|
||||
|
||||
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
|
||||
float]) -> dict[tuple[int, int, int], float]:
|
||||
"""
|
||||
Remove zero probability entries
|
||||
and normalize the bucket config to sum to 1.
|
||||
"""
|
||||
# Raise error if value is negative
|
||||
if any(v < 0 for v in bucket_config.values()):
|
||||
raise ValueError("Bucket config values must be non-negative.")
|
||||
# Remove zero probability entries
|
||||
bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
|
||||
# if bucket config is empty, raise error
|
||||
if not bucket_config:
|
||||
raise ValueError("Got invalid bucket config. "
|
||||
"Bucket config values must be non-zero.")
|
||||
# Normalize the remaining bucket config to sum to 1
|
||||
total = sum(bucket_config.values())
|
||||
return {k: v / total for k, v in bucket_config.items()}
|
||||
|
||||
|
||||
def generate_mm_item(self,
|
||||
mm_item_config: tuple[int, int, int],
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Create synthetic images and videos and
|
||||
apply process_image/process_video respectively.
|
||||
This follows the OpenAI API chat completions
|
||||
https://github.com/openai/openai-python
|
||||
"""
|
||||
|
||||
if self.map_config_to_modality(mm_item_config) == "image":
|
||||
return process_image(self.generate_synthetic_image(
|
||||
mm_item_config[1],
|
||||
mm_item_config[0]))
|
||||
elif self.map_config_to_modality(mm_item_config) == "video":
|
||||
return process_video(self.generate_synthetic_video(
|
||||
mm_item_config[1],
|
||||
mm_item_config[0],
|
||||
mm_item_config[2]))
|
||||
else:
|
||||
raise ValueError(f"Invalid multimodal item configuration: "
|
||||
f"{mm_item_config}")
|
||||
|
||||
|
||||
def get_mm_item_sampling_params(
|
||||
self,
|
||||
base_items_per_request: int,
|
||||
num_mm_items_range_ratio: float,
|
||||
limit_mm_per_prompt: dict[str, int],
|
||||
bucket_config: dict[tuple[int, int, int], float],
|
||||
) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
|
||||
"""
|
||||
Get the sampling parameters for the multimodal items.
|
||||
"""
|
||||
# Enforce num_mm_items_range_ratio <= 1
|
||||
if not (0.0 <= num_mm_items_range_ratio <= 1.0):
|
||||
raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
|
||||
|
||||
# Ensure modalities to sample are in limit_mm_per_prompt
|
||||
for k, v in bucket_config.items():
|
||||
# get modality from bucket config
|
||||
modality = self.map_config_to_modality(k)
|
||||
if modality not in limit_mm_per_prompt:
|
||||
raise ValueError(f"Modality {modality} is not in "
|
||||
f"limit_mm_per_prompt: "
|
||||
f"{limit_mm_per_prompt.keys()}")
|
||||
|
||||
# Remove zero probability entries
|
||||
# and normalize bucket config to sum to 1
|
||||
bucket_config = self.normalize_bucket_config(bucket_config)
|
||||
logger.info(
|
||||
"Normalized bucket config: %s", bucket_config,
|
||||
)
|
||||
# Only consider limit per prompt for modalities in bucket config
|
||||
allowed_modalities = {self.map_config_to_modality(cfg)
|
||||
for cfg in bucket_config}
|
||||
limit_mm_per_prompt = {
|
||||
k: v for k, v in limit_mm_per_prompt.items()
|
||||
if k in allowed_modalities}
|
||||
if not limit_mm_per_prompt:
|
||||
raise ValueError("No valid limits for modalities present in "
|
||||
"bucket_config.")
|
||||
|
||||
logger.info(
|
||||
"Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
# Get max and min num mm items and ensure
|
||||
# it is at most the sum of limit_mm_per_prompt for all modalities
|
||||
max_num_mm_items = min(
|
||||
sum(limit_mm_per_prompt.values()),
|
||||
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
|
||||
)
|
||||
# Ensure min num mm items is at least 0
|
||||
min_num_mm_items = max(
|
||||
0,
|
||||
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
|
||||
)
|
||||
# Raise error if min num mm items is greater than max num mm items
|
||||
if min_num_mm_items > max_num_mm_items:
|
||||
raise ValueError(f"Min num mm items is greater than max mm items: "
|
||||
f"{min_num_mm_items} > {max_num_mm_items}")
|
||||
|
||||
logger.info(
|
||||
"Sampling number of multimodal items from [%s, %s]",
|
||||
min_num_mm_items, max_num_mm_items,
|
||||
)
|
||||
|
||||
return (
|
||||
min_num_mm_items,
|
||||
max_num_mm_items,
|
||||
limit_mm_per_prompt,
|
||||
bucket_config,
|
||||
)
|
||||
|
||||
def get_mm_item_iterator(
|
||||
self,
|
||||
min_num_mm_items: int,
|
||||
max_num_mm_items: int,
|
||||
bucket_config: dict[tuple[int, int, int], float],
|
||||
limit_mm_per_prompt: dict[str, int],
|
||||
) -> Iterator[tuple[int,int, int]]:
|
||||
"""
|
||||
Iterator over the multimodal items for each request
|
||||
whose size is between min_num_mm_items and max_num_mm_items.
|
||||
|
||||
Loop over the bucket config and sample a multimodal item.
|
||||
Loop until the number of multimodal items sampled is equal to
|
||||
request_num_mm_items or limit of multimodal items per prompt
|
||||
for all modalities is reached.
|
||||
|
||||
Note:
|
||||
- This function operates on a per-request shallow copy of
|
||||
`bucket_config` (tuple->float). The original dict passed to
|
||||
`sample` is not mutated. If this ever changes, a test
|
||||
is implemented and will fail.
|
||||
"""
|
||||
# Get the number of multimodal items to sample
|
||||
request_num_mm_items = int(
|
||||
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
|
||||
)
|
||||
# If request_num_mm_items is 0, yield an empty iterator
|
||||
if request_num_mm_items == 0:
|
||||
return
|
||||
# Initialize modality counters
|
||||
modality_counter = {self.map_config_to_modality(k): 0
|
||||
for k in bucket_config}
|
||||
# Copy the bucket config to avoid modifying the original
|
||||
bucket_config_copy = bucket_config.copy()
|
||||
# Loop over the number of multimodal items to sample
|
||||
while sum(modality_counter.values()) < request_num_mm_items:
|
||||
# Sample a multimodal item config
|
||||
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
|
||||
p=list(bucket_config_copy.values()))
|
||||
modality = self.map_config_to_modality(mm_item_config)
|
||||
# Check that modality count is less than limit per prompt
|
||||
if modality_counter[modality] < limit_mm_per_prompt[modality]:
|
||||
modality_counter[modality] += 1
|
||||
yield (
|
||||
mm_item_config
|
||||
)
|
||||
else:
|
||||
# If the counter is greater than the limit per prompt
|
||||
# set all multimodal items of this modality to 0
|
||||
for k, v in bucket_config_copy.items():
|
||||
if self.map_config_to_modality(k) == modality:
|
||||
bucket_config_copy[k] = 0
|
||||
# If all configs are 0, break the loop
|
||||
# This should not happen as request_num_mm_items is at most
|
||||
# the sum of limit_mm_per_prompt for all modalities
|
||||
if all(v == 0 for v in bucket_config_copy.values()):
|
||||
logger.warning("Exhausted all multimodal items "
|
||||
"of modality %s",
|
||||
modality)
|
||||
break
|
||||
# Renormalize the bucket config
|
||||
bucket_config_copy = self.normalize_bucket_config(
|
||||
bucket_config_copy)
|
||||
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
|
||||
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
|
||||
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
|
||||
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
|
||||
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
|
||||
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
|
||||
bucket_config: dict[tuple[int, int, int], float] =
|
||||
DEFAULT_MM_ITEM_BUCKET_CONFIG,
|
||||
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
|
||||
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
|
||||
# and probability is non-zero.
|
||||
if any(self.map_config_to_modality(cfg) == "video" and p > 0
|
||||
for cfg, p in bucket_config.items()):
|
||||
raise NotImplementedError("Video sampling not implemented; "
|
||||
"set its probability to 0.")
|
||||
|
||||
# Get the sampling parameters for the dataset
|
||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||
)
|
||||
|
||||
(
|
||||
min_num_mm_items,
|
||||
max_num_mm_items,
|
||||
limit_mm_per_prompt,
|
||||
bucket_config,
|
||||
) = self.get_mm_item_sampling_params(
|
||||
base_items_per_request,
|
||||
num_mm_items_range_ratio,
|
||||
limit_mm_per_prompt,
|
||||
bucket_config,
|
||||
)
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Add synthetic multimodal items to each request
|
||||
mm_requests = []
|
||||
for i in range(num_requests):
|
||||
prompt, total_input_len = self.generate_token_sequence(
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=prefix_token_ids,
|
||||
prefix_len=prefix_len,
|
||||
vocab_size=vocab_size,
|
||||
input_len=int(input_lens[i]),
|
||||
offset=int(offsets[i]),
|
||||
index=i,
|
||||
)
|
||||
# Get multimodal item iterator for a given request
|
||||
mm_item_iterator = self.get_mm_item_iterator(
|
||||
min_num_mm_items,
|
||||
max_num_mm_items,
|
||||
bucket_config,
|
||||
limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
mm_content = cast(list[dict[str, Any]], [
|
||||
self.generate_mm_item(mm_item_config)
|
||||
for mm_item_config in mm_item_iterator
|
||||
])
|
||||
|
||||
if enable_multimodal_chat:
|
||||
# NOTE: For now this option is only provided for completeness
|
||||
# given that the serve.py benchmark currently does not use it.
|
||||
mm_chat_prompt: Any = prompt
|
||||
mm_chat_prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
sample_request = SampleRequest(
|
||||
prompt=mm_chat_prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
multi_modal_data=None,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
else:
|
||||
sample_request = SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
mm_requests.append(sample_request)
|
||||
return mm_requests
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ShareGPT Dataset Implementation
|
||||
@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
type=str,
|
||||
default="random",
|
||||
choices=[
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
|
||||
"prefix_repetition"
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||
"custom", "prefix_repetition"
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
|
||||
# random multimodal dataset options
|
||||
random_mm_group = parser.add_argument_group(
|
||||
"random multimodal dataset options extended from random dataset")
|
||||
random_mm_group.add_argument(
|
||||
"--random-mm-base-items-per-request",
|
||||
type=int,
|
||||
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
|
||||
help=(
|
||||
"Base number of multimodal items per request for random-mm. "
|
||||
"Actual per-request count is sampled around this base using "
|
||||
"--random-mm-num-mm-items-range-ratio."
|
||||
),
|
||||
)
|
||||
random_mm_group.add_argument(
|
||||
"--random-mm-num-mm-items-range-ratio",
|
||||
type=float,
|
||||
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
|
||||
help=(
|
||||
"Range ratio r in [0, 1] for sampling items per request. "
|
||||
"We sample uniformly from the closed integer range "
|
||||
"[floor(n*(1-r)), ceil(n*(1+r))] "
|
||||
"where n is the base items per request. "
|
||||
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
|
||||
"to the sum of per-modality limits from "
|
||||
"--random-mm-limit-mm-per-prompt. "
|
||||
"An error is raised if the computed min exceeds the max."
|
||||
),
|
||||
)
|
||||
random_mm_group.add_argument(
|
||||
"--random-mm-limit-mm-per-prompt",
|
||||
type=json.loads,
|
||||
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
|
||||
help=(
|
||||
"Per-modality hard caps for items attached per request, e.g. "
|
||||
"'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
|
||||
"count is clamped to the sum of these limits. When a modality "
|
||||
"reaches its cap, its buckets are excluded and probabilities are "
|
||||
"renormalized."
|
||||
"OBS.: Only image sampling is supported for now."
|
||||
),
|
||||
)
|
||||
|
||||
def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
|
||||
# If already a dict (e.g., programmatic call), normalize keys
|
||||
def normalize(d: dict) -> dict[tuple[int, int, int], float]:
|
||||
out: dict[tuple[int, int, int], float] = {}
|
||||
for k, val in d.items():
|
||||
key = k
|
||||
if isinstance(key, str):
|
||||
with suppress(Exception):
|
||||
key = ast.literal_eval(key)
|
||||
if not (isinstance(key, tuple) and len(key) == 3
|
||||
and all(isinstance(x, int) for x in key)):
|
||||
raise ValueError(
|
||||
f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
|
||||
)
|
||||
out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
|
||||
return out
|
||||
|
||||
if isinstance(v, dict):
|
||||
return normalize(v)
|
||||
if isinstance(v, str):
|
||||
# Python literal (supports tuple keys)
|
||||
parsed = ast.literal_eval(v)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("Bucket config must parse to a dict.")
|
||||
return normalize(parsed)
|
||||
raise ValueError("Unsupported value for --random-mm-bucket-config.")
|
||||
|
||||
random_mm_group.add_argument(
|
||||
"--random-mm-bucket-config",
|
||||
type=_parse_mm_bucket_config,
|
||||
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
|
||||
help=(
|
||||
"The bucket config is a dictionary mapping a multimodal item"
|
||||
"sampling configuration to a probability."
|
||||
"Currently allows for 2 modalities: images and videos. "
|
||||
"An bucket key is a tuple of (height, width, num_frames)"
|
||||
"The value is the probability of sampling that specific item. "
|
||||
"Example: "
|
||||
"--random-mm-bucket-config "
|
||||
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
|
||||
"First item: images with resolution 256x256 w.p. 0.5"
|
||||
"Second item: images with resolution 720x1280 w.p. 0.4 "
|
||||
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
|
||||
"OBS.: If the probabilities do not sum to 1, they are normalized."
|
||||
"OBS bis.: Only image sampling is supported for now."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
@ -821,6 +1372,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"random-mm":
|
||||
lambda: RandomMultiModalDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
base_items_per_request=args.random_mm_base_items_per_request,
|
||||
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
|
||||
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
|
||||
bucket_config=args.random_mm_bucket_config,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"prefix_repetition":
|
||||
lambda: PrefixRepetitionRandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
@ -836,6 +1403,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
}
|
||||
|
||||
try:
|
||||
# Enforce endpoint compatibility for multimodal datasets.
|
||||
if args.dataset_name == "random-mm" and args.endpoint_type not in [
|
||||
"openai-chat"]:
|
||||
raise ValueError(
|
||||
"Multi-modal content (images) is only supported on "
|
||||
"'openai-chat' backend."
|
||||
)
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
|
Reference in New Issue
Block a user