113 lines
3.1 KiB
Python
113 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import gc
|
|
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
|
|
from benchmark_utils import TimeCollector
|
|
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
|
from vllm.utils import FlexibleArgumentParser
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
|
|
|
|
def main(args):
|
|
rows = []
|
|
for max_ngram in args.max_ngram:
|
|
collector = TimeCollector(TimeCollector.US)
|
|
|
|
model_config = ModelConfig(
|
|
model="facebook/opt-125m",
|
|
task="generate",
|
|
max_model_len=args.num_token + args.num_spec_token,
|
|
tokenizer="facebook/opt-125m",
|
|
tokenizer_mode="auto",
|
|
dtype="auto",
|
|
seed=None,
|
|
trust_remote_code=False,
|
|
)
|
|
proposer = NgramProposer(
|
|
vllm_config=VllmConfig(
|
|
model_config=model_config,
|
|
speculative_config=SpeculativeConfig(
|
|
prompt_lookup_min=args.min_ngram,
|
|
prompt_lookup_max=max_ngram,
|
|
num_speculative_tokens=args.num_spec_token,
|
|
method="ngram",
|
|
),
|
|
)
|
|
)
|
|
|
|
# Warm up
|
|
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
|
|
|
gc.collect()
|
|
for _ in range(args.num_iteration):
|
|
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
|
with collector:
|
|
for i in range(args.num_req):
|
|
proposer.propose(tokens[i, :])
|
|
rows.append(
|
|
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
|
+ collector.dump_avg_max()
|
|
)
|
|
|
|
print(
|
|
tabulate(
|
|
rows,
|
|
headers=[
|
|
"# Request",
|
|
"# Token",
|
|
"Min Ngram",
|
|
"Max Ngram",
|
|
"Avg (us)",
|
|
"Max (us)",
|
|
],
|
|
tablefmt="grid",
|
|
floatfmt=".3f",
|
|
)
|
|
)
|
|
|
|
|
|
def invoke_main() -> None:
|
|
parser = FlexibleArgumentParser(
|
|
description="Benchmark the performance of N-gram speculative decode drafting"
|
|
)
|
|
parser.add_argument(
|
|
"--num-iteration",
|
|
type=int,
|
|
default=100,
|
|
help="Number of iterations to run to stablize final data readings",
|
|
)
|
|
parser.add_argument(
|
|
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
|
)
|
|
parser.add_argument(
|
|
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
|
)
|
|
parser.add_argument(
|
|
"--min-ngram",
|
|
type=int,
|
|
default=3,
|
|
help="Minimum n-gram to match",
|
|
)
|
|
parser.add_argument(
|
|
"--max-ngram",
|
|
type=int,
|
|
nargs="*",
|
|
default=[5, 7, 10, 15, 20],
|
|
help="Maximum n-gram to match",
|
|
)
|
|
parser.add_argument(
|
|
"--num-spec-token",
|
|
type=int,
|
|
default=3,
|
|
help="Number of speculative tokens to generate",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
invoke_main() # pragma: no cover
|