mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
7 Commits
v0.10.1
...
bench-late
Author | SHA1 | Date | |
---|---|---|---|
af985d70bf | |||
b484b79504 | |||
8fcd4d18e0 | |||
50e2788383 | |||
f0ca3a6142 | |||
528088392e | |||
9030400353 |
@ -5,18 +5,21 @@ import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
|
||||
validate_dataset, write_to_json)
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -48,28 +51,34 @@ def main(args: argparse.Namespace):
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = get_requests(args.batch_size, args, tokenizer)
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts,
|
||||
llm.generate(prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
@ -180,7 +189,44 @@ if __name__ == "__main__":
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset")
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
args.backend = "vllm"
|
||||
validate_dataset(args)
|
||||
random.seed(0)
|
||||
main(args)
|
||||
|
@ -11,11 +11,9 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from benchmark_dataset import SampleRequest
|
||||
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
|
||||
validate_dataset, write_to_json)
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
@ -287,59 +285,6 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
@ -348,7 +293,7 @@ def main(args: argparse.Namespace):
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = get_requests(args, tokenizer)
|
||||
requests = get_requests(args.num_prompts, args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
@ -449,47 +394,8 @@ def validate_args(args):
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
# === Dataset Validation ===
|
||||
validate_dataset(args)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
@ -529,14 +435,6 @@ if __name__ == "__main__":
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
|
@ -4,8 +4,14 @@ import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
@ -67,3 +73,113 @@ class InfEncoder(json.JSONEncoder):
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(records, f, cls=InfEncoder)
|
||||
|
||||
|
||||
def get_requests(num_requests: int, args: argparse.Namespace,
|
||||
tokenizer: Any) -> list[SampleRequest]:
|
||||
"""
|
||||
Sample the requests for the benchmark.
|
||||
"""
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": num_requests,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if getattr(args, "backend", False) and args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def validate_dataset(args: argparse.Namespace, ):
|
||||
"""
|
||||
Validate the dataset arguments.
|
||||
"""
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
@ -81,6 +81,7 @@ class RejectionSampler(nn.Module):
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
acceptance_rate: min(p, q)
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
@ -92,7 +93,7 @@ class RejectionSampler(nn.Module):
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
output_token_ids, output_probs = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
@ -102,7 +103,9 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
mask = output_probs != PLACEHOLDER_TOKEN_ID
|
||||
acceptance_rate = output_probs[mask].mean()
|
||||
return output_token_ids, acceptance_rate
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
@ -170,6 +173,8 @@ def rejection_sample(
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
output_probs = torch.empty_like(output_token_ids, dtype=torch.float32)
|
||||
output_probs.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
@ -180,6 +185,7 @@ def rejection_sample(
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
output_probs,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
@ -189,7 +195,7 @@ def rejection_sample(
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
return output_token_ids, output_probs
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
@ -216,6 +222,7 @@ def rejection_sample(
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
output_probs,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
@ -229,7 +236,7 @@ def rejection_sample(
|
||||
IS_NGRAM=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
return output_token_ids, output_probs
|
||||
|
||||
|
||||
def compute_probs(
|
||||
@ -432,6 +439,7 @@ def sample_recovered_tokens(
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
output_probs_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
@ -459,14 +467,16 @@ def rejection_greedy_sample_kernel(
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
draft_token_id == target_argmax_id)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
@ -480,6 +490,7 @@ def rejection_greedy_sample_kernel(
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
output_probs_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
@ -507,17 +518,16 @@ def rejection_random_sample_kernel(
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + draft_token_id)
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
@ -530,6 +540,8 @@ def rejection_random_sample_kernel(
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
token_id)
|
||||
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
min(draft_prob, target_prob))
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
|
133
vllm/v1/spec_decode/auto_tuner.py
Normal file
133
vllm/v1/spec_decode/auto_tuner.py
Normal file
@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
|
||||
class AutoTuner:
|
||||
|
||||
def __init__(self):
|
||||
# Some tracking metrics
|
||||
# for the auto-tuning process.
|
||||
# metrics specific to ngram_proposer.
|
||||
self.step_cnt = 0
|
||||
self.match_cnt = 0
|
||||
self.total_cnt = 0
|
||||
self.past_acceptance_rates = []
|
||||
self.past_match_ratios = []
|
||||
|
||||
# config
|
||||
self.update_interval = 100
|
||||
self.window_size = 10000
|
||||
self.c_kv_load = 0.1
|
||||
self.c_computation = 0.2
|
||||
self.c_overhead = 0.3
|
||||
|
||||
# some cached values
|
||||
self.last_verified_len = 0
|
||||
|
||||
def get_verified_len(self, batch_size: int, match_cnt: int,
|
||||
num_kv_tokens: int, max_draft_len: int) -> int:
|
||||
if self.step_cnt % self.update_interval != 0:
|
||||
return self.last_verified_len
|
||||
|
||||
best_verified_len = 0
|
||||
max_goodput = -1.0
|
||||
for i in range(max_draft_len):
|
||||
cur_goodput, draft_time, target_time = self._predict_goodput(
|
||||
batch_size, match_cnt, num_kv_tokens, i)
|
||||
# print(f"Goodput for proposal len {i}: {cur_goodput}")
|
||||
if cur_goodput > max_goodput:
|
||||
max_goodput = cur_goodput
|
||||
best_verified_len = i
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_verified_len = best_verified_len
|
||||
return best_verified_len
|
||||
|
||||
def adjust_draft_len(self, req_states: dict[str, CachedRequestState],
|
||||
draft_token_ids: list[list[int]]):
|
||||
"""
|
||||
Adjust the draft length based on the verified length.
|
||||
"""
|
||||
|
||||
# Calculate parameters used for goodput prediction.
|
||||
num_kv_tokens = 0
|
||||
for req_id in req_states:
|
||||
num_kv_tokens += req_states[req_id].num_tokens
|
||||
batch_size = len(draft_token_ids)
|
||||
match_cnt = 0
|
||||
max_draft_len = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
if len(draft_token_ids[i]) == 0:
|
||||
continue
|
||||
match_cnt += 1
|
||||
max_draft_len = max(max_draft_len, len(draft_token_ids[i]))
|
||||
self.total_cnt += batch_size
|
||||
self.match_cnt += match_cnt
|
||||
self.past_match_ratios.append(match_cnt * 1.0 / (batch_size))
|
||||
|
||||
return draft_token_ids
|
||||
# Use goodput prediction to get the verified length.
|
||||
verified_len = self.get_verified_len(batch_size, match_cnt,
|
||||
num_kv_tokens, max_draft_len)
|
||||
|
||||
draft_token_ids = [draft[:verified_len] for draft in draft_token_ids]
|
||||
return draft_token_ids
|
||||
|
||||
def update_stats(self, acceptance_rate: float):
|
||||
self.step_cnt += 1
|
||||
if self.step_cnt % 20 == 0:
|
||||
print(
|
||||
f"Step {self.step_cnt}: "
|
||||
f"Last acceptance rate: {acceptance_rate:.2f}",
|
||||
f"Last match ratio: {self.past_match_ratios[-1]:.2f}",
|
||||
f"Global acceptance rate: {self.acceptance_rate:.2f}",
|
||||
"Global match ratio:",
|
||||
f"{self.match_cnt / (self.total_cnt + 1e-5):.2f}",
|
||||
)
|
||||
|
||||
self.past_acceptance_rates.append(acceptance_rate)
|
||||
|
||||
@property
|
||||
def acceptance_rate(self):
|
||||
window_acceptance_rates = self.past_acceptance_rates[-self.
|
||||
window_size:]
|
||||
return sum(window_acceptance_rates) / len(window_acceptance_rates)
|
||||
|
||||
def _predict_goodput(self, batch_size: int, match_cnt: int,
|
||||
num_kv_tokens: int,
|
||||
verified_len: int) -> tuple[float, float, float]:
|
||||
"""
|
||||
Predict the goodput for a given verified length.
|
||||
"""
|
||||
generated_len = self._predict_generated_len(batch_size, match_cnt,
|
||||
verified_len)
|
||||
draft_time = self._predict_draft_time()
|
||||
target_time = self._predict_target_time(batch_size, match_cnt,
|
||||
num_kv_tokens, verified_len)
|
||||
batch_time = draft_time + target_time
|
||||
return generated_len / batch_time, draft_time, target_time
|
||||
|
||||
def _predict_generated_len(self, batch_size: int, match_cnt: int,
|
||||
verified_len: int):
|
||||
spec_gen_len = float((1 - self.acceptance_rate**(verified_len + 1)) /
|
||||
(1 - self.acceptance_rate))
|
||||
non_spec_gen_len = batch_size - match_cnt
|
||||
return spec_gen_len + non_spec_gen_len
|
||||
|
||||
def _predict_draft_time(self):
|
||||
# TODO: We need to benchmark and model this.
|
||||
return 0
|
||||
|
||||
def _predict_target_time(self, batch_size: int, match_cnt: int,
|
||||
num_kv_tokens: int, verified_len: int):
|
||||
kv_load_time = num_kv_tokens * self.c_kv_load
|
||||
|
||||
# Computation time
|
||||
# +1 for the input token.
|
||||
num_batched_tokens = match_cnt * (verified_len + 1) + (batch_size -
|
||||
match_cnt)
|
||||
computation_time = num_batched_tokens * self.c_computation
|
||||
|
||||
return kv_load_time + computation_time + self.c_overhead
|
@ -34,6 +34,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.auto_tuner import AutoTuner
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
@ -156,6 +157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.auto_tuner = AutoTuner()
|
||||
assert self.speculative_config.method == "ngram", \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
if get_pp_group().is_last_rank:
|
||||
@ -1087,13 +1089,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
output_token_ids, acceptance_rate = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
self.auto_tuner.update_stats(acceptance_rate)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@ -1191,6 +1194,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
|
||||
draft_token_ids = self.auto_tuner.adjust_draft_len(
|
||||
self.requests, draft_token_ids)
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self) -> None:
|
||||
|
Reference in New Issue
Block a user