mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
214 lines
6.5 KiB
Python
214 lines
6.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import gc
|
|
import time
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
|
|
from benchmark_utils import TimeCollector
|
|
from vllm.config import (
|
|
CacheConfig,
|
|
DeviceConfig,
|
|
LoadConfig,
|
|
ModelConfig,
|
|
ParallelConfig,
|
|
SchedulerConfig,
|
|
SpeculativeConfig,
|
|
VllmConfig,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import FlexibleArgumentParser
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
|
|
|
|
def benchmark_propose(args):
|
|
rows = []
|
|
for max_ngram in args.max_ngram:
|
|
collector = TimeCollector(TimeCollector.US)
|
|
|
|
model_config = ModelConfig(
|
|
model="facebook/opt-125m",
|
|
task="generate",
|
|
max_model_len=args.num_token + args.num_spec_token,
|
|
tokenizer="facebook/opt-125m",
|
|
tokenizer_mode="auto",
|
|
dtype="auto",
|
|
seed=None,
|
|
trust_remote_code=False,
|
|
)
|
|
proposer = NgramProposer(
|
|
vllm_config=VllmConfig(
|
|
model_config=model_config,
|
|
speculative_config=SpeculativeConfig(
|
|
prompt_lookup_min=args.min_ngram,
|
|
prompt_lookup_max=max_ngram,
|
|
num_speculative_tokens=args.num_spec_token,
|
|
method="ngram",
|
|
),
|
|
)
|
|
)
|
|
|
|
# Warm up
|
|
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
|
|
|
gc.collect()
|
|
for _ in range(args.num_iteration):
|
|
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
|
with collector:
|
|
for i in range(args.num_req):
|
|
proposer.propose(tokens[i, :])
|
|
rows.append(
|
|
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
|
+ collector.dump_avg_max()
|
|
)
|
|
|
|
print(
|
|
tabulate(
|
|
rows,
|
|
headers=[
|
|
"# Request",
|
|
"# Token",
|
|
"Min Ngram",
|
|
"Max Ngram",
|
|
"Avg (us)",
|
|
"Max (us)",
|
|
],
|
|
tablefmt="grid",
|
|
floatfmt=".3f",
|
|
)
|
|
)
|
|
|
|
|
|
def benchmark_batched_propose(args):
|
|
NUM_SPECULATIVE_TOKENS_NGRAM = 10
|
|
PROMPT_LOOKUP_MIN = 5
|
|
PROMPT_LOOKUP_MAX = 15
|
|
MAX_MODEL_LEN = int(1e7)
|
|
DEVICE = current_platform.device_type
|
|
|
|
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
|
|
|
|
speculative_config = SpeculativeConfig(
|
|
target_model_config=model_config,
|
|
target_parallel_config=ParallelConfig(),
|
|
method="ngram",
|
|
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
|
|
prompt_lookup_max=PROMPT_LOOKUP_MAX,
|
|
prompt_lookup_min=PROMPT_LOOKUP_MIN,
|
|
)
|
|
|
|
vllm_config = VllmConfig(
|
|
model_config=model_config,
|
|
cache_config=CacheConfig(),
|
|
speculative_config=speculative_config,
|
|
device_config=DeviceConfig(device=current_platform.device_type),
|
|
parallel_config=ParallelConfig(),
|
|
load_config=LoadConfig(),
|
|
scheduler_config=SchedulerConfig(),
|
|
)
|
|
|
|
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
|
mock_pp_group = mock.MagicMock()
|
|
mock_pp_group.world_size = 1
|
|
with mock.patch(
|
|
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
|
|
):
|
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
|
|
|
# hack max model len
|
|
runner.max_model_len = MAX_MODEL_LEN
|
|
runner.drafter.max_model_len = MAX_MODEL_LEN
|
|
|
|
dummy_input_batch = InputBatch(
|
|
max_num_reqs=args.num_req,
|
|
max_model_len=MAX_MODEL_LEN,
|
|
max_num_batched_tokens=args.num_req * args.num_token,
|
|
device=DEVICE,
|
|
pin_memory=False,
|
|
vocab_size=256000,
|
|
block_sizes=[16],
|
|
)
|
|
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
|
|
dummy_input_batch.spec_decode_unsupported_reqs = ()
|
|
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
|
|
dummy_input_batch.token_ids_cpu = np.random.randint(
|
|
0, 20, (args.num_req, args.num_token)
|
|
)
|
|
|
|
runner.input_batch = dummy_input_batch
|
|
|
|
sampled_token_ids = [[0]] * args.num_req
|
|
|
|
print("Starting benchmark")
|
|
# first run is warmup so ignore it
|
|
for _ in range(args.num_iteration):
|
|
start = time.time()
|
|
runner.drafter.propose(
|
|
sampled_token_ids,
|
|
dummy_input_batch.req_ids,
|
|
dummy_input_batch.num_tokens_no_spec,
|
|
dummy_input_batch.token_ids_cpu,
|
|
dummy_input_batch.spec_decode_unsupported_reqs,
|
|
)
|
|
end = time.time()
|
|
print(f"Iteration time (s): {end - start}")
|
|
|
|
|
|
def invoke_main() -> None:
|
|
parser = FlexibleArgumentParser(
|
|
description="Benchmark the performance of N-gram speculative decode drafting"
|
|
)
|
|
parser.add_argument(
|
|
"--batched", action="store_true", help="consider time to prepare batch"
|
|
) # noqa: E501
|
|
parser.add_argument(
|
|
"--num-iteration",
|
|
type=int,
|
|
default=100,
|
|
help="Number of iterations to run to stabilize final data readings",
|
|
)
|
|
parser.add_argument(
|
|
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
|
)
|
|
parser.add_argument(
|
|
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
|
)
|
|
parser.add_argument(
|
|
"--min-ngram",
|
|
type=int,
|
|
default=3,
|
|
help="Minimum n-gram to match",
|
|
)
|
|
parser.add_argument(
|
|
"--max-ngram",
|
|
type=int,
|
|
nargs="*",
|
|
default=[5, 7, 10, 15, 20],
|
|
help="Maximum n-gram to match",
|
|
)
|
|
parser.add_argument(
|
|
"--num-spec-token",
|
|
type=int,
|
|
default=3,
|
|
help="Number of speculative tokens to generate",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not args.batched:
|
|
benchmark_propose(args)
|
|
else:
|
|
benchmark_batched_propose(args)
|
|
|
|
|
|
"""
|
|
# Example command lines:
|
|
# time python3 benchmarks/benchmark_ngram_proposer.py
|
|
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
|
|
""" # noqa: E501
|
|
if __name__ == "__main__":
|
|
invoke_main() # pragma: no cover
|