Compare commits

...

7 Commits

Author SHA1 Message Date
af985d70bf change to greedy 2025-04-01 15:53:26 -07:00
b484b79504 fix 2025-04-01 15:46:41 -07:00
8fcd4d18e0 minor 2025-04-01 13:51:04 -07:00
50e2788383 dsd draft 2025-04-01 13:33:07 -07:00
f0ca3a6142 minor 2025-03-31 20:05:48 -07:00
528088392e minor 2025-03-31 15:27:06 -07:00
9030400353 add datasets to benchmark_latency 2025-03-31 15:25:08 -07:00
6 changed files with 348 additions and 137 deletions

View File

@ -5,18 +5,21 @@ import argparse
import dataclasses
import json
import os
import random
import time
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union
import numpy as np
import torch
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
validate_dataset, write_to_json)
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser
@ -48,28 +51,34 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
temperature=0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = get_requests(args.batch_size, args, tokenizer)
prompts: list[Union[TextPrompt, TokensPrompt]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
llm.generate(prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
@ -180,7 +189,44 @@ if __name__ == "__main__":
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
)
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
default="sharegpt")
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=None,
help="Range of sampled ratio of input/output length, "
"used only for RandomDataSet.",
)
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset")
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
parser.add_argument("--prefix-len",
type=int,
default=None,
help="Number of prefix tokens per request."
"This is for the RandomDataset and SonnetDataset")
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
args.backend = "vllm"
validate_dataset(args)
random.seed(0)
main(args)

View File

@ -11,11 +11,9 @@ from typing import Any, Optional, Union
import torch
import uvloop
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from benchmark_dataset import SampleRequest
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
validate_dataset, write_to_json)
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
@ -287,59 +285,6 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
write_to_json(pt_file, pt_records)
def get_requests(args, tokenizer):
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
def main(args: argparse.Namespace):
if args.seed is None:
args.seed = 0
@ -348,7 +293,7 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = get_requests(args, tokenizer)
requests = get_requests(args.num_prompts, args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
request_outputs: Optional[list[RequestOutput]] = None
@ -449,47 +394,8 @@ def validate_args(args):
if args.backend not in valid_backends:
raise ValueError(f"Unsupported backend: {args.backend}")
# === Dataset Configuration ===
if not args.dataset and not args.dataset_path:
print(
"When dataset path is not set, it will default to random dataset")
args.dataset_name = 'random'
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if args.dataset_name not in {"random", "sonnet", None
} and args.prefix_len is not None:
warnings.warn("--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2)
# === Dataset Validation ===
validate_dataset(args)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
@ -529,14 +435,6 @@ if __name__ == "__main__":
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
default="sharegpt")
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--dataset-path",
type=str,
default=None,

View File

@ -4,8 +4,14 @@ import argparse
import json
import math
import os
import warnings
from typing import Any
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
metrics: dict[str, list],
@ -67,3 +73,113 @@ class InfEncoder(json.JSONEncoder):
def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder)
def get_requests(num_requests: int, args: argparse.Namespace,
tokenizer: Any) -> list[SampleRequest]:
"""
Sample the requests for the benchmark.
"""
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": num_requests,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if getattr(args, "backend", False) and args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
def validate_dataset(args: argparse.Namespace, ):
"""
Validate the dataset arguments.
"""
# === Dataset Configuration ===
if not args.dataset_path:
print(
"When dataset path is not set, it will default to random dataset")
args.dataset_name = 'random'
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
assert getattr(
args, 'backend', None
) and args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
assert getattr(
args, 'backend', None
) and args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
assert getattr(
args, 'backend', None
) and args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if args.dataset_name not in {"random", "sonnet", None
} and args.prefix_len is not None:
warnings.warn("--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2)

View File

@ -81,6 +81,7 @@ class RejectionSampler(nn.Module):
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
acceptance_rate: min(p, q)
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
@ -92,7 +93,7 @@ class RejectionSampler(nn.Module):
sampling_metadata,
)
output_token_ids = rejection_sample(
output_token_ids, output_probs = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
@ -102,7 +103,9 @@ class RejectionSampler(nn.Module):
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
mask = output_probs != PLACEHOLDER_TOKEN_ID
acceptance_rate = output_probs[mask].mean()
return output_token_ids, acceptance_rate
@staticmethod
def parse_output(
@ -170,6 +173,8 @@ def rejection_sample(
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
output_probs = torch.empty_like(output_token_ids, dtype=torch.float32)
output_probs.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
@ -180,6 +185,7 @@ def rejection_sample(
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
output_probs,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
@ -189,7 +195,7 @@ def rejection_sample(
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
return output_token_ids, output_probs
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
@ -216,6 +222,7 @@ def rejection_sample(
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
output_probs,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
@ -229,7 +236,7 @@ def rejection_sample(
IS_NGRAM=draft_probs is None,
num_warps=1,
)
return output_token_ids
return output_token_ids, output_probs
def compute_probs(
@ -432,6 +439,7 @@ def sample_recovered_tokens(
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
output_probs_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
@ -459,14 +467,16 @@ def rejection_greedy_sample_kernel(
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
if not rejected:
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
draft_token_id == target_argmax_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
@ -480,6 +490,7 @@ def rejection_greedy_sample_kernel(
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
output_probs_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
@ -507,7 +518,6 @@ def rejection_random_sample_kernel(
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
draft_prob = 1
@ -516,8 +526,8 @@ def rejection_random_sample_kernel(
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
(start_idx + pos) * vocab_size + draft_token_id)
if not rejected:
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
@ -530,6 +540,8 @@ def rejection_random_sample_kernel(
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
min(draft_prob, target_prob))
if not rejected:
# If all tokens are accepted, append the bonus token.

View File

@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.worker.gpu_input_batch import CachedRequestState
class AutoTuner:
def __init__(self):
# Some tracking metrics
# for the auto-tuning process.
# metrics specific to ngram_proposer.
self.step_cnt = 0
self.match_cnt = 0
self.total_cnt = 0
self.past_acceptance_rates = []
self.past_match_ratios = []
# config
self.update_interval = 100
self.window_size = 10000
self.c_kv_load = 0.1
self.c_computation = 0.2
self.c_overhead = 0.3
# some cached values
self.last_verified_len = 0
def get_verified_len(self, batch_size: int, match_cnt: int,
num_kv_tokens: int, max_draft_len: int) -> int:
if self.step_cnt % self.update_interval != 0:
return self.last_verified_len
best_verified_len = 0
max_goodput = -1.0
for i in range(max_draft_len):
cur_goodput, draft_time, target_time = self._predict_goodput(
batch_size, match_cnt, num_kv_tokens, i)
# print(f"Goodput for proposal len {i}: {cur_goodput}")
if cur_goodput > max_goodput:
max_goodput = cur_goodput
best_verified_len = i
else:
break
self.last_verified_len = best_verified_len
return best_verified_len
def adjust_draft_len(self, req_states: dict[str, CachedRequestState],
draft_token_ids: list[list[int]]):
"""
Adjust the draft length based on the verified length.
"""
# Calculate parameters used for goodput prediction.
num_kv_tokens = 0
for req_id in req_states:
num_kv_tokens += req_states[req_id].num_tokens
batch_size = len(draft_token_ids)
match_cnt = 0
max_draft_len = 0
for i in range(batch_size):
if len(draft_token_ids[i]) == 0:
continue
match_cnt += 1
max_draft_len = max(max_draft_len, len(draft_token_ids[i]))
self.total_cnt += batch_size
self.match_cnt += match_cnt
self.past_match_ratios.append(match_cnt * 1.0 / (batch_size))
return draft_token_ids
# Use goodput prediction to get the verified length.
verified_len = self.get_verified_len(batch_size, match_cnt,
num_kv_tokens, max_draft_len)
draft_token_ids = [draft[:verified_len] for draft in draft_token_ids]
return draft_token_ids
def update_stats(self, acceptance_rate: float):
self.step_cnt += 1
if self.step_cnt % 20 == 0:
print(
f"Step {self.step_cnt}: "
f"Last acceptance rate: {acceptance_rate:.2f}",
f"Last match ratio: {self.past_match_ratios[-1]:.2f}",
f"Global acceptance rate: {self.acceptance_rate:.2f}",
"Global match ratio:",
f"{self.match_cnt / (self.total_cnt + 1e-5):.2f}",
)
self.past_acceptance_rates.append(acceptance_rate)
@property
def acceptance_rate(self):
window_acceptance_rates = self.past_acceptance_rates[-self.
window_size:]
return sum(window_acceptance_rates) / len(window_acceptance_rates)
def _predict_goodput(self, batch_size: int, match_cnt: int,
num_kv_tokens: int,
verified_len: int) -> tuple[float, float, float]:
"""
Predict the goodput for a given verified length.
"""
generated_len = self._predict_generated_len(batch_size, match_cnt,
verified_len)
draft_time = self._predict_draft_time()
target_time = self._predict_target_time(batch_size, match_cnt,
num_kv_tokens, verified_len)
batch_time = draft_time + target_time
return generated_len / batch_time, draft_time, target_time
def _predict_generated_len(self, batch_size: int, match_cnt: int,
verified_len: int):
spec_gen_len = float((1 - self.acceptance_rate**(verified_len + 1)) /
(1 - self.acceptance_rate))
non_spec_gen_len = batch_size - match_cnt
return spec_gen_len + non_spec_gen_len
def _predict_draft_time(self):
# TODO: We need to benchmark and model this.
return 0
def _predict_target_time(self, batch_size: int, match_cnt: int,
num_kv_tokens: int, verified_len: int):
kv_load_time = num_kv_tokens * self.c_kv_load
# Computation time
# +1 for the input token.
num_batched_tokens = match_cnt * (verified_len + 1) + (batch_size -
match_cnt)
computation_time = num_batched_tokens * self.c_computation
return kv_load_time + computation_time + self.c_overhead

View File

@ -34,6 +34,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.auto_tuner import AutoTuner
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
@ -156,6 +157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
self.auto_tuner = AutoTuner()
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
@ -1087,13 +1089,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
output_token_ids, acceptance_rate = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
self.auto_tuner.update_stats(acceptance_rate)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
@ -1191,6 +1194,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
draft_token_ids = self.auto_tuner.adjust_draft_len(
self.requests, draft_token_ids)
return draft_token_ids
def load_model(self) -> None: