FEAT Text generation benchmark (#2525)

Similar to #2395, this benchmark serves to compare different PEFT
methods on an equal basis. This time, the goal is to measure metrics
related to text generation, most notably speed and memory usage. The
results should be easy to reproduce and compare.

The actual experimental settings and results have yet to be added.
This commit is contained in:
VED
2025-08-07 13:47:32 +05:30
committed by GitHub
parent d7194f869a
commit ec5a1c67b0
11 changed files with 1331 additions and 0 deletions

View File

@ -0,0 +1,179 @@
## Base Model Inference Caching
The benchmarking suite uses a separate script, `run_base.py`, to measure base model inference times and save results for reuse. This should be run once per model configuration to avoid redundant computations and ensure consistent baseline metrics for all PEFT experiments.
**Usage:**
```bash
python run_base.py
```
This will cache the base model inference results for the specified configuration. Subsequent runs of `run.py` will automatically load these cached results.
# PEFT Benchmarking Suite
This directory contains a comprehensive benchmarking framework for Parameter-Efficient Fine-Tuning (PEFT) methods. For the task of text generation, the suite measures inference performance, memory usage, and other key metrics across different PEFT configurations.
## Overview
The benchmarking suite provides:
- **Inference time measurement** across different prompt categories
- **Memory usage during inference** (RAM and GPU)
- **Parameter efficiency metrics** (trainable vs total parameters)
- **Time per token analysis** for fair comparison across different generation lengths
- **Structured result logging** with detailed metadata
## Architecture
The suite follows a clean separation between:
1. **Default benchmark configuration** - shared settings for consistent comparison
2. **Individual adapter configurations** - PEFT-specific parameters for each experiment
This ensures that all experiments are comparable while allowing flexibility in adapter parameters.
## Quick Start
### Running a Single Experiment
```bash
# From the peft_bench directory
python run.py experiments/lora/lora_r8 --verbose
```
## Configuration Structure
The benchmarking suite uses a hierarchical configuration system:
1. **Default benchmark parameters** (`default_benchmark_params.json`) - Base configuration shared by all experiments
2. **Experiment-specific overrides** (`benchmark_params.json` in each experiment) - Optional overrides for specific experiments
3. **Adapter configuration** (`adapter_config.json` in each experiment) - PEFT method parameters
This structure ensures consistent comparison while allowing flexibility where needed.
### Default Configuration (`default_benchmark_params.json`)
Contains shared benchmark settings that apply to all experiments. Here are the key configuration fields:
- `model_id`: The Hugging Face model ID to use as the base model (e.g., "facebook/opt-350m")
- `dtype`: Model precision ("float16", "float32", or "bfloat16")
- `seed`: Random seed for reproducibility
- `max_new_tokens`: Maximum number of tokens to generate during inference
- `num_inference_runs`: Number of inference runs per prompt for statistical reliability
- `use_4bit`: Whether to use 4-bit quantization (bool)
- `use_8bit`: Whether to use 8-bit quantization (bool)
Each experiment can override these settings by providing its own `benchmark_params.json` file.
### Experiment Structure
Each experiment directory should contain:
1. `adapter_config.json`: PEFT adapter configuration. For details on available parameters and their meanings, refer to the [PEFT documentation](https://huggingface.co/docs/peft/main/en/developer_guides/adapters).
2. (Optional) `benchmark_params.json`: Override specific benchmark parameters for this experiment.
Example directory structure:
```
experiments/
└── lora/
├── lora_r8/ # LoRA rank 8 experiment
│ ├── adapter_config.json # PEFT adapter configuration
│ └── benchmark_params.json # Optional benchmark overrides
└── lora_r16/ # LoRA rank 16 experiment
└── adapter_config.json
```
### Experiment-Specific Overrides Example
If an experiment needs different benchmark settings, create `benchmark_params.json`:
```json
{
"_comment": "Override settings for this specific experiment",
"max_new_tokens": 50,
"num_inference_runs": 15,
"num_prompt_samples": 2
}
```
These parameters will override the defaults from `default_benchmark_params.json`. However, the defaults should generally not be changed to keep the results from the individual experiments comparable.
### Create a New Experiment Adapter Configuration
To create a new experiment, follow these steps:
1. **Create the experiment directory**
```bash
mkdir -p experiments/lora/lora_r8
```
2. **Generate the adapter configuration programmatically**
Use the PEFT library to create and save your adapter config:
```python
from peft import LoraConfig
config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM"
)
config.save_pretrained("experiments/lora/lora_r8")
```
This will create an `adapter_config.json` in your experiment directory. Adjust parameters as needed for your experiment.
3. **(Optional) Add benchmark overrides**
If you need to override default benchmark settings, create a `benchmark_params.json` in the same directory.
4. **Run the benchmark**
```bash
python run.py experiments/lora/lora_r8 --verbose
```
## Prompt Categories
The benchmark automatically runs across all prompt categories for consistent comparison:
- **short** - Brief prompts (1-2 sentences)
- **medium** - Moderate length prompts (paragraph-level)
- **long** - Extended prompts (multiple paragraphs)
Results are tracked separately for each category, allowing analysis of how different PEFT methods perform across varying input lengths.
## Results Structure
Results are saved in a structured JSON format with three main sections:
### `run_info`
- Execution metadata (timestamp, duration, status)
- Hardware information (GPU type, CUDA version, etc.)
- Error information (if applicable)
- PEFT and benchmark configurations
### `generation_info`
- Memory usage logs at different stages
- Per-category metrics (inference time, time per token, etc.)
- Overall aggregated metrics
- Individual sample results for detailed analysis
### `meta_info`
- Model information (ID, PEFT method)
- Parameter counts (adapter, total, ratio)
- Model size information (base model, adapter)
- System and package information
## Key Metrics
### Inference Performance
- **Inference Time**: Total time for generation per category
- **Time Per Token**: Normalized time accounting for different generation lengths
- **Inference Overhead**: Percentage increase compared to base model
### Memory Usage
- **Peak GPU Memory**: Maximum GPU memory during benchmark
- **Peak RAM Memory**: Maximum RAM usage
- **Memory Logs**: Detailed tracking at each stage
### Parameter Efficiency
- **Adapter Parameters**: Number of parameters in the PEFT adapter
- **Parameter Ratio**: Percentage of total model parameters that are in the adapter
- **Adapter Size**: Memory footprint of the adapter in MB

View File

@ -0,0 +1,23 @@
{
"short": [
"Explain quantum computing in one paragraph.",
"Write a haiku about machine learning.",
"What's the difference between supervised and unsupervised learning?",
"Define parameter-efficient fine-tuning in one sentence.",
"List three applications of natural language processing."
],
"medium": [
"Explain the concept of low-rank adaptation (LoRA) for large language models. Include its benefits and limitations.",
"Compare and contrast prompt tuning and prefix tuning approaches for adapting large language models.",
"What are the key differences between full fine-tuning and parameter-efficient methods? Explain with examples.",
"Describe the process of quantization for neural networks and how it affects model size and inference speed.",
"Explain how sparse expert models like Mixture of Experts work and their advantages over dense models."
],
"long": [
"Analyze the evolution of parameter-efficient fine-tuning methods from 2020 to present. Include a detailed comparison of at least five different approaches, their theoretical foundations, and practical implications for deploying large language models.",
"Provide a comprehensive tutorial on implementing LoRA for a transformer-based language model. Include code examples, hyperparameter selection guidance, and best practices for training and deployment.",
"Compare the computational efficiency, parameter count, and performance characteristics of different PEFT methods (LoRA, Prefix Tuning, Prompt Tuning, IA3, AdaLoRA) across various downstream tasks. Include a discussion of when each method is most appropriate.",
"Explain the mathematical foundations of various parameter-efficient fine-tuning techniques. Discuss how each technique modifies the original neural network architecture and the optimization challenges involved.",
"Discuss the ethical implications of parameter-efficient fine-tuning methods in democratizing access to large language models. Include considerations about computational resources, environmental impact, and accessibility for researchers in resource-constrained settings."
]
}

View File

@ -0,0 +1,119 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data handling utilities for PEFT benchmarking.
"""
import json
import os
from typing import Optional
from transformers import PreTrainedTokenizer
from utils import BenchmarkConfig
DEFAULT_PROMPTS_PATH = os.path.join(os.path.dirname(__file__), "configs", "prompts.json")
def load_test_prompts(config: dict) -> dict[str, list[str]]:
"""
Load prompts from JSON file.
Args:
config: Configuration containing prompts file path
Returns:
dictionary with prompts by category
"""
prompts_file = getattr(config, "prompts_file", DEFAULT_PROMPTS_PATH)
with open(prompts_file) as f:
prompts = json.load(f)
return prompts
def truncate_prompt_for_model(
prompt: str,
tokenizer: PreTrainedTokenizer,
max_length: Optional[int] = None,
reserve_output_tokens: int = 50,
) -> str:
"""
Truncate a prompt to fit within the model's context window.
Args:
prompt: Input prompt
tokenizer: Model tokenizer
max_length: Maximum sequence length (if None, uses model's max_length)
reserve_output_tokens: Number of tokens to reserve for response
Returns:
Truncated prompt
"""
if max_length is None:
if hasattr(tokenizer, "model_max_length"):
max_length = tokenizer.model_max_length
else:
max_length = 2048
max_prompt_length = max_length - reserve_output_tokens
input_ids = tokenizer.encode(prompt, return_tensors="pt")[0]
if len(input_ids) <= max_prompt_length:
return prompt
truncated_ids = input_ids[:max_prompt_length]
truncated_prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True)
return truncated_prompt
def prepare_benchmark_prompts(
config: BenchmarkConfig,
tokenizer: PreTrainedTokenizer,
max_input_length: Optional[int] = None,
seed: int = 42,
) -> dict[str, list[str]]:
"""
Prepare prompts for benchmarking, ensuring appropriate length and variety.
Always returns all prompt categories for consistent benchmarking.
Args:
config: Benchmark configuration
tokenizer: Model tokenizer
max_input_length: Maximum input length (overrides model default if provided)
seed: Random seed (kept for backwards compatibility)
Returns:
Dictionary with processed prompts by category (all categories included)
"""
all_prompts = load_test_prompts(config)
processed_prompts = {}
for category, prompts in all_prompts.items():
truncated_prompts = [
truncate_prompt_for_model(
prompt,
tokenizer,
max_length=max_input_length,
reserve_output_tokens=getattr(config, "reserve_output_tokens", 50),
)
for prompt in prompts
]
processed_prompts[category] = truncated_prompts
return processed_prompts

View File

@ -0,0 +1,12 @@
{
"model_id": "meta-llama/Llama-3.2-3B",
"dtype": "float16",
"seed": 42,
"num_inference_runs": 10,
"max_new_tokens": 20,
"category_generation_params": {
"short": {"max_new_tokens": 20},
"medium": {"max_new_tokens": 50},
"long": {"max_new_tokens": 100}
}
}

View File

@ -0,0 +1,17 @@
{
"base_model_name_or_path": null,
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": false,
"init_lora_weights": true,
"lora_alpha": 16,
"lora_dropout": 0.1,
"modules_to_save": null,
"peft_type": "LORA",
"r": 8,
"target_modules": [
"q_proj",
"v_proj"
],
"task_type": "CAUSAL_LM"
}

View File

@ -0,0 +1,355 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main entry point to run the experiments. Contains general setup and the proper inference code.
"""
import argparse
import gc
import json
import os
import sys
import time
from typing import Optional
import bitsandbytes
import torch
import transformers
from data import prepare_benchmark_prompts
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
from utils import (
BenchmarkConfig,
BenchmarkResult,
BenchmarkStatus,
get_memory_usage,
init_cuda,
log_results,
validate_experiment_path,
)
import peft
from peft import PeftConfig, get_peft_model
def load_base_results(model_id: str) -> Optional[dict]:
"""Load base model results if they exist."""
base_results_dir = os.path.join(os.path.dirname(__file__), "base_results")
model_name = model_id.replace("/", "_").replace("-", "_")
filename = f"base_{model_name}.json"
filepath = os.path.join(base_results_dir, filename)
if os.path.exists(filepath):
with open(filepath) as f:
return json.load(f)
return None
def measure_inference_time(model, tokenizer, prompts, max_new_tokens, num_runs, print_fn, category_generation_params):
"""Measure inference time for each prompt category."""
inference_times = {}
time_per_token = {}
generated_tokens = {}
individual_samples = {}
for category, category_prompts in prompts.items():
print_fn(f"\nMeasuring inference time for {category} prompts...")
category_times = []
category_tokens = []
category_time_per_token = []
category_samples = []
for prompt in category_prompts:
prompt_times = []
prompt_tokens = []
prompt_time_per_token = []
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
cat_max_new_tokens = category_generation_params.get(category, {}).get("max_new_tokens", max_new_tokens)
for _ in range(num_runs):
start_time = time.perf_counter()
outputs = model.generate(
**inputs,
max_new_tokens=cat_max_new_tokens,
min_new_tokens=cat_max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
)
end_time = time.perf_counter()
# Calculate metrics
inference_time = end_time - start_time
num_tokens = len(outputs[0]) - len(inputs["input_ids"][0])
time_per_token_val = inference_time / num_tokens if num_tokens > 0 else 0
prompt_times.append(inference_time)
prompt_tokens.append(num_tokens)
prompt_time_per_token.append(time_per_token_val)
# Calculate averages for this prompt
avg_time = sum(prompt_times) / len(prompt_times)
avg_tokens = sum(prompt_tokens) / len(prompt_tokens)
avg_time_per_token = sum(prompt_time_per_token) / len(prompt_time_per_token)
sample_result = {
"inference_time": avg_time,
"generated_tokens": avg_tokens,
"time_per_token": avg_time_per_token,
"individual_runs": [
{"inference_time": t, "generated_tokens": tok, "time_per_token": tpt}
for t, tok, tpt in zip(prompt_times, prompt_tokens, prompt_time_per_token)
],
}
category_samples.append(sample_result)
category_times.append(avg_time)
category_tokens.append(avg_tokens)
category_time_per_token.append(avg_time_per_token)
if category_times:
avg_category_time = sum(category_times) / len(category_times)
avg_category_tokens = sum(category_tokens) / len(category_tokens)
avg_category_time_per_token = sum(category_time_per_token) / len(category_time_per_token)
inference_times[category] = avg_category_time
generated_tokens[category] = avg_category_tokens
time_per_token[category] = avg_category_time_per_token
individual_samples[category] = category_samples
return {
"inference_times": inference_times,
"time_per_token": time_per_token,
"generated_tokens": generated_tokens,
"individual_samples": individual_samples,
}
def run_benchmark(
benchmark_config: BenchmarkConfig, experiment_name: str, experiment_path: str, print_fn=print
) -> BenchmarkResult:
"""Run benchmarks for the specified PEFT method configuration."""
result = BenchmarkResult(
experiment_name=experiment_name,
status=BenchmarkStatus.RUNNING,
model_id=benchmark_config.model_id,
)
result.save()
start_time = time.perf_counter()
e_main_benchmark: Optional[Exception] = None
try:
print_fn("Initializing CUDA...")
gpu_allocated_init, gpu_reserved_init = init_cuda()
set_seed(benchmark_config.seed)
print_fn(f"Loading base model: {benchmark_config.model_id}")
tokenizer = AutoTokenizer.from_pretrained(benchmark_config.model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"device_map": "auto" if torch.cuda.is_available() else None,
}
if benchmark_config.dtype == "float32":
model_kwargs["torch_dtype"] = torch.float32
elif benchmark_config.dtype == "float16":
model_kwargs["torch_dtype"] = torch.float16
elif benchmark_config.dtype == "bfloat16":
model_kwargs["torch_dtype"] = torch.bfloat16
else:
raise ValueError(f"Unsupported dtype: {benchmark_config.dtype}")
if benchmark_config.use_8bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True
)
elif benchmark_config.use_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_kwargs.get("torch_dtype", torch.float16),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
base_model = AutoModelForCausalLM.from_pretrained(benchmark_config.model_id, **model_kwargs)
base_results = load_base_results(benchmark_config.model_id)
print_fn("Preparing benchmark prompts...")
prompts = prepare_benchmark_prompts(
config=benchmark_config,
tokenizer=tokenizer,
max_input_length=None,
seed=benchmark_config.seed,
)
if base_results:
print_fn("Using cached base model results...")
base_inference_times = base_results["inference_results"]
else:
raise FileNotFoundError(
"No cached base results found. Please run `python run_base.py` first to generate base model results."
)
try:
print_fn(f"Loading PEFT config from {experiment_path}")
peft_config = PeftConfig.from_pretrained(experiment_path)
print_fn(f"Loaded PEFT config: {peft_config.peft_type}, with parameters: {vars(peft_config)}")
model = get_peft_model(base_model, peft_config)
except Exception as exc:
error_msg = f"Error loading PEFT config: {str(exc)}"
print_fn(error_msg)
del base_model
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
ram, gpu_allocated, gpu_reserved = get_memory_usage()
result.add_memory_log("peft_model_loaded", ram, gpu_allocated, gpu_reserved)
# Calculate PEFT model metrics
trainable_params = model.get_nb_trainable_parameters()[0]
total_params = sum(p.numel() for p in model.parameters())
base_params = sum(p.numel() for p in model.base_model.parameters())
dtype_bytes = 2 if benchmark_config.dtype in ["float16", "bfloat16"] else 4
adapter_size_mb = trainable_params * dtype_bytes / (1024 * 1024)
base_model_size_mb = base_params * dtype_bytes / (1024 * 1024)
param_ratio = trainable_params / total_params if total_params > 0 else 0
result.update_meta_info(
param_counts={
"base_params": base_params,
"trainable_params": trainable_params,
"total_params": total_params,
"param_ratio": param_ratio,
},
size_info={"base_model_size_mb": base_model_size_mb, "adapter_size_mb": adapter_size_mb},
package_info={
"transformers-version": transformers.__version__,
"peft-version": peft.__version__,
"bitsandbytes-version": bitsandbytes.__version__ if hasattr(bitsandbytes, "__version__") else None,
},
)
print_fn("Measuring PEFT model inference times...")
peft_inference_times = measure_inference_time(
model,
tokenizer,
prompts,
max_new_tokens=benchmark_config.max_new_tokens,
num_runs=benchmark_config.num_inference_runs,
print_fn=print_fn,
category_generation_params=benchmark_config.category_generation_params,
)
# Calculate inference overhead for each category
inference_overhead = {
k: (peft_inference_times["inference_times"][k] - base_inference_times["inference_times"][k])
/ base_inference_times["inference_times"][k]
* 100
for k in base_inference_times["inference_times"]
}
for category in prompts:
category_metrics = {
"inference_time": peft_inference_times["inference_times"][category],
"base_inference_time": base_inference_times["inference_times"][category],
"inference_overhead_pct": inference_overhead[category],
"time_per_token": peft_inference_times["time_per_token"][category],
"generated_tokens": peft_inference_times["generated_tokens"][category],
}
result.add_metrics_for_category(
category, category_metrics, individual_samples=peft_inference_times["individual_samples"][category]
)
result.update_generation_info(
memory_data={
"peak_gpu_memory_mb": max(
(log["gpu_allocated_mb"] for log in result.generation_info["memory"]["memory_logs"]), default=0
),
"peak_ram_memory_mb": max(
(log["ram_mb"] for log in result.generation_info["memory"]["memory_logs"]), default=0
),
}
)
ram, gpu_allocated, gpu_reserved = get_memory_usage()
result.add_memory_log("benchmark_complete", ram, gpu_allocated, gpu_reserved)
result.status = BenchmarkStatus.SUCCESS
except Exception as exc:
print_fn(f"Benchmark failed with error: {exc}")
result.status = BenchmarkStatus.FAILED
e_main_benchmark = exc
end_time = time.perf_counter()
error_message = str(e_main_benchmark) if e_main_benchmark is not None else None
peft_config_dict = peft_config.to_dict() if "peft_config" in locals() else None
if peft_config_dict:
for key, value in peft_config_dict.items():
if isinstance(value, set):
peft_config_dict[key] = list(value)
result.update_run_info(
duration=end_time - start_time,
status=result.status,
error=error_message,
peft_config=peft_config_dict,
benchmark_config=benchmark_config.to_dict(),
)
return result
def main() -> None:
"""Main entry point for the benchmark runner."""
parser = argparse.ArgumentParser(description="Run PEFT method benchmarks")
parser.add_argument("experiment_path", help="Path to experiment directory")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
args = parser.parse_args()
print_fn = print if args.verbose else lambda *args, **kwargs: None
experiment_path = args.experiment_path
allowed_root = os.path.abspath(os.path.join(os.path.dirname(__file__)))
abs_experiment_path = os.path.abspath(experiment_path)
if not abs_experiment_path.startswith(allowed_root):
print(f"Experiment path must be inside {allowed_root}, got: {abs_experiment_path}. Skipping execution.")
return 0
if not os.path.exists(abs_experiment_path):
print(f"Experiment path not found: {abs_experiment_path}. Skipping execution.")
return 0
experiment_path = abs_experiment_path
experiment_name, benchmark_config = validate_experiment_path(experiment_path)
print_fn(f"Running benchmark for experiment: {experiment_name}")
result = run_benchmark(
benchmark_config=benchmark_config,
experiment_name=experiment_name,
experiment_path=experiment_path,
print_fn=print_fn,
)
log_results(experiment_name, result, print_fn=print)
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,184 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import sys
import time
import torch
from data import prepare_benchmark_prompts
from run import measure_inference_time
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
from utils import (
BenchmarkConfig,
get_memory_usage,
init_cuda,
)
def run_base_model_benchmark(benchmark_config: BenchmarkConfig, print_fn=print) -> dict:
"""Run benchmark for base model only and return results."""
print_fn(f"Running base model benchmark for: {benchmark_config.model_id}")
print_fn("Initializing CUDA...")
init_cuda()
set_seed(benchmark_config.seed)
print_fn(f"Loading base model: {benchmark_config.model_id}")
tokenizer = AutoTokenizer.from_pretrained(benchmark_config.model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"device_map": "auto" if torch.cuda.is_available() else None,
}
if benchmark_config.dtype == "float32":
model_kwargs["torch_dtype"] = torch.float32
elif benchmark_config.dtype == "float16":
model_kwargs["torch_dtype"] = torch.float16
elif benchmark_config.dtype == "bfloat16":
model_kwargs["torch_dtype"] = torch.bfloat16
if benchmark_config.use_8bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True
)
elif benchmark_config.use_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_kwargs.get("torch_dtype", torch.float16),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(benchmark_config.model_id, **model_kwargs)
ram, gpu_allocated, gpu_reserved = get_memory_usage()
print_fn(f"Memory after model load - RAM: {ram:.2f}MB, GPU: {gpu_allocated:.2f}MB")
print_fn("Preparing benchmark prompts...")
prompts = prepare_benchmark_prompts(
config=benchmark_config.to_dict(),
tokenizer=tokenizer,
max_input_length=None,
seed=benchmark_config.seed,
)
# Measure base model inference for each prompt category
print_fn("Measuring base model inference times...")
base_inference_results = measure_inference_time(
model,
tokenizer,
prompts,
max_new_tokens=benchmark_config.max_new_tokens,
num_runs=benchmark_config.num_inference_runs,
print_fn=print_fn,
category_generation_params=benchmark_config.category_generation_params,
)
result = {
"model_id": benchmark_config.model_id,
"benchmark_config": benchmark_config.to_dict(),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"inference_results": base_inference_results,
"memory_info": {
"ram_mb": ram,
"gpu_allocated_mb": gpu_allocated,
"gpu_reserved_mb": gpu_reserved,
},
}
return result
def save_base_results(result: dict, model_id: str) -> str:
"""Save base model results with a filename based on model and config."""
base_results_dir = os.path.join(os.path.dirname(__file__), "base_results")
os.makedirs(base_results_dir, exist_ok=True)
model_name = model_id.replace("/", "_").replace("-", "_")
filename = f"base_{model_name}.json"
filepath = os.path.join(base_results_dir, filename)
with open(filepath, "w") as f:
json.dump(result, f, indent=2)
return filepath
def main():
"""Main entry point for the base model benchmark runner."""
parser = argparse.ArgumentParser(description="Run base model benchmarks")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
parser.add_argument("--force", "-f", action="store_true", help="Force re-run even if results exist")
args = parser.parse_args()
print_fn = print if args.verbose else lambda *args, **kwargs: None
default_config_path = os.path.join(os.path.dirname(__file__), "default_benchmark_params.json")
benchmark_config = BenchmarkConfig.from_json(default_config_path)
model_name = benchmark_config.model_id.replace("/", "_").replace("-", "_")
base_results_dir = os.path.join(os.path.dirname(__file__), "base_results")
filename = f"base_{model_name}.json"
filepath = os.path.join(base_results_dir, filename)
if os.path.exists(filepath) and not args.force:
print(f"Base results already exist at: {filepath}")
print("Use --force to re-run the benchmark")
return 0
print_fn(f"Running base model benchmark for: {benchmark_config.model_id}")
result = run_base_model_benchmark(benchmark_config, print_fn=print_fn)
saved_path = save_base_results(result, benchmark_config.model_id)
print(f"Base model results saved to: {saved_path}")
print("\nBase Model Benchmark Summary:")
print(f"Model: {result['model_id']}")
print(
f"Memory Usage - RAM: {result['memory_info']['ram_mb']:.2f}MB, GPU: {result['memory_info']['gpu_allocated_mb']:.2f}MB"
)
print("\nInference Times by Category:")
for category, time_val in result["inference_results"]["inference_times"].items():
time_per_token = result["inference_results"]["time_per_token"][category]
tokens = result["inference_results"]["generated_tokens"][category]
print(f" {category}: {time_val:.4f}s ({time_per_token:.6f}s/token, {tokens:.1f} tokens)")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,442 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for PEFT benchmarking.
"""
import datetime
import json
import os
import platform
import subprocess
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Optional
import psutil
import torch
FILE_NAME_BENCHMARK_PARAMS = "benchmark_params.json"
FILE_NAME_DEFAULT_CONFIG = "default_benchmark_params.json"
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
RESULT_PATH_TEMP = os.path.join(os.path.dirname(__file__), "temporary_results")
RESULT_PATH_CANCELLED = os.path.join(os.path.dirname(__file__), "cancelled_results")
class BenchmarkStatus(Enum):
"""Status of a benchmark run."""
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
RUNNING = "running"
@dataclass
class BenchmarkResult:
"""Container for benchmark results."""
experiment_name: str
status: BenchmarkStatus
model_id: str
run_info: dict = field(default_factory=dict)
generation_info: dict = field(default_factory=dict)
meta_info: dict = field(default_factory=dict)
def __post_init__(self):
"""Initialize structured data format."""
self.run_info = {
"timestamp": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(),
"duration": 0.0,
"status": self.status.value,
"hardware": {
"num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
"cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
"pytorch_version": torch.__version__,
},
}
self.meta_info = {
"model_id": self.model_id,
"parameters": {
"base_params": 0,
"trainable_params": 0,
"total_params": 0,
"param_ratio": 0.0,
},
"model_size": {
"base_model_size_mb": 0.0,
"adapter_size_mb": 0.0,
},
"package_info": {
"transformers-version": None,
"transformers-commit-hash": None,
"peft-version": None,
"peft-commit-hash": None,
"datasets-version": None,
"datasets-commit-hash": None,
"bitsandbytes-version": None,
"bitsandbytes-commit-hash": None,
"torch-version": torch.__version__,
"torch-commit-hash": None,
},
"system_info": {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
},
}
self.generation_info = {
"memory": {
"peak_gpu_memory_mb": 0.0,
"peak_ram_memory_mb": 0.0,
"memory_logs": [],
},
"by_category": {},
"overall": {},
}
def update_meta_info(self, param_counts: dict, size_info: dict, package_info: Optional[dict] = None):
"""Update model metadata information."""
self.meta_info["parameters"].update(param_counts)
self.meta_info["model_size"].update(size_info)
if package_info:
self.meta_info["package_info"].update(package_info)
def update_generation_info(self, memory_data: Optional[dict] = None, performance_metrics: Optional[dict] = None):
"""Update generation performance information, primarily for memory and high-level performance."""
if memory_data:
self.generation_info["memory"].update(memory_data)
if performance_metrics: # For things like overall tokens/sec if calculated
self.generation_info.update(performance_metrics)
def add_memory_log(self, stage: str, ram_mb: float, gpu_allocated_mb: float, gpu_reserved_mb: float):
"""Add a memory usage log entry to generation_info."""
self.generation_info["memory"]["memory_logs"].append(
{
"stage": stage,
"ram_mb": ram_mb,
"gpu_allocated_mb": gpu_allocated_mb,
"gpu_reserved_mb": gpu_reserved_mb,
}
)
def add_metrics_for_category(self, category: str, metrics: dict, individual_samples: list = None):
"""Add metrics for a specific prompt category under generation_info."""
category_data = {"metrics": metrics, "samples": individual_samples if individual_samples is not None else []}
self.generation_info["by_category"][category] = category_data
def update_run_info(
self,
duration: float,
status: BenchmarkStatus,
error: Optional[str] = None,
peft_config: Optional[dict] = None,
benchmark_config: Optional[dict] = None,
):
"""Update run information."""
self.run_info["duration"] = duration
self.run_info["status"] = status.value
if error:
self.run_info["error"] = error
if peft_config:
self.run_info["peft_config"] = peft_config
if benchmark_config:
self.run_info["benchmark_config"] = benchmark_config
def compute_overall_metrics(self):
"""Compute overall metrics across all categories within generation_info."""
if not self.generation_info["by_category"]:
return
categories = self.generation_info["by_category"]
key_metrics = [
"inference_time",
"base_inference_time",
"inference_overhead_pct",
"time_per_token",
"generated_tokens",
]
for metric in key_metrics:
values = []
for category_data in categories.values():
if "metrics" in category_data and metric in category_data["metrics"]:
values.append(category_data["metrics"][metric])
if values:
self.generation_info["overall"][metric] = sum(values) / len(values)
def to_dict(self) -> dict[str, Any]:
"""Convert result to dictionary."""
self.compute_overall_metrics()
return {
"run_info": self.run_info,
"generation_info": self.generation_info,
"meta_info": self.meta_info,
}
def save(self, path: Optional[str] = None):
"""Save result to JSON file."""
if path is None:
peft_branch = get_peft_branch()
if self.status == BenchmarkStatus.CANCELLED:
base_path = RESULT_PATH_CANCELLED
elif peft_branch != "main":
base_path = RESULT_PATH_TEMP
elif self.status == BenchmarkStatus.SUCCESS:
base_path = RESULT_PATH
elif self.status == BenchmarkStatus.FAILED:
base_path = RESULT_PATH_CANCELLED
else:
base_path = RESULT_PATH_TEMP
filename = f"{self.experiment_name}.json"
path = os.path.join(base_path, filename)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
return path
@dataclass
class BenchmarkConfig:
"""Configuration for benchmarking PEFT methods."""
model_id: str
seed: int
num_inference_runs: int
max_new_tokens: int
dtype: str = "float16"
use_4bit: bool = False
use_8bit: bool = False
category_generation_params: Optional[dict] = None
def __post_init__(self) -> None:
"""Validate configuration."""
if not isinstance(self.model_id, str):
raise ValueError(f"Invalid model_id: {self.model_id}")
if self.seed < 0:
raise ValueError(f"Invalid seed: {self.seed}")
if self.num_inference_runs <= 0:
raise ValueError(f"Invalid num_inference_runs: {self.num_inference_runs}")
if self.max_new_tokens <= 0:
raise ValueError(f"Invalid max_new_tokens: {self.max_new_tokens}")
@classmethod
def from_dict(cls, config_dict: dict) -> "BenchmarkConfig":
"""Create config from dictionary."""
valid_keys = set(cls.__dataclass_fields__.keys())
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
return cls(**filtered_dict)
@classmethod
def from_json(cls, json_path: str) -> "BenchmarkConfig":
"""Load config from JSON file."""
with open(json_path) as f:
config_dict = json.load(f)
return cls.from_dict(config_dict)
def to_dict(self) -> dict[str, Any]:
"""Convert config to dictionary."""
result = asdict(self)
return result
def save(self, path: str) -> None:
"""Save config to JSON file."""
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
def merge_from_dict(self, config_dict: dict) -> None:
"""Merge settings from a dictionary into this config object.
Keys in config_dict will override existing attributes.
"""
for key, value in config_dict.items():
if hasattr(self, key):
setattr(self, key, value)
def validate_experiment_path(path: str) -> tuple[str, "BenchmarkConfig"]:
"""Validate experiment path, load and merge configs, and return them."""
if not os.path.exists(path):
raise FileNotFoundError(f"Experiment path not found: {path}")
path_parts = os.path.normpath(path).split(os.sep)
try:
experiments_idx = path_parts.index("experiments")
except ValueError:
experiment_name = os.path.basename(path.rstrip(os.sep))
else:
if experiments_idx + 1 < len(path_parts):
method_name = path_parts[experiments_idx + 1]
remaining_parts = path_parts[experiments_idx + 2 :]
if remaining_parts:
remaining_name = "-".join(remaining_parts)
experiment_name = f"{method_name}--{remaining_name}"
else:
experiment_name = method_name
else:
experiment_name = os.path.basename(path.rstrip(os.sep))
default_config_path = os.path.join(os.path.dirname(__file__), FILE_NAME_DEFAULT_CONFIG)
experiment_benchmark_params_path = os.path.join(path, FILE_NAME_BENCHMARK_PARAMS)
if not os.path.exists(default_config_path):
raise FileNotFoundError(f"Default configuration file not found: {default_config_path}. This is required.")
benchmark_config = BenchmarkConfig.from_json(default_config_path)
print(f"Loaded default configuration from {default_config_path}")
if os.path.exists(experiment_benchmark_params_path):
with open(experiment_benchmark_params_path) as f:
experiment_specific_params = json.load(f)
benchmark_config.merge_from_dict(experiment_specific_params)
print(f"Loaded and merged experiment-specific parameters from {experiment_benchmark_params_path}")
else:
print(f"No {FILE_NAME_BENCHMARK_PARAMS} found in {path}. Using only default configuration.")
return experiment_name, benchmark_config
def get_memory_usage() -> tuple[float, float, float]:
"""Get current memory usage (RAM and GPU)."""
process = psutil.Process(os.getpid())
ram_usage_bytes = process.memory_info().rss
ram_usage_mb = ram_usage_bytes / (1024 * 1024)
if torch.cuda.is_available():
gpu_allocated = torch.cuda.memory_allocated()
gpu_reserved = torch.cuda.memory_reserved()
gpu_allocated_mb = gpu_allocated / (1024 * 1024)
gpu_reserved_mb = gpu_reserved / (1024 * 1024)
else:
gpu_allocated_mb = 0.0
gpu_reserved_mb = 0.0
return ram_usage_mb, gpu_allocated_mb, gpu_reserved_mb
def init_cuda() -> tuple[float, float]:
"""Initialize CUDA and return initial memory usage."""
if torch.cuda.is_available():
torch.cuda.init()
torch.cuda.empty_cache()
_, gpu_allocated, gpu_reserved = get_memory_usage()
return gpu_allocated, gpu_reserved
return 0.0, 0.0
def get_model_size_mb(model: torch.nn.Module, dtype_bytes: int = 4) -> float:
"""Calculate model size in MB."""
return sum(p.numel() * dtype_bytes for p in model.parameters()) / (1024 * 1024)
def get_peft_branch() -> str:
repo_root = os.path.dirname(__file__)
return subprocess.check_output("git rev-parse --abbrev-ref HEAD".split(), cwd=repo_root).decode().strip()
def log_results(
experiment_name: str,
benchmark_result: BenchmarkResult,
print_fn: Callable = print,
) -> None:
"""Log benchmark results to console."""
print_fn("\n" + "=" * 50)
print_fn(f"Benchmark Results: {experiment_name}")
print_fn("=" * 50)
print_fn(f"Status: {benchmark_result.run_info.get('status', 'N/A')}")
print_fn(f"Duration: {benchmark_result.run_info.get('duration', 0):.2f} seconds")
if benchmark_result.run_info.get("status") != BenchmarkStatus.SUCCESS.value:
print_fn(f"Error: {benchmark_result.run_info.get('error', 'Unknown error')}")
print_fn("=" * 50)
return
print_fn("\nModel Information:")
print_fn(f" Base Model: {benchmark_result.meta_info.get('model_id', 'N/A')}")
print_fn("\nParameter Counts:")
params = benchmark_result.meta_info.get("parameters", {})
print_fn(f" Base Parameters: {params.get('base_params', 0):,}")
print_fn(f" Trainable Parameters: {params.get('trainable_params', 0):,}")
print_fn(f" Parameter Ratio: {params.get('param_ratio', 0):.5%}")
print_fn("\nModel Size:")
size_info = benchmark_result.meta_info.get("model_size", {})
print_fn(f" Base Model: {size_info.get('base_model_size_mb', 0):.2f} MB")
print_fn(f" Adapter: {size_info.get('adapter_size_mb', 0):.2f} MB")
print_fn("\nMemory Usage (from generation_info):")
memory_data = benchmark_result.generation_info.get("memory", {})
print_fn(f" Peak GPU Memory: {memory_data.get('peak_gpu_memory_mb', 0):.2f} MB")
print_fn(f" Peak RAM Memory: {memory_data.get('peak_ram_memory_mb', 0):.2f} MB")
print_fn("\nDetailed Metrics (from generation_info.by_category):")
if benchmark_result.generation_info.get("by_category"):
for category, cat_data in benchmark_result.generation_info["by_category"].items():
print_fn(f" Category: {category}")
metrics = cat_data.get("metrics", {})
print_fn(f" Inference Time: {metrics.get('inference_time', 0):.4f} seconds")
print_fn(f" Base Inference Time: {metrics.get('base_inference_time', 0):.4f} seconds")
print_fn(f" Inference Overhead: {metrics.get('inference_overhead_pct', 0):.2f}%")
print_fn(f" Time Per Token: {metrics.get('time_per_token', 0):.6f} seconds/token")
print_fn(f" Generated Tokens: {metrics.get('generated_tokens', 0):.1f}")
samples = cat_data.get("samples", [])
if samples:
print_fn(f" Number of Samples: {len(samples)}")
print_fn(
f" Average Generated Tokens: {sum(s.get('generated_tokens', 0) for s in samples) / len(samples):.1f}"
)
else:
print_fn(" No per-category metrics available.")
benchmark_result.compute_overall_metrics()
print_fn("\nOverall Metrics (from generation_info.overall):")
overall = benchmark_result.generation_info.get("overall")
if overall:
print_fn(f" Inference Time: {overall.get('inference_time', 0):.4f} seconds")
print_fn(f" Base Inference Time: {overall.get('base_inference_time', 0):.4f} seconds")
print_fn(f" Inference Overhead: {overall.get('inference_overhead_pct', 0):.2f}%")
print_fn(f" Time Per Token: {overall.get('time_per_token', 0):.6f} seconds/token")
print_fn(f" Generated Tokens: {overall.get('generated_tokens', 0):.1f}")
else:
print_fn(" No overall metrics computed.")
print_fn("\nSaved results to:", benchmark_result.save())
print_fn("=" * 50)