mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Benchmarking V2: framework impl (#40486)
* Start revamping benchmarking * Start refactoring benchmarking * Use Pandas for CSV * import fix * Remove benchmark files * Remove sample data * Address review comments * Benchmarking v2 * Fix llama bench parameters * Working checkpoint * Readme touchups * Remove unnecessary test * Massage the framework a bit * Small cleanup * Remove unnecessary flushes * Remove references to mock benchmark * Take commit ID from CLI * Address review comments * Use Events for thread comms * Tiny renaming
This commit is contained in:
1
benchmark_v2/.gitignore
vendored
Normal file
1
benchmark_v2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
benchmark_results/
|
98
benchmark_v2/README.md
Normal file
98
benchmark_v2/README.md
Normal file
@ -0,0 +1,98 @@
|
||||
# Benchmarking v2
|
||||
|
||||
A comprehensive benchmarking framework for transformer models that supports multiple execution modes (eager, compiled, kernelized), detailed performance metrics collection, and structured output format.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Running All Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all benchmarks with default settings
|
||||
python run_benchmarks.py
|
||||
|
||||
# Specify output directory
|
||||
python run_benchmarks.py --output-dir my_results
|
||||
|
||||
# Run with custom parameters
|
||||
python run_benchmarks.py \
|
||||
--warmup-iterations 5 \
|
||||
--measurement-iterations 10 \
|
||||
--num-tokens-to-generate 200
|
||||
```
|
||||
|
||||
### Running Specific Benchmarks
|
||||
|
||||
```bash
|
||||
# Include only specific benchmarks
|
||||
python run_benchmarks.py --include llama
|
||||
|
||||
# Exclude specific benchmarks
|
||||
python run_benchmarks.py --exclude old_benchmark
|
||||
|
||||
## Output Format
|
||||
|
||||
Results are saved as JSON files with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_name": "llama_2_7b",
|
||||
"benchmark_scenarios": [
|
||||
{
|
||||
"scenario_name": "eager_variant",
|
||||
"metadata": {
|
||||
"timestamp": "2025-01-XX...",
|
||||
"commit_id": "abc123...",
|
||||
"hardware_info": {
|
||||
"gpu_name": "NVIDIA A100",
|
||||
"gpu_memory_total": 40960,
|
||||
"cpu_count": 64
|
||||
},
|
||||
"config": {
|
||||
"variant": "eager",
|
||||
"warmup_iterations": 3,
|
||||
"measurement_iterations": 5
|
||||
}
|
||||
},
|
||||
"measurements": {
|
||||
"latency": {
|
||||
"mean": 2.45,
|
||||
"median": 2.43,
|
||||
"std": 0.12,
|
||||
"min": 2.31,
|
||||
"max": 2.67,
|
||||
"p95": 2.61,
|
||||
"p99": 2.65
|
||||
},
|
||||
"time_to_first_token": {
|
||||
"mean": 0.15,
|
||||
"std": 0.02
|
||||
},
|
||||
"tokens_per_second": {
|
||||
"mean": 87.3,
|
||||
"unit": "tokens/sec"
|
||||
}
|
||||
},
|
||||
"gpu_metrics": {
|
||||
"gpu_utilization_mean": 85.2,
|
||||
"gpu_memory_used_mean": 12450
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
python run_benchmarks.py --log-level DEBUG
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
To add new benchmarks:
|
||||
|
||||
1. Create a new file in `benches/`
|
||||
2. Implement the `ModelBenchmark` interface
|
||||
3. Add a runner function (`run_<benchmark_name>` or `run_benchmark`)
|
||||
4. run_benchmarks.py
|
1
benchmark_v2/benches/__init__.py
Normal file
1
benchmark_v2/benches/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Benchmark implementations directory
|
156
benchmark_v2/benches/llama.py
Normal file
156
benchmark_v2/benches/llama.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from benchmark_framework import ModelBenchmark
|
||||
|
||||
import torch
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
class LLaMABenchmark(ModelBenchmark):
|
||||
"""Simplified LLaMA model benchmark implementation using the ModelBenchmark base class."""
|
||||
|
||||
def __init__(self, logger: logging.Logger):
|
||||
super().__init__(logger)
|
||||
self._default_prompt = "Why dogs are so cute?" # Custom prompt for LLaMA
|
||||
|
||||
|
||||
|
||||
def get_scenario_configs(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get LLaMA-specific scenario configurations.
|
||||
|
||||
Returns:
|
||||
List of scenario configuration dictionaries
|
||||
"""
|
||||
return [
|
||||
# Eager variants
|
||||
{"variant": "eager", "compile_mode": None, "use_cache": True, "description": "Eager execution with cache"},
|
||||
|
||||
# Compiled variants
|
||||
{"variant": "compiled", "compile_mode": "max-autotune", "use_cache": True, "description": "Compiled with max autotune"},
|
||||
|
||||
# Kernelized variant (if available)
|
||||
{"variant": "kernelized", "compile_mode": "max-autotune", "use_cache": True, "description": "Kernelized execution"},
|
||||
]
|
||||
|
||||
def _is_kernelization_available(self) -> bool:
|
||||
"""Check if kernelization is available for LLaMA."""
|
||||
try:
|
||||
from kernels import Mode, kernelize
|
||||
return True
|
||||
except ImportError:
|
||||
self.logger.debug("Kernelization not available: kernels module not found")
|
||||
return False
|
||||
|
||||
def get_default_generation_config(self) -> Dict[str, Any]:
|
||||
"""Get LLaMA-specific generation configuration."""
|
||||
return {
|
||||
"do_sample": False,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_new_tokens": None, # Will be set per scenario
|
||||
}
|
||||
|
||||
def get_model_init_kwargs(self, config) -> Dict[str, Any]:
|
||||
"""Get LLaMA-specific model initialization kwargs."""
|
||||
from benchmark_framework import BenchmarkConfig
|
||||
return {
|
||||
"torch_dtype": getattr(torch, config.torch_dtype),
|
||||
"attn_implementation": config.attn_implementation,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
def get_default_torch_dtype(self) -> str:
|
||||
"""Get default torch dtype for LLaMA."""
|
||||
return "float16" # LLaMA works well with float16
|
||||
|
||||
def get_default_device(self) -> str:
|
||||
"""Get default device for LLaMA."""
|
||||
return "cuda" # LLaMA prefers CUDA
|
||||
|
||||
|
||||
def run_llama(logger, output_dir, **kwargs):
|
||||
"""
|
||||
Run LLaMA benchmark with the given configuration.
|
||||
|
||||
Args:
|
||||
logger: Logger instance
|
||||
output_dir: Output directory for results
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Path to output file if successful
|
||||
"""
|
||||
from benchmark_framework import BenchmarkRunner
|
||||
|
||||
# Extract parameters with defaults
|
||||
model_id = kwargs.get('model_id', 'meta-llama/Llama-2-7b-hf')
|
||||
warmup_iterations = kwargs.get('warmup_iterations', 3)
|
||||
measurement_iterations = kwargs.get('measurement_iterations', 5)
|
||||
num_tokens_to_generate = kwargs.get('num_tokens_to_generate', 100)
|
||||
include_sdpa_variants = kwargs.get('include_sdpa_variants', True)
|
||||
device = kwargs.get('device', 'cuda')
|
||||
torch_dtype = kwargs.get('torch_dtype', 'float16')
|
||||
batch_size = kwargs.get('batch_size', 1)
|
||||
commit_id = kwargs.get('commit_id', None)
|
||||
|
||||
logger.info(f"Starting LLaMA benchmark for model: {model_id}")
|
||||
logger.info(f"Configuration: warmup={warmup_iterations}, measurement={measurement_iterations}, tokens={num_tokens_to_generate}")
|
||||
|
||||
try:
|
||||
# Create benchmark instance
|
||||
benchmark = LLaMABenchmark(logger)
|
||||
|
||||
# Create scenarios
|
||||
scenarios = benchmark.create_scenarios(
|
||||
model_id=model_id,
|
||||
warmup_iterations=warmup_iterations,
|
||||
measurement_iterations=measurement_iterations,
|
||||
num_tokens_to_generate=num_tokens_to_generate,
|
||||
include_sdpa_variants=include_sdpa_variants,
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
logger.info(f"Created {len(scenarios)} benchmark scenarios")
|
||||
|
||||
# Create runner and execute benchmarks
|
||||
runner = BenchmarkRunner(logger, output_dir)
|
||||
results = runner.run_benchmark(benchmark, scenarios, commit_id=commit_id)
|
||||
|
||||
if not results:
|
||||
logger.warning("No successful benchmark results")
|
||||
return None
|
||||
|
||||
# Save results
|
||||
model_name = model_id.split('/')[-1] # Extract model name from ID
|
||||
output_file = runner.save_results(model_name, results)
|
||||
|
||||
logger.info(f"LLaMA benchmark completed successfully. Results saved to: {output_file}")
|
||||
return output_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLaMA benchmark failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
raise
|
1204
benchmark_v2/benchmark_framework.py
Normal file
1204
benchmark_v2/benchmark_framework.py
Normal file
File diff suppressed because it is too large
Load Diff
6
benchmark_v2/requirements.txt
Normal file
6
benchmark_v2/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
numpy>=1.21.0
|
||||
psutil>=5.8.0
|
||||
gpustat>=1.0.0
|
||||
torch>=2.0.0
|
||||
transformers>=4.30.0
|
||||
datasets>=2.10.0
|
385
benchmark_v2/run_benchmarks.py
Executable file
385
benchmark_v2/run_benchmarks.py
Executable file
@ -0,0 +1,385 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Top-level benchmarking script that automatically discovers and runs all benchmarks
|
||||
in the ./benches directory, organizing outputs into model-specific subfolders.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO", enable_file_logging: bool = False) -> logging.Logger:
|
||||
"""Setup logging configuration."""
|
||||
numeric_level = getattr(logging, log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f'Invalid log level: {log_level}')
|
||||
|
||||
handlers = [logging.StreamHandler(sys.stdout)]
|
||||
|
||||
if enable_file_logging:
|
||||
handlers.append(
|
||||
logging.FileHandler(f'benchmark_run_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=numeric_level,
|
||||
format='[%(levelname)s - %(asctime)s] %(name)s: %(message)s',
|
||||
handlers=handlers
|
||||
)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_benchmarks(benches_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Discover all benchmark modules in the benches directory.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing benchmark module info
|
||||
"""
|
||||
benchmarks = []
|
||||
benches_path = Path(benches_dir)
|
||||
|
||||
if not benches_path.exists():
|
||||
raise FileNotFoundError(f"Benches directory not found: {benches_dir}")
|
||||
|
||||
for py_file in benches_path.glob("*.py"):
|
||||
if py_file.name.startswith("__"):
|
||||
continue
|
||||
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Check if it has a benchmark runner function
|
||||
if hasattr(module, f'run_{module_name}'):
|
||||
benchmarks.append({
|
||||
'name': module_name,
|
||||
'path': str(py_file),
|
||||
'module': module,
|
||||
'runner_function': getattr(module, f'run_{module_name}')
|
||||
})
|
||||
elif hasattr(module, 'run_benchmark'):
|
||||
benchmarks.append({
|
||||
'name': module_name,
|
||||
'path': str(py_file),
|
||||
'module': module,
|
||||
'runner_function': getattr(module, 'run_benchmark')
|
||||
})
|
||||
else:
|
||||
logging.warning(f"No runner function found in {py_file}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to import {py_file}: {e}")
|
||||
|
||||
return benchmarks
|
||||
|
||||
|
||||
def run_single_benchmark(
|
||||
benchmark_info: Dict[str, Any],
|
||||
output_dir: str,
|
||||
logger: logging.Logger,
|
||||
**kwargs
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Run a single benchmark and return the output file path.
|
||||
|
||||
Args:
|
||||
benchmark_info: Dictionary containing benchmark module info
|
||||
output_dir: Base output directory
|
||||
logger: Logger instance
|
||||
**kwargs: Additional arguments to pass to the benchmark
|
||||
|
||||
Returns:
|
||||
Path to the output file if successful, None otherwise
|
||||
"""
|
||||
benchmark_name = benchmark_info['name']
|
||||
runner_func = benchmark_info['runner_function']
|
||||
|
||||
logger.info(f"Running benchmark: {benchmark_name}")
|
||||
|
||||
try:
|
||||
# Check function signature to determine what arguments to pass
|
||||
import inspect
|
||||
sig = inspect.signature(runner_func)
|
||||
|
||||
# Prepare arguments based on function signature
|
||||
func_kwargs = {
|
||||
'logger': logger,
|
||||
'output_dir': output_dir
|
||||
}
|
||||
|
||||
# Add other kwargs if the function accepts them
|
||||
for param_name in sig.parameters:
|
||||
if param_name in kwargs:
|
||||
func_kwargs[param_name] = kwargs[param_name]
|
||||
|
||||
# Filter kwargs to only include parameters the function accepts
|
||||
# If function has **kwargs, include all provided kwargs
|
||||
has_var_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
|
||||
if has_var_kwargs:
|
||||
valid_kwargs = {**func_kwargs, **kwargs}
|
||||
else:
|
||||
valid_kwargs = {k: v for k, v in func_kwargs.items()
|
||||
if k in sig.parameters}
|
||||
|
||||
# Run the benchmark
|
||||
result = runner_func(**valid_kwargs)
|
||||
|
||||
if isinstance(result, str):
|
||||
# Function returned a file path
|
||||
return result
|
||||
else:
|
||||
logger.info(f"Benchmark {benchmark_name} completed successfully")
|
||||
return "completed"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark {benchmark_name} failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
|
||||
def generate_summary_report(
|
||||
output_dir: str,
|
||||
benchmark_results: Dict[str, Any],
|
||||
logger: logging.Logger
|
||||
) -> str:
|
||||
"""Generate a summary report of all benchmark runs."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
summary_file = os.path.join(output_dir, f"benchmark_summary_{timestamp}.json")
|
||||
|
||||
summary_data = {
|
||||
"run_metadata": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"total_benchmarks": len(benchmark_results),
|
||||
"successful_benchmarks": len([r for r in benchmark_results.values() if r is not None]),
|
||||
"failed_benchmarks": len([r for r in benchmark_results.values() if r is None])
|
||||
},
|
||||
"benchmark_results": benchmark_results,
|
||||
"output_directory": output_dir
|
||||
}
|
||||
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Summary report saved to: {summary_file}")
|
||||
return summary_file
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the benchmarking script."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run all benchmarks in the ./benches directory"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="benchmark_results",
|
||||
help="Base output directory for benchmark results (default: benchmark_results)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--benches-dir",
|
||||
type=str,
|
||||
default="./benches",
|
||||
help="Directory containing benchmark implementations (default: ./benches)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Logging level (default: INFO)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
help="Specific model ID to benchmark (if supported by benchmarks)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warmup-iterations",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warmup iterations (default: 3)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--measurement-iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of measurement iterations (default: 5)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tokens-to-generate",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of tokens to generate in benchmarks (default: 100)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Only run benchmarks matching these names"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exclude",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Exclude benchmarks matching these names"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-mock",
|
||||
action="store_true",
|
||||
help="Enable mock benchmark (skipped by default)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-file-logging",
|
||||
action="store_true",
|
||||
help="Enable file logging (disabled by default)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--commit-id",
|
||||
type=str,
|
||||
help="Git commit ID for metadata (if not provided, will auto-detect from git)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(args.log_level, args.enable_file_logging)
|
||||
|
||||
logger.info("Starting benchmark discovery and execution")
|
||||
logger.info(f"Output directory: {args.output_dir}")
|
||||
logger.info(f"Benches directory: {args.benches_dir}")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Discover benchmarks
|
||||
benchmarks = discover_benchmarks(args.benches_dir)
|
||||
logger.info(f"Discovered {len(benchmarks)} benchmark(s): {[b['name'] for b in benchmarks]}")
|
||||
|
||||
if not benchmarks:
|
||||
logger.warning("No benchmarks found!")
|
||||
return 1
|
||||
|
||||
# Filter benchmarks based on include/exclude
|
||||
filtered_benchmarks = benchmarks
|
||||
|
||||
if args.include:
|
||||
filtered_benchmarks = [b for b in filtered_benchmarks
|
||||
if any(pattern in b['name'] for pattern in args.include)]
|
||||
logger.info(f"Filtered to include: {[b['name'] for b in filtered_benchmarks]}")
|
||||
|
||||
if args.exclude:
|
||||
filtered_benchmarks = [b for b in filtered_benchmarks
|
||||
if not any(pattern in b['name'] for pattern in args.exclude)]
|
||||
logger.info(f"After exclusion: {[b['name'] for b in filtered_benchmarks]}")
|
||||
|
||||
if not filtered_benchmarks:
|
||||
logger.warning("No benchmarks remaining after filtering!")
|
||||
return 1
|
||||
|
||||
# Prepare common kwargs for benchmarks
|
||||
benchmark_kwargs = {
|
||||
'warmup_iterations': args.warmup_iterations,
|
||||
'measurement_iterations': args.measurement_iterations,
|
||||
'num_tokens_to_generate': args.num_tokens_to_generate
|
||||
}
|
||||
|
||||
if args.model_id:
|
||||
benchmark_kwargs['model_id'] = args.model_id
|
||||
|
||||
# Add enable_mock flag for mock benchmark
|
||||
benchmark_kwargs['enable_mock'] = args.enable_mock
|
||||
|
||||
# Add commit_id if provided
|
||||
if args.commit_id:
|
||||
benchmark_kwargs['commit_id'] = args.commit_id
|
||||
|
||||
# Run benchmarks
|
||||
benchmark_results = {}
|
||||
successful_count = 0
|
||||
|
||||
for benchmark_info in filtered_benchmarks:
|
||||
result = run_single_benchmark(
|
||||
benchmark_info,
|
||||
args.output_dir,
|
||||
logger,
|
||||
**benchmark_kwargs
|
||||
)
|
||||
|
||||
benchmark_results[benchmark_info['name']] = result
|
||||
|
||||
if result is not None:
|
||||
successful_count += 1
|
||||
|
||||
# Generate summary report
|
||||
summary_file = generate_summary_report(args.output_dir, benchmark_results, logger)
|
||||
|
||||
# Final summary
|
||||
total_benchmarks = len(filtered_benchmarks)
|
||||
failed_count = total_benchmarks - successful_count
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("BENCHMARK RUN SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total benchmarks: {total_benchmarks}")
|
||||
logger.info(f"Successful: {successful_count}")
|
||||
logger.info(f"Failed: {failed_count}")
|
||||
logger.info(f"Output directory: {args.output_dir}")
|
||||
logger.info(f"Summary report: {summary_file}")
|
||||
|
||||
if failed_count > 0:
|
||||
logger.warning(f"{failed_count} benchmark(s) failed. Check logs for details.")
|
||||
return 1
|
||||
else:
|
||||
logger.info("All benchmarks completed successfully!")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark run failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
Reference in New Issue
Block a user