Benchmarking V2: framework impl (#40486)

* Start revamping benchmarking

* Start refactoring benchmarking

* Use Pandas for CSV

* import fix

* Remove benchmark files

* Remove sample data

* Address review comments

* Benchmarking v2

* Fix llama bench parameters

* Working checkpoint

* Readme touchups

* Remove unnecessary test

* Massage the framework a bit

* Small cleanup

* Remove unnecessary flushes

* Remove references to mock benchmark

* Take commit ID from CLI

* Address review comments

* Use Events for thread comms

* Tiny renaming
This commit is contained in:
Ákos Hadnagy
2025-09-03 22:26:32 +02:00
committed by GitHub
parent 459c1fa47a
commit f22ec7f174
7 changed files with 1851 additions and 0 deletions

1
benchmark_v2/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
benchmark_results/

98
benchmark_v2/README.md Normal file
View File

@ -0,0 +1,98 @@
# Benchmarking v2
A comprehensive benchmarking framework for transformer models that supports multiple execution modes (eager, compiled, kernelized), detailed performance metrics collection, and structured output format.
## Quick Start
### Running All Benchmarks
```bash
# Run all benchmarks with default settings
python run_benchmarks.py
# Specify output directory
python run_benchmarks.py --output-dir my_results
# Run with custom parameters
python run_benchmarks.py \
--warmup-iterations 5 \
--measurement-iterations 10 \
--num-tokens-to-generate 200
```
### Running Specific Benchmarks
```bash
# Include only specific benchmarks
python run_benchmarks.py --include llama
# Exclude specific benchmarks
python run_benchmarks.py --exclude old_benchmark
## Output Format
Results are saved as JSON files with the following structure:
```json
{
"model_name": "llama_2_7b",
"benchmark_scenarios": [
{
"scenario_name": "eager_variant",
"metadata": {
"timestamp": "2025-01-XX...",
"commit_id": "abc123...",
"hardware_info": {
"gpu_name": "NVIDIA A100",
"gpu_memory_total": 40960,
"cpu_count": 64
},
"config": {
"variant": "eager",
"warmup_iterations": 3,
"measurement_iterations": 5
}
},
"measurements": {
"latency": {
"mean": 2.45,
"median": 2.43,
"std": 0.12,
"min": 2.31,
"max": 2.67,
"p95": 2.61,
"p99": 2.65
},
"time_to_first_token": {
"mean": 0.15,
"std": 0.02
},
"tokens_per_second": {
"mean": 87.3,
"unit": "tokens/sec"
}
},
"gpu_metrics": {
"gpu_utilization_mean": 85.2,
"gpu_memory_used_mean": 12450
}
}
]
}
```
### Debug Mode
```bash
python run_benchmarks.py --log-level DEBUG
```
## Contributing
To add new benchmarks:
1. Create a new file in `benches/`
2. Implement the `ModelBenchmark` interface
3. Add a runner function (`run_<benchmark_name>` or `run_benchmark`)
4. run_benchmarks.py

View File

@ -0,0 +1 @@
# Benchmark implementations directory

View File

@ -0,0 +1,156 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
from typing import Dict, Any, List
from benchmark_framework import ModelBenchmark
import torch
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.set_float32_matmul_precision("high")
class LLaMABenchmark(ModelBenchmark):
"""Simplified LLaMA model benchmark implementation using the ModelBenchmark base class."""
def __init__(self, logger: logging.Logger):
super().__init__(logger)
self._default_prompt = "Why dogs are so cute?" # Custom prompt for LLaMA
def get_scenario_configs(self) -> List[Dict[str, Any]]:
"""
Get LLaMA-specific scenario configurations.
Returns:
List of scenario configuration dictionaries
"""
return [
# Eager variants
{"variant": "eager", "compile_mode": None, "use_cache": True, "description": "Eager execution with cache"},
# Compiled variants
{"variant": "compiled", "compile_mode": "max-autotune", "use_cache": True, "description": "Compiled with max autotune"},
# Kernelized variant (if available)
{"variant": "kernelized", "compile_mode": "max-autotune", "use_cache": True, "description": "Kernelized execution"},
]
def _is_kernelization_available(self) -> bool:
"""Check if kernelization is available for LLaMA."""
try:
from kernels import Mode, kernelize
return True
except ImportError:
self.logger.debug("Kernelization not available: kernels module not found")
return False
def get_default_generation_config(self) -> Dict[str, Any]:
"""Get LLaMA-specific generation configuration."""
return {
"do_sample": False,
"top_p": 1.0,
"temperature": 1.0,
"repetition_penalty": 1.0,
"max_new_tokens": None, # Will be set per scenario
}
def get_model_init_kwargs(self, config) -> Dict[str, Any]:
"""Get LLaMA-specific model initialization kwargs."""
from benchmark_framework import BenchmarkConfig
return {
"torch_dtype": getattr(torch, config.torch_dtype),
"attn_implementation": config.attn_implementation,
"use_cache": True,
}
def get_default_torch_dtype(self) -> str:
"""Get default torch dtype for LLaMA."""
return "float16" # LLaMA works well with float16
def get_default_device(self) -> str:
"""Get default device for LLaMA."""
return "cuda" # LLaMA prefers CUDA
def run_llama(logger, output_dir, **kwargs):
"""
Run LLaMA benchmark with the given configuration.
Args:
logger: Logger instance
output_dir: Output directory for results
**kwargs: Additional configuration options
Returns:
Path to output file if successful
"""
from benchmark_framework import BenchmarkRunner
# Extract parameters with defaults
model_id = kwargs.get('model_id', 'meta-llama/Llama-2-7b-hf')
warmup_iterations = kwargs.get('warmup_iterations', 3)
measurement_iterations = kwargs.get('measurement_iterations', 5)
num_tokens_to_generate = kwargs.get('num_tokens_to_generate', 100)
include_sdpa_variants = kwargs.get('include_sdpa_variants', True)
device = kwargs.get('device', 'cuda')
torch_dtype = kwargs.get('torch_dtype', 'float16')
batch_size = kwargs.get('batch_size', 1)
commit_id = kwargs.get('commit_id', None)
logger.info(f"Starting LLaMA benchmark for model: {model_id}")
logger.info(f"Configuration: warmup={warmup_iterations}, measurement={measurement_iterations}, tokens={num_tokens_to_generate}")
try:
# Create benchmark instance
benchmark = LLaMABenchmark(logger)
# Create scenarios
scenarios = benchmark.create_scenarios(
model_id=model_id,
warmup_iterations=warmup_iterations,
measurement_iterations=measurement_iterations,
num_tokens_to_generate=num_tokens_to_generate,
include_sdpa_variants=include_sdpa_variants,
device=device,
torch_dtype=torch_dtype,
batch_size=batch_size
)
logger.info(f"Created {len(scenarios)} benchmark scenarios")
# Create runner and execute benchmarks
runner = BenchmarkRunner(logger, output_dir)
results = runner.run_benchmark(benchmark, scenarios, commit_id=commit_id)
if not results:
logger.warning("No successful benchmark results")
return None
# Save results
model_name = model_id.split('/')[-1] # Extract model name from ID
output_file = runner.save_results(model_name, results)
logger.info(f"LLaMA benchmark completed successfully. Results saved to: {output_file}")
return output_file
except Exception as e:
logger.error(f"LLaMA benchmark failed: {e}")
import traceback
logger.debug(traceback.format_exc())
raise

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
numpy>=1.21.0
psutil>=5.8.0
gpustat>=1.0.0
torch>=2.0.0
transformers>=4.30.0
datasets>=2.10.0

385
benchmark_v2/run_benchmarks.py Executable file
View File

@ -0,0 +1,385 @@
#!/usr/bin/env python3
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Top-level benchmarking script that automatically discovers and runs all benchmarks
in the ./benches directory, organizing outputs into model-specific subfolders.
"""
import argparse
import importlib.util
import logging
import os
import sys
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
def setup_logging(log_level: str = "INFO", enable_file_logging: bool = False) -> logging.Logger:
"""Setup logging configuration."""
numeric_level = getattr(logging, log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {log_level}')
handlers = [logging.StreamHandler(sys.stdout)]
if enable_file_logging:
handlers.append(
logging.FileHandler(f'benchmark_run_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
)
logging.basicConfig(
level=numeric_level,
format='[%(levelname)s - %(asctime)s] %(name)s: %(message)s',
handlers=handlers
)
return logging.getLogger(__name__)
def discover_benchmarks(benches_dir: str) -> List[Dict[str, Any]]:
"""
Discover all benchmark modules in the benches directory.
Returns:
List of dictionaries containing benchmark module info
"""
benchmarks = []
benches_path = Path(benches_dir)
if not benches_path.exists():
raise FileNotFoundError(f"Benches directory not found: {benches_dir}")
for py_file in benches_path.glob("*.py"):
if py_file.name.startswith("__"):
continue
module_name = py_file.stem
try:
# Import the module
spec = importlib.util.spec_from_file_location(module_name, py_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Check if it has a benchmark runner function
if hasattr(module, f'run_{module_name}'):
benchmarks.append({
'name': module_name,
'path': str(py_file),
'module': module,
'runner_function': getattr(module, f'run_{module_name}')
})
elif hasattr(module, 'run_benchmark'):
benchmarks.append({
'name': module_name,
'path': str(py_file),
'module': module,
'runner_function': getattr(module, 'run_benchmark')
})
else:
logging.warning(f"No runner function found in {py_file}")
except Exception as e:
logging.error(f"Failed to import {py_file}: {e}")
return benchmarks
def run_single_benchmark(
benchmark_info: Dict[str, Any],
output_dir: str,
logger: logging.Logger,
**kwargs
) -> Optional[str]:
"""
Run a single benchmark and return the output file path.
Args:
benchmark_info: Dictionary containing benchmark module info
output_dir: Base output directory
logger: Logger instance
**kwargs: Additional arguments to pass to the benchmark
Returns:
Path to the output file if successful, None otherwise
"""
benchmark_name = benchmark_info['name']
runner_func = benchmark_info['runner_function']
logger.info(f"Running benchmark: {benchmark_name}")
try:
# Check function signature to determine what arguments to pass
import inspect
sig = inspect.signature(runner_func)
# Prepare arguments based on function signature
func_kwargs = {
'logger': logger,
'output_dir': output_dir
}
# Add other kwargs if the function accepts them
for param_name in sig.parameters:
if param_name in kwargs:
func_kwargs[param_name] = kwargs[param_name]
# Filter kwargs to only include parameters the function accepts
# If function has **kwargs, include all provided kwargs
has_var_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
if has_var_kwargs:
valid_kwargs = {**func_kwargs, **kwargs}
else:
valid_kwargs = {k: v for k, v in func_kwargs.items()
if k in sig.parameters}
# Run the benchmark
result = runner_func(**valid_kwargs)
if isinstance(result, str):
# Function returned a file path
return result
else:
logger.info(f"Benchmark {benchmark_name} completed successfully")
return "completed"
except Exception as e:
logger.error(f"Benchmark {benchmark_name} failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def generate_summary_report(
output_dir: str,
benchmark_results: Dict[str, Any],
logger: logging.Logger
) -> str:
"""Generate a summary report of all benchmark runs."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
summary_file = os.path.join(output_dir, f"benchmark_summary_{timestamp}.json")
summary_data = {
"run_metadata": {
"timestamp": datetime.utcnow().isoformat(),
"total_benchmarks": len(benchmark_results),
"successful_benchmarks": len([r for r in benchmark_results.values() if r is not None]),
"failed_benchmarks": len([r for r in benchmark_results.values() if r is None])
},
"benchmark_results": benchmark_results,
"output_directory": output_dir
}
with open(summary_file, 'w') as f:
json.dump(summary_data, f, indent=2, default=str)
logger.info(f"Summary report saved to: {summary_file}")
return summary_file
def main():
"""Main entry point for the benchmarking script."""
parser = argparse.ArgumentParser(
description="Run all benchmarks in the ./benches directory"
)
parser.add_argument(
"--output-dir",
type=str,
default="benchmark_results",
help="Base output directory for benchmark results (default: benchmark_results)"
)
parser.add_argument(
"--benches-dir",
type=str,
default="./benches",
help="Directory containing benchmark implementations (default: ./benches)"
)
parser.add_argument(
"--log-level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Logging level (default: INFO)"
)
parser.add_argument(
"--model-id",
type=str,
help="Specific model ID to benchmark (if supported by benchmarks)"
)
parser.add_argument(
"--warmup-iterations",
type=int,
default=3,
help="Number of warmup iterations (default: 3)"
)
parser.add_argument(
"--measurement-iterations",
type=int,
default=5,
help="Number of measurement iterations (default: 5)"
)
parser.add_argument(
"--num-tokens-to-generate",
type=int,
default=100,
help="Number of tokens to generate in benchmarks (default: 100)"
)
parser.add_argument(
"--include",
type=str,
nargs="*",
help="Only run benchmarks matching these names"
)
parser.add_argument(
"--exclude",
type=str,
nargs="*",
help="Exclude benchmarks matching these names"
)
parser.add_argument(
"--enable-mock",
action="store_true",
help="Enable mock benchmark (skipped by default)"
)
parser.add_argument(
"--enable-file-logging",
action="store_true",
help="Enable file logging (disabled by default)"
)
parser.add_argument(
"--commit-id",
type=str,
help="Git commit ID for metadata (if not provided, will auto-detect from git)"
)
args = parser.parse_args()
# Setup logging
logger = setup_logging(args.log_level, args.enable_file_logging)
logger.info("Starting benchmark discovery and execution")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Benches directory: {args.benches_dir}")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
# Discover benchmarks
benchmarks = discover_benchmarks(args.benches_dir)
logger.info(f"Discovered {len(benchmarks)} benchmark(s): {[b['name'] for b in benchmarks]}")
if not benchmarks:
logger.warning("No benchmarks found!")
return 1
# Filter benchmarks based on include/exclude
filtered_benchmarks = benchmarks
if args.include:
filtered_benchmarks = [b for b in filtered_benchmarks
if any(pattern in b['name'] for pattern in args.include)]
logger.info(f"Filtered to include: {[b['name'] for b in filtered_benchmarks]}")
if args.exclude:
filtered_benchmarks = [b for b in filtered_benchmarks
if not any(pattern in b['name'] for pattern in args.exclude)]
logger.info(f"After exclusion: {[b['name'] for b in filtered_benchmarks]}")
if not filtered_benchmarks:
logger.warning("No benchmarks remaining after filtering!")
return 1
# Prepare common kwargs for benchmarks
benchmark_kwargs = {
'warmup_iterations': args.warmup_iterations,
'measurement_iterations': args.measurement_iterations,
'num_tokens_to_generate': args.num_tokens_to_generate
}
if args.model_id:
benchmark_kwargs['model_id'] = args.model_id
# Add enable_mock flag for mock benchmark
benchmark_kwargs['enable_mock'] = args.enable_mock
# Add commit_id if provided
if args.commit_id:
benchmark_kwargs['commit_id'] = args.commit_id
# Run benchmarks
benchmark_results = {}
successful_count = 0
for benchmark_info in filtered_benchmarks:
result = run_single_benchmark(
benchmark_info,
args.output_dir,
logger,
**benchmark_kwargs
)
benchmark_results[benchmark_info['name']] = result
if result is not None:
successful_count += 1
# Generate summary report
summary_file = generate_summary_report(args.output_dir, benchmark_results, logger)
# Final summary
total_benchmarks = len(filtered_benchmarks)
failed_count = total_benchmarks - successful_count
logger.info("=" * 60)
logger.info("BENCHMARK RUN SUMMARY")
logger.info("=" * 60)
logger.info(f"Total benchmarks: {total_benchmarks}")
logger.info(f"Successful: {successful_count}")
logger.info(f"Failed: {failed_count}")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Summary report: {summary_file}")
if failed_count > 0:
logger.warning(f"{failed_count} benchmark(s) failed. Check logs for details.")
return 1
else:
logger.info("All benchmarks completed successfully!")
return 0
except Exception as e:
logger.error(f"Benchmark run failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return 1
if __name__ == "__main__":
sys.exit(main())