mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
bc3b20f81f | |||
54be44ee74 | |||
2815bd6143 | |||
17bccecb1c | |||
c335930d75 |
@ -189,6 +189,9 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
logger.info("Current number of requests: %d", len(requests))
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
@ -402,6 +405,13 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
@ -760,6 +770,14 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['instruction']}:\n{item['input']}"
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
@ -793,11 +811,18 @@ class AIMODataset(HuggingFaceDataset):
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt, completion = item['problem'], item["solution"]
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
@ -895,3 +920,103 @@ class ASRDataset(HuggingFaceDataset):
|
||||
" what Whisper supports.", skipped)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
class MTBenchDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"philschmid/mt-bench",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item['turns'][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
class CNNDailyMailDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"abisee/cnn_dailymail",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
instruction = "Could you summarize the following article, " \
|
||||
"please reuse text from the article if possible: "
|
||||
prompt = instruction + item['article']
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
@ -12,7 +12,8 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
ConversationDataset, InstructCoderDataset,
|
||||
CNNDailyMailDataset, ConversationDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
@ -57,9 +58,9 @@ def run_vllm(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -123,9 +124,9 @@ def run_vllm_chat(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -167,9 +168,9 @@ async def run_vllm_async(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -339,6 +340,14 @@ def get_requests(args, tokenizer):
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MTBenchDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in CNNDailyMailDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = CNNDailyMailDataset
|
||||
common_kwargs['dataset_subset'] = '3.0.0'
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
@ -477,8 +486,11 @@ def validate_args(args):
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
| MTBenchDataset.SUPPORTED_DATASET_PATHS
|
||||
| CNNDailyMailDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
|
213
benchmarks/run.sh
Normal file
213
benchmarks/run.sh
Normal file
@ -0,0 +1,213 @@
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
python benchmarks/benchmark_throughput.py \
|
||||
--model meta-llama/Meta-Llama-3.1-8B-Instruct\
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--prefix-len 0 \
|
||||
--output-len 512 \
|
||||
--num-prompts 200 \
|
||||
--speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path AI-MO/aimo-validation-aime \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 1024 \
|
||||
# --num-prompts 90 \
|
||||
# --speculative_config '{"method": "eagle3", "num_speculative_tokens": 20, "model": "yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B"}'
|
||||
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path AI-MO/aimo-validation-aime \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 1024 \
|
||||
# --num-prompts 90 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 10 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
63
benchmarks/visualize/common.py
Normal file
63
benchmarks/visualize/common.py
Normal file
@ -0,0 +1,63 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
MODEL_TO_NAMES = {
|
||||
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||
"llama3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"llama3.1-8B" : "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama3.1-70B" : "meta-llama/Llama-3.1-70B-Instruct",
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class AccStats:
|
||||
lens: list[int]
|
||||
probs: list[float] = None
|
||||
entropies: list[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.probs is not None:
|
||||
assert len(self.lens) == len(self.probs), "Length of lens and probs must match"
|
||||
if self.entropies is not None:
|
||||
assert len(self.lens) == len(self.entropies), "Length of lens and entropies must match"
|
||||
|
||||
# remove the prefill accepted lens
|
||||
self.lens = self.lens[1:]
|
||||
|
||||
# remove the last proposed tokens
|
||||
if self.probs:
|
||||
self.probs = self.probs[:-1]
|
||||
if self.entropies:
|
||||
self.entropies = self.entropies[:-1]
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.lens)
|
||||
|
||||
# def cleanup(acc_stats: AccStats) ->
|
||||
# # Remove the prefill phase
|
||||
# data = data[1:]
|
||||
# # Cap the maximum value to 10
|
||||
# data = [min(x, 10) for x in data]
|
||||
# return data
|
||||
|
||||
def load_data(datapath, tokenizer, verbose=False):
|
||||
acceptance_stats = []
|
||||
with open(datapath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
stat = AccStats(
|
||||
lens=data['acc']['acc_len'],
|
||||
probs=data['acc'].get('acc_prob', None),
|
||||
entropies=data['acc'].get('acc_entropy', None)
|
||||
)
|
||||
acceptance_stats.append(stat)
|
||||
if verbose:
|
||||
print("Input:", tokenizer.decode(data['prompt_token_ids']))
|
||||
print("Output:", tokenizer.decode(data['generated_token_ids']))
|
||||
print("=============================================")
|
||||
|
||||
max_length = max(stats.length for stats in acceptance_stats)
|
||||
|
||||
print(f"Load {len(acceptance_stats)} with max length {max_length}")
|
||||
return acceptance_stats
|
108
benchmarks/visualize/vis_acc.py
Normal file
108
benchmarks/visualize/vis_acc.py
Normal file
@ -0,0 +1,108 @@
|
||||
import json
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from transformers import AutoTokenizer
|
||||
from .common import MODEL_TO_NAMES, load_data
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
class AcceptanceStatsClient:
|
||||
"""Client for fetching and processing acceptance statistics data."""
|
||||
|
||||
def __init__(self, model_name, method, dataset, data_path=None):
|
||||
"""Initialize the client with model and dataset info."""
|
||||
self.model_name = model_name
|
||||
self.method = method
|
||||
self.dataset = dataset
|
||||
|
||||
if data_path is None:
|
||||
self.data_path = f"/data/lily/batch-sd/data/{model_name}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
else:
|
||||
self.data_path = data_path
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model_name], use_fast=False)
|
||||
self.acceptance_stats = None
|
||||
|
||||
def load_data(self):
|
||||
"""Load the acceptance statistics from file."""
|
||||
self.acceptance_stats = load_data(self.data_path, self.tokenizer)
|
||||
return self.acceptance_stats
|
||||
|
||||
def plot_heatmap(self, output_dir="figures"):
|
||||
"""Plot the acceptance statistics as a heatmap."""
|
||||
if self.acceptance_stats is None:
|
||||
self.load_data()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.heatmap(self.acceptance_stats, cmap="YlGnBu")
|
||||
plt.xlabel("Position")
|
||||
plt.ylabel("Request ID")
|
||||
|
||||
# Add Y-axis labels on the right
|
||||
ax2 = ax.twinx()
|
||||
ax2.set_ylim(ax.get_ylim())
|
||||
ax2.set_yticks([])
|
||||
ax2.set_ylabel("# of Accepted Tokens", labelpad=10)
|
||||
|
||||
plt.title(f"Acceptance Statistics: {self.model_name} - {self.method} - {self.dataset}")
|
||||
plt.tight_layout()
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path = Path(output_dir) / self.model_name
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
output_file = output_path / f"{self.method}_{self.dataset}_acceptance_stats.pdf"
|
||||
plt.savefig(output_file)
|
||||
print(f"Saved heatmap to {output_file}")
|
||||
return fig
|
||||
|
||||
def get_summary_stats(self):
|
||||
"""Get summary statistics about the acceptance data."""
|
||||
if self.acceptance_stats is None:
|
||||
self.load_data()
|
||||
|
||||
# Calculate average acceptance rate for each position
|
||||
avg_by_position = [sum(col)/len(col) for col in zip(*self.acceptance_stats) if sum(1 for v in col if v >= 0) > 0]
|
||||
|
||||
# Calculate average acceptance rate for each request
|
||||
avg_by_request = [sum(row)/len(row) for row in self.acceptance_stats]
|
||||
|
||||
return {
|
||||
"total_requests": len(self.acceptance_stats),
|
||||
"max_position": len(avg_by_position),
|
||||
"avg_acceptance_rate": sum(avg_by_request)/len(avg_by_request),
|
||||
"avg_by_position": avg_by_position,
|
||||
"avg_by_request": avg_by_request
|
||||
}
|
||||
|
||||
# Example model configuration
|
||||
model = "llama3.1-8B"
|
||||
# model = "r1-distill-llama-8B"
|
||||
method = "eagle3"
|
||||
dataset = "mtbench"
|
||||
# dataset = "aime"
|
||||
# method = "ngram"
|
||||
# dataset = "cnndailymail"
|
||||
# datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
datapath = "acceptance_stats.jsonl"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use the client instead of directly loading data
|
||||
client = AcceptanceStatsClient(model, method, dataset, datapath)
|
||||
acceptance_stats = client.load_data()
|
||||
|
||||
# Get summary statistics
|
||||
summary = client.get_summary_stats()
|
||||
print("Summary Statistics:")
|
||||
print(f"Total Requests: {summary['total_requests']}")
|
||||
print(f"Max Position: {summary['max_position']}")
|
||||
print(f"Average Acceptance Rate: {summary['avg_acceptance_rate']:.2f}")
|
||||
|
||||
# Create heatmap visualization
|
||||
plot_heatmap = False
|
||||
if plot_heatmap:
|
||||
client.plot_heatmap()
|
||||
|
69
benchmarks/visualize/vis_acc_diff.py
Normal file
69
benchmarks/visualize/vis_acc_diff.py
Normal file
@ -0,0 +1,69 @@
|
||||
import json
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
model = "llama3.1-8B"
|
||||
dataset = "instructcode"
|
||||
method1 = "ngram"
|
||||
method2 = "eagle3"
|
||||
|
||||
def get_datapath(method):
|
||||
datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
return datapath
|
||||
|
||||
def cleanup(data):
|
||||
# Remove the prefill phase
|
||||
data = data[1:]
|
||||
# Cap the maximum value to 10
|
||||
data = [min(x, 10) for x in data]
|
||||
return data
|
||||
|
||||
def load_data(datapath):
|
||||
acceptance_stats = {}
|
||||
with open(datapath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
key = hash(tuple(data['prompt_token_ids']))
|
||||
acceptance_stats[key] = cleanup(data['acc'])
|
||||
# Pad the acceptance stats to the same length
|
||||
max_length = max(len(stats) for k, stats in acceptance_stats.items())
|
||||
|
||||
for key in acceptance_stats:
|
||||
acceptance_stats[key] += [-2] * (max_length - len(acceptance_stats[key]))
|
||||
|
||||
print(f"Load {len(acceptance_stats)} with max length {max_length} from {datapath}")
|
||||
return acceptance_stats
|
||||
|
||||
def diff(acceptance_stats1, acceptance_stats2):
|
||||
diff = {}
|
||||
for key in acceptance_stats1:
|
||||
if key in acceptance_stats2:
|
||||
diff[key] = [a - b for a, b in zip(acceptance_stats1[key], acceptance_stats2[key])]
|
||||
return diff
|
||||
|
||||
datapath_1 = get_datapath(method1)
|
||||
datapath_2 = get_datapath(method2)
|
||||
acceptance_stats_1 = load_data(datapath_1)
|
||||
acceptance_stats_2 = load_data(datapath_2)
|
||||
acceptance_stats_diff = diff(acceptance_stats_1, acceptance_stats_2)
|
||||
|
||||
acceptance_stats = list(acceptance_stats_diff.values())
|
||||
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
colors = ["red", "white", "blue"]
|
||||
custom_cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
|
||||
sns.heatmap(acceptance_stats, cmap=custom_cmap, center=0)
|
||||
plt.xlabel("Position")
|
||||
plt.ylabel("Request ID")
|
||||
# Add Y-axis labels on the right
|
||||
ax2 = ax.twinx()
|
||||
ax2.set_ylim(ax.get_ylim()) # Match y-axis range
|
||||
ax2.set_yticks([]) # Remove right tick marks if undesired
|
||||
ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label
|
||||
plt.title(f"Diff between {method2} - {method1} acceptance stats for {dataset}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"figures/{model}/diff_{method2}_{method1}_{dataset}_acceptance_stats.pdf")
|
38
benchmarks/visualize/vis_prob_entropy.py
Normal file
38
benchmarks/visualize/vis_prob_entropy.py
Normal file
@ -0,0 +1,38 @@
|
||||
from transformers import AutoTokenizer
|
||||
from common import MODEL_TO_NAMES, load_data
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_prob_entropy(acceptance_stats,
|
||||
output_path):
|
||||
|
||||
acc_probs = []
|
||||
rej_probs = []
|
||||
for stat in acceptance_stats:
|
||||
for i, acc_len in enumerate(stat.lens):
|
||||
acc_probs.extend(stat.probs[i][:acc_len-1])
|
||||
rej_probs.extend(stat.probs[i][acc_len-1:])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
plt.hist(acc_probs, bins=100, alpha=0.5,
|
||||
label='Accepted Probabilities', color='green')
|
||||
plt.hist(rej_probs, bins=100, alpha=0.5,
|
||||
label='Rejected Probabilities', color='red')
|
||||
plt.xlabel('Probability')
|
||||
plt.ylabel('Frequency')
|
||||
plt.title('Distribution of Accepted and Rejected Probabilities')
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
datapath = "/data/lily/sd-benchmark-paper/batch-sd/acceptance_stats.jsonl"
|
||||
model = "llama3.1-8B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model],
|
||||
use_fast=False)
|
||||
acceptance_stats = load_data(datapath, tokenizer)
|
||||
plot_prob_entropy(acceptance_stats, output_path="prob_entropy_figures")
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
import json
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -632,6 +633,7 @@ class Scheduler(SchedulerInterface):
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
self.acceptance_stats = model_runner_output.acceptance_stats
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
@ -789,6 +791,18 @@ class Scheduler(SchedulerInterface):
|
||||
self._free_request(request)
|
||||
|
||||
def _free_request(self, request: Request) -> None:
|
||||
req_id = request.request_id
|
||||
data = self.acceptance_stats.pop(req_id)
|
||||
with open('acceptance_stats.jsonl', 'a') as f:
|
||||
f.write(json.dumps({
|
||||
"id": req_id,
|
||||
"acc": data,
|
||||
"prompt_token_ids": request.prompt_token_ids,
|
||||
"generated_token_ids": request.output_token_ids._x
|
||||
}))
|
||||
f.write('\n')
|
||||
|
||||
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
|
@ -100,6 +100,8 @@ class ModelRunnerOutput:
|
||||
# [prompt_len]
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||
|
||||
acceptance_stats: Optional[dict[str, list]] = None
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
|
@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -98,12 +100,23 @@ class EagleProposer:
|
||||
)
|
||||
sample_hidden_states = hidden_states_logits[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
all_draft_probs = []
|
||||
all_draft_entropy = []
|
||||
|
||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
# Get the probabilities of the draft tokens.
|
||||
draft_probs = probs.gather(dim=1, index=draft_token_ids.unsqueeze(1))
|
||||
dist = Categorical(logits=logits)
|
||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||
all_draft_probs.append(draft_probs)
|
||||
all_draft_entropy.append(entropy)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
return draft_token_ids.view(-1,
|
||||
1), all_draft_probs, all_draft_entropy
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@ -164,9 +177,17 @@ class EagleProposer:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
draft_probs = probs.gather(dim=1,
|
||||
index=draft_token_ids.unsqueeze(1))
|
||||
dist = Categorical(logits=logits)
|
||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||
all_draft_probs.append(draft_probs)
|
||||
all_draft_entropy.append(entropy)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
return draft_token_ids, all_draft_probs, all_draft_entropy
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
|
@ -282,6 +282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
self.acceptance_stats = {}
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@ -1187,6 +1189,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if req_id not in self.acceptance_stats:
|
||||
self.acceptance_stats[req_id] = {
|
||||
'acc_len': [],
|
||||
'acc_prob': [],
|
||||
'acc_entropy': [],
|
||||
}
|
||||
self.acceptance_stats[req_id]['acc_len'].append(len(token_ids))
|
||||
# Force 1 generated token per request.
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
valid_sampled_token_ids[i] = token_ids[:1]
|
||||
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
@ -1262,7 +1277,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
draft_token_ids, draft_probs, draft_entropy = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
@ -1274,6 +1289,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
if req_id not in self.acceptance_stats:
|
||||
self.acceptance_stats[req_id] = {
|
||||
'acc_len': [],
|
||||
'acc_prob': [],
|
||||
'acc_entropy': [],
|
||||
}
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
step_probs, step_entropy = [], []
|
||||
for prob, entropy in zip(draft_probs, draft_entropy):
|
||||
step_probs.append(prob[req_index].item())
|
||||
step_entropy.append(entropy[req_index].item())
|
||||
|
||||
self.acceptance_stats[req_id]['acc_prob'].append(step_probs)
|
||||
self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy)
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
@ -1285,6 +1316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
acceptance_stats=self.acceptance_stats,
|
||||
)
|
||||
|
||||
def generate_draft_token_ids(
|
||||
|
Reference in New Issue
Block a user