mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Benchmark overhaul (#41408)
* Big refactor, still classes to move around and script to re-complexify * Move to streamer, isolate benches, propagate num tokens * Some refacto * Added compile mode to name * Re-order * Move to dt_tokens * Better format * Fix and disable use_cache by default * Fixed compile and SDPA backend default * Refactor results format * Added default compile mode * Always use cache * Fixed cache and added flex * Plan for missing modules * Experiments: no cg and shuffle * Disable compile for FA * Remove wall time, add sweep mode, get git commit * Review compliance, start * Apply suggestions from code review Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * Update benchmark_v2/framework/benchmark_runner.py Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * Disable workflow * Pretty print * Added some pretty names to have pretty logs * Review n2 compliance (end?) * Style and end of PR --------- Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
This commit is contained in:
5
.github/workflows/benchmark.yml
vendored
5
.github/workflows/benchmark.yml
vendored
@ -1,10 +1,7 @@
|
|||||||
name: Self-hosted runner (benchmark)
|
name: Self-hosted runner (benchmark)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
workflow_dispatch:
|
||||||
branches: [main]
|
|
||||||
pull_request:
|
|
||||||
types: [ opened, labeled, reopened, synchronize ]
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
32
.github/workflows/benchmark_v2.yml
vendored
32
.github/workflows/benchmark_v2.yml
vendored
@ -1,35 +1,7 @@
|
|||||||
name: Benchmark v2 Framework
|
name: Benchmark v2 Framework
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_call:
|
workflow_dispatch:
|
||||||
inputs:
|
|
||||||
runner:
|
|
||||||
description: 'GH Actions runner group to use'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
container_image:
|
|
||||||
description: 'Docker image to use'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
container_options:
|
|
||||||
description: 'Container options to use'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
commit_sha:
|
|
||||||
description: 'Commit SHA to benchmark'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
run_id:
|
|
||||||
description: 'Custom run ID for organizing results (auto-generated if not provided)'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
benchmark_repo_id:
|
|
||||||
description: 'HuggingFace Dataset to upload results to (e.g., "org/benchmark-results")'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
HF_HOME: /mnt/cache
|
HF_HOME: /mnt/cache
|
||||||
@ -82,4 +54,4 @@ jobs:
|
|||||||
--token '${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}' \
|
--token '${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}' \
|
||||||
--log-level INFO
|
--log-level INFO
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
name: Benchmark v2 Scheduled Runner - A10 Single-GPU
|
name: Benchmark v2 Scheduled Runner - A10 Single-GPU
|
||||||
|
|
||||||
on:
|
on:
|
||||||
schedule:
|
workflow_dispatch:
|
||||||
# Run daily at 16:30 UTC
|
|
||||||
- cron: "30 16 * * *"
|
|
||||||
pull_request:
|
|
||||||
types: [ opened, labeled, reopened, synchronize ]
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
benchmark-v2-default:
|
benchmark-v2-default:
|
||||||
@ -18,4 +14,4 @@ jobs:
|
|||||||
commit_sha: ${{ github.sha }}
|
commit_sha: ${{ github.sha }}
|
||||||
run_id: ${{ github.run_id }}
|
run_id: ${{ github.run_id }}
|
||||||
benchmark_repo_id: hf-internal-testing/transformers-daily-benchmarks
|
benchmark_repo_id: hf-internal-testing/transformers-daily-benchmarks
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
name: Benchmark v2 Scheduled Runner - MI325 Single-GPU
|
name: Benchmark v2 Scheduled Runner - MI325 Single-GPU
|
||||||
|
|
||||||
on:
|
on:
|
||||||
schedule:
|
workflow_dispatch:
|
||||||
# Run daily at 16:30 UTC
|
|
||||||
- cron: "30 16 * * *"
|
|
||||||
pull_request:
|
|
||||||
types: [ opened, labeled, reopened, synchronize ]
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
benchmark-v2-default:
|
benchmark-v2-default:
|
||||||
@ -18,4 +14,4 @@ jobs:
|
|||||||
commit_sha: ${{ github.sha }}
|
commit_sha: ${{ github.sha }}
|
||||||
run_id: ${{ github.run_id }}
|
run_id: ${{ github.run_id }}
|
||||||
benchmark_repo_id: hf-internal-testing/transformers-daily-benchmarks
|
benchmark_repo_id: hf-internal-testing/transformers-daily-benchmarks
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
3
benchmark_v2/.gitignore
vendored
3
benchmark_v2/.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
benchmark_results/
|
benchmark_results/
|
||||||
|
benchmark_results_profiles/
|
||||||
|
@ -1 +0,0 @@
|
|||||||
# Benchmark implementations directory
|
|
@ -1,165 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from benchmark_framework import ModelBenchmark
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
|
||||||
torch.set_float32_matmul_precision("high")
|
|
||||||
|
|
||||||
|
|
||||||
class LLaMABenchmark(ModelBenchmark):
|
|
||||||
"""Simplified LLaMA model benchmark implementation using the ModelBenchmark base class."""
|
|
||||||
|
|
||||||
def __init__(self, logger: logging.Logger):
|
|
||||||
super().__init__(logger)
|
|
||||||
self._default_prompt = "Why dogs are so cute?" # Custom prompt for LLaMA
|
|
||||||
|
|
||||||
def get_scenario_configs(self) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get LLaMA-specific scenario configurations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of scenario configuration dictionaries
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
# Eager variants
|
|
||||||
{"variant": "eager", "compile_mode": None, "use_cache": True, "description": "Eager execution with cache"},
|
|
||||||
# Compiled variants
|
|
||||||
{
|
|
||||||
"variant": "compiled",
|
|
||||||
"compile_mode": "max-autotune",
|
|
||||||
"use_cache": True,
|
|
||||||
"description": "Compiled with max autotune",
|
|
||||||
},
|
|
||||||
# Kernelized variant (if available)
|
|
||||||
{
|
|
||||||
"variant": "kernelized",
|
|
||||||
"compile_mode": "max-autotune",
|
|
||||||
"use_cache": True,
|
|
||||||
"description": "Kernelized execution",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
def _is_kernelization_available(self) -> bool:
|
|
||||||
"""Check if kernelization is available for LLaMA."""
|
|
||||||
try:
|
|
||||||
from kernels import Mode, kernelize # noqa: F401
|
|
||||||
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
self.logger.debug("Kernelization not available: kernels module not found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_default_generation_config(self) -> dict[str, Any]:
|
|
||||||
"""Get LLaMA-specific generation configuration."""
|
|
||||||
return {
|
|
||||||
"do_sample": False,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"temperature": 1.0,
|
|
||||||
"repetition_penalty": 1.0,
|
|
||||||
"max_new_tokens": None, # Will be set per scenario
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model_init_kwargs(self, config) -> dict[str, Any]:
|
|
||||||
"""Get LLaMA-specific model initialization kwargs."""
|
|
||||||
return {
|
|
||||||
"torch_dtype": getattr(torch, config.torch_dtype),
|
|
||||||
"attn_implementation": config.attn_implementation,
|
|
||||||
"use_cache": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_default_torch_dtype(self) -> str:
|
|
||||||
"""Get default torch dtype for LLaMA."""
|
|
||||||
return "float16" # LLaMA works well with float16
|
|
||||||
|
|
||||||
def get_default_device(self) -> str:
|
|
||||||
"""Get default device for LLaMA."""
|
|
||||||
return "cuda" # LLaMA prefers CUDA
|
|
||||||
|
|
||||||
|
|
||||||
def run_llama(logger, output_dir, **kwargs):
|
|
||||||
"""
|
|
||||||
Run LLaMA benchmark with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logger: Logger instance
|
|
||||||
output_dir: Output directory for results
|
|
||||||
**kwargs: Additional configuration options
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to output file if successful
|
|
||||||
"""
|
|
||||||
from benchmark_framework import BenchmarkRunner
|
|
||||||
|
|
||||||
# Extract parameters with defaults
|
|
||||||
model_id = kwargs.get("model_id", "meta-llama/Llama-2-7b-hf")
|
|
||||||
warmup_iterations = kwargs.get("warmup_iterations", 3)
|
|
||||||
measurement_iterations = kwargs.get("measurement_iterations", 5)
|
|
||||||
num_tokens_to_generate = kwargs.get("num_tokens_to_generate", 100)
|
|
||||||
include_sdpa_variants = kwargs.get("include_sdpa_variants", True)
|
|
||||||
device = kwargs.get("device", "cuda")
|
|
||||||
torch_dtype = kwargs.get("torch_dtype", "float16")
|
|
||||||
batch_size = kwargs.get("batch_size", 1)
|
|
||||||
commit_id = kwargs.get("commit_id")
|
|
||||||
|
|
||||||
logger.info(f"Starting LLaMA benchmark for model: {model_id}")
|
|
||||||
logger.info(
|
|
||||||
f"Configuration: warmup={warmup_iterations}, measurement={measurement_iterations}, tokens={num_tokens_to_generate}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create benchmark instance
|
|
||||||
benchmark = LLaMABenchmark(logger)
|
|
||||||
|
|
||||||
# Create scenarios
|
|
||||||
scenarios = benchmark.create_scenarios(
|
|
||||||
model_id=model_id,
|
|
||||||
warmup_iterations=warmup_iterations,
|
|
||||||
measurement_iterations=measurement_iterations,
|
|
||||||
num_tokens_to_generate=num_tokens_to_generate,
|
|
||||||
include_sdpa_variants=include_sdpa_variants,
|
|
||||||
device=device,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Created {len(scenarios)} benchmark scenarios")
|
|
||||||
|
|
||||||
# Create runner and execute benchmarks
|
|
||||||
runner = BenchmarkRunner(logger, output_dir)
|
|
||||||
results = runner.run_benchmark(benchmark, scenarios, commit_id=commit_id)
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
logger.warning("No successful benchmark results")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
model_name = model_id.split("/")[-1] # Extract model name from ID
|
|
||||||
output_file = runner.save_results(model_name, results)
|
|
||||||
|
|
||||||
logger.info(f"LLaMA benchmark completed successfully. Results saved to: {output_file}")
|
|
||||||
return output_file
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LLaMA benchmark failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
raise
|
|
File diff suppressed because it is too large
Load Diff
218
benchmark_v2/framework/benchmark_config.py
Normal file
218
benchmark_v2/framework/benchmark_config.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
KERNELIZATION_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from kernels import Mode, kernelize # noqa: F401
|
||||||
|
|
||||||
|
KERNELIZATION_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig:
|
||||||
|
"""Configuration for a single benchmark scenario."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
warmup_iterations: int = 5,
|
||||||
|
measurement_iterations: int = 20,
|
||||||
|
gpu_monitoring: bool = False, # False by default because it slows down the benchmark by a lot
|
||||||
|
batch_size: int = 1,
|
||||||
|
sequence_length: int = 128,
|
||||||
|
num_tokens_to_generate: int = 128,
|
||||||
|
attn_implementation: str = "eager",
|
||||||
|
sdpa_backend: Optional[str] = None,
|
||||||
|
compile_mode: Optional[str] = None,
|
||||||
|
compile_options: Optional[dict[str, Any]] = None,
|
||||||
|
kernelize: bool = False,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
skip_validity_check: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# Benchmark parameters
|
||||||
|
self.warmup_iterations = warmup_iterations
|
||||||
|
self.measurement_iterations = measurement_iterations
|
||||||
|
self.gpu_monitoring = gpu_monitoring
|
||||||
|
# Input parameters
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.num_tokens_to_generate = num_tokens_to_generate
|
||||||
|
# Generation parameters
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
self.sdpa_backend = sdpa_backend
|
||||||
|
# Optimization parameters
|
||||||
|
self.compile_mode = compile_mode
|
||||||
|
self.compile_options = compile_options if compile_options is not None else {}
|
||||||
|
self.kernelize = kernelize
|
||||||
|
# Constant parameters
|
||||||
|
self.dtype = "torch.bfloat16"
|
||||||
|
self.device = "cuda"
|
||||||
|
|
||||||
|
self.check_validity(skip_validity_check)
|
||||||
|
self.name = name if name is not None else self.infer_name()
|
||||||
|
|
||||||
|
def check_validity(self, skip_validity_check: bool = False) -> None:
|
||||||
|
if skip_validity_check:
|
||||||
|
return
|
||||||
|
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
|
||||||
|
is_fa = self.attn_implementation == "flash_attention_2"
|
||||||
|
is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
|
||||||
|
if is_fa:
|
||||||
|
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
|
||||||
|
self.compile_mode = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self) -> str:
|
||||||
|
return hashlib.sha256(json.dumps(self.to_dict()).encode()).hexdigest()
|
||||||
|
|
||||||
|
def infer_name(self, compact: bool = True) -> str:
|
||||||
|
"""Infer a human-readable name for the benchmark config, either compact or verbose."""
|
||||||
|
if compact:
|
||||||
|
iter_str = f"w{self.warmup_iterations}_i{self.measurement_iterations}"
|
||||||
|
gpu_monitor_str = "monitored" if self.gpu_monitoring else "unmonitored"
|
||||||
|
dimensions_str = f"b{self.batch_size}_s{self.sequence_length}_n{self.num_tokens_to_generate}"
|
||||||
|
attn_code = self.attn_implementation
|
||||||
|
attn_code += f"_{self.sdpa_backend}" if self.attn_implementation == "sdpa" else ""
|
||||||
|
compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
|
||||||
|
kernelize_str = "kernelized" if self.kernelize else "unkernelized"
|
||||||
|
sep = "-"
|
||||||
|
else:
|
||||||
|
iter_str = f"{self.warmup_iterations} warmup, {self.measurement_iterations} iterations"
|
||||||
|
gpu_monitor_str = ("with" if self.gpu_monitoring else "no") + " GPU monitoring"
|
||||||
|
dimensions_str = f"batch size {self.batch_size}, sequence length {self.sequence_length}, {self.num_tokens_to_generate} generated tokens"
|
||||||
|
attn_code = f"{self.attn_implementation} attention"
|
||||||
|
attn_code += f" with {self.sdpa_backend} backend" if self.attn_implementation == "sdpa" else ""
|
||||||
|
compile_str = "compiled" if self.compile_mode is not None else "not compiled"
|
||||||
|
kernelize_str = "kernelized" if self.kernelize else "not kernelized"
|
||||||
|
sep = ", "
|
||||||
|
return sep.join([iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str])
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"warmup_iterations": self.warmup_iterations,
|
||||||
|
"measurement_iterations": self.measurement_iterations,
|
||||||
|
"gpu_monitoring": self.gpu_monitoring,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"sequence_length": self.sequence_length,
|
||||||
|
"num_tokens_to_generate": self.num_tokens_to_generate,
|
||||||
|
"attn_implementation": self.attn_implementation,
|
||||||
|
"sdpa_backend": self.sdpa_backend,
|
||||||
|
"compile_mode": self.compile_mode,
|
||||||
|
"compile_options": self.compile_options,
|
||||||
|
"kernelize": self.kernelize,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "BenchmarkConfig":
|
||||||
|
return cls(
|
||||||
|
warmup_iterations=data.get("warmup_iterations", 5),
|
||||||
|
measurement_iterations=data.get("measurement_iterations", 20),
|
||||||
|
gpu_monitoring=data.get("gpu_monitoring", False),
|
||||||
|
batch_size=data.get("batch_size", 1),
|
||||||
|
sequence_length=data.get("sequence_length", 128),
|
||||||
|
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
|
||||||
|
attn_implementation=data.get("attn_implementation", "eager"),
|
||||||
|
sdpa_backend=data.get("sdpa_backend"),
|
||||||
|
compile_mode=data.get("compile_mode"),
|
||||||
|
compile_options=data.get("compile_options"),
|
||||||
|
kernelize=data.get("kernelize", False),
|
||||||
|
name=data.get("name"),
|
||||||
|
skip_validity_check=skip_validity_check,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cross_generate_configs(
|
||||||
|
attn_impl_and_sdpa_backend: list[tuple[str, Optional[str]]],
|
||||||
|
compiled_mode: list[Optional[str]],
|
||||||
|
kernelized: list[bool],
|
||||||
|
warmup_iterations: int = 5,
|
||||||
|
measurement_iterations: int = 20,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sequence_length: int = 128,
|
||||||
|
num_tokens_to_generate: int = 128,
|
||||||
|
gpu_monitoring: bool = False, # this slows down the benchmark by a lot so we disable it by default
|
||||||
|
) -> list[BenchmarkConfig]:
|
||||||
|
# Create kwargs common to all configs
|
||||||
|
kwargs = {
|
||||||
|
"warmup_iterations": warmup_iterations,
|
||||||
|
"measurement_iterations": measurement_iterations,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"sequence_length": sequence_length,
|
||||||
|
"num_tokens_to_generate": num_tokens_to_generate,
|
||||||
|
"gpu_monitoring": gpu_monitoring,
|
||||||
|
}
|
||||||
|
# Cross-generate all combinations of attn_implementation, compiled_mode, and kernelized
|
||||||
|
configs = []
|
||||||
|
for attn_implementation, sdpa_backend in list(dict.fromkeys(attn_impl_and_sdpa_backend)):
|
||||||
|
for cm in list(dict.fromkeys(compiled_mode)):
|
||||||
|
for kernelize_on in list(dict.fromkeys(kernelized)):
|
||||||
|
config = BenchmarkConfig(
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
sdpa_backend=sdpa_backend,
|
||||||
|
compile_mode=cm,
|
||||||
|
kernelize=kernelize_on,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
configs.append(config)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_all_configs(
|
||||||
|
warmup_iterations: int = 5,
|
||||||
|
measurement_iterations: int = 20,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sequence_length: int = 128,
|
||||||
|
num_tokens_to_generate: int = 128,
|
||||||
|
gpu_monitoring: bool = False,
|
||||||
|
) -> list[BenchmarkConfig]:
|
||||||
|
all_attn_implementations = [
|
||||||
|
("flash_attention_2", None),
|
||||||
|
("eager", None),
|
||||||
|
("sdpa", "math"),
|
||||||
|
("sdpa", "flash_attention"),
|
||||||
|
("flex_attention", None),
|
||||||
|
]
|
||||||
|
return cross_generate_configs(
|
||||||
|
attn_impl_and_sdpa_backend=all_attn_implementations,
|
||||||
|
compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
|
||||||
|
kernelized=[False, KERNELIZATION_AVAILABLE],
|
||||||
|
warmup_iterations=warmup_iterations,
|
||||||
|
measurement_iterations=measurement_iterations,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
num_tokens_to_generate=num_tokens_to_generate,
|
||||||
|
gpu_monitoring=gpu_monitoring,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_default_configs(
|
||||||
|
warmup_iterations: int = 5,
|
||||||
|
measurement_iterations: int = 20,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sequence_length: int = 128,
|
||||||
|
num_tokens_to_generate: int = 128,
|
||||||
|
gpu_monitoring: bool = False,
|
||||||
|
) -> list[BenchmarkConfig]:
|
||||||
|
all_attn_implementations = [
|
||||||
|
("flash_attention_2", None),
|
||||||
|
("eager", None),
|
||||||
|
("sdpa", "math"),
|
||||||
|
("sdpa", "flash_attention"), # note: this one can fail with compile because of attn mask
|
||||||
|
]
|
||||||
|
return cross_generate_configs(
|
||||||
|
attn_impl_and_sdpa_backend=all_attn_implementations,
|
||||||
|
compiled_mode=[None, "max-autotune"],
|
||||||
|
kernelized=[False, KERNELIZATION_AVAILABLE],
|
||||||
|
warmup_iterations=warmup_iterations,
|
||||||
|
measurement_iterations=measurement_iterations,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
num_tokens_to_generate=num_tokens_to_generate,
|
||||||
|
gpu_monitoring=gpu_monitoring,
|
||||||
|
)
|
388
benchmark_v2/framework/benchmark_runner.py
Normal file
388
benchmark_v2/framework/benchmark_runner.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from datetime import datetime
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
CompileConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
GenerationMixin,
|
||||||
|
)
|
||||||
|
from transformers.generation.streamers import BaseStreamer
|
||||||
|
|
||||||
|
from .benchmark_config import BenchmarkConfig
|
||||||
|
from .data_classes import BenchmarkMetadata, BenchmarkResult, GPURawMetrics, pretty_print_dict
|
||||||
|
from .hardware_metrics import GPUMonitor
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from kernels import Mode, kernelize # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
kernelize = None
|
||||||
|
Mode = None
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = "\n".join([
|
||||||
|
"The French Revolution was a period of political and societal change in France that began with the Estates General of 1789 and ended with the Coup of 18 Brumaire on 9 November 1799.",
|
||||||
|
"Many of the revolution's ideas are considered fundamental principles of liberal democracy, and its values remain central to modern French political discourse.",
|
||||||
|
"It was caused by a combination of social, political, and economic factors which the existing regime proved unable to manage.",
|
||||||
|
"Financial crisis and widespread social distress led to the convocation of the Estates General in May 1789, its first meeting since 1614.",
|
||||||
|
"The representatives of the Third Estate broke away and re-constituted themselves as a National Assembly in June.",
|
||||||
|
"The Storming of the Bastille in Paris on 14 July led to a series of radical measures by the Assembly, including the abolition of feudalism, state control over the Catholic Church in France, and issuing the Declaration of the Rights of Man and of the Citizen.",
|
||||||
|
"The next three years were dominated by a struggle for political control.",
|
||||||
|
"King Louis XVI's attempted flight to Varennes in June 1791 further discredited the monarchy, and military defeats after the outbreak of the French Revolutionary Wars in April 1792 led to the insurrection of 10 August 1792.",
|
||||||
|
"As a result, the monarchy was replaced by the French First Republic in September, followed by the execution of Louis XVI himself in January 1793.",
|
||||||
|
"After another revolt in June 1793, the constitution was suspended, and political power passed from the National Convention to the Committee of Public Safety, dominated by radical Jacobins led by Maximilien Robespierre.",
|
||||||
|
"About 16,000 people were sentenced by the Revolutionary Tribunal and executed in the Reign of Terror, which ended in July 1794 with the Thermidorian Reaction.",
|
||||||
|
"Weakened by external threats and internal opposition, the Committee of Public Safety was replaced in November 1795 by the Directory.",
|
||||||
|
"Its instability ended in the coup of 18 Brumaire and the establishment of the Consulate, with Napoleon Bonaparte as First Consul.",
|
||||||
|
]) # fmt: skip
|
||||||
|
|
||||||
|
|
||||||
|
def compact_json_numeric_arrays(data: dict):
|
||||||
|
# Match arrays that contain only numbers (ints/floats), whitespace, commas, and newlines
|
||||||
|
pattern = r"\[\s*\n\s*((?:\d+(?:\.\d+)?\s*,\s*)*\d+(?:\.\d+)?)\s*\n\s*\]"
|
||||||
|
|
||||||
|
def replace_numeric_array(match):
|
||||||
|
# Get the array content
|
||||||
|
content = match.group(1)
|
||||||
|
# Remove extra whitespace but keep commas
|
||||||
|
compact_content = re.sub(r"\s+", " ", content).strip()
|
||||||
|
return f"[{compact_content}]"
|
||||||
|
|
||||||
|
return re.sub(pattern, replace_numeric_array, json.dumps(data, indent=4, default=str), flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_revision() -> str:
|
||||||
|
base_path = pathlib.Path(__file__).parent.parent.parent
|
||||||
|
git_dir = base_path / ".git"
|
||||||
|
with (git_dir / "HEAD").open("r") as head:
|
||||||
|
ref = head.readline().split(" ")[-1].strip()
|
||||||
|
with (git_dir / ref).open("r") as git_hash:
|
||||||
|
return git_hash.readline().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sdpa_backend(backend_name: Optional[str]) -> Optional[torch.nn.attention.SDPBackend]:
|
||||||
|
"""Get the SDPA backend enum from string name."""
|
||||||
|
if backend_name is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend_map = {
|
||||||
|
"math": torch.nn.attention.SDPBackend.MATH,
|
||||||
|
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||||
|
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
|
||||||
|
}
|
||||||
|
return backend_map.get(backend_name.lower())
|
||||||
|
except AttributeError:
|
||||||
|
# torch.nn.attention.SDPBackend not available in older torch versions
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def flush_memory():
|
||||||
|
"""Flush GPU memory and run garbage collection."""
|
||||||
|
gc.collect()
|
||||||
|
# Dynamo resets
|
||||||
|
torch._dynamo.reset()
|
||||||
|
torch._dynamo.reset_code_caches()
|
||||||
|
if hasattr(torch._inductor, "codecache"):
|
||||||
|
# Clear FX graph cache
|
||||||
|
if hasattr(torch._inductor.codecache, "FxGraphCache"):
|
||||||
|
torch._inductor.codecache.FxGraphCache.clear()
|
||||||
|
# Clear PyCodeCache
|
||||||
|
if hasattr(torch._inductor.codecache, "PyCodeCache"):
|
||||||
|
torch._inductor.codecache.PyCodeCache.cache_clear()
|
||||||
|
# Clear TritonFuture cache (for async compilation)
|
||||||
|
if hasattr(torch._inductor.codecache, "TritonFuture"):
|
||||||
|
if hasattr(torch._inductor.codecache.TritonFuture, "_compile_cache"):
|
||||||
|
torch._inductor.codecache.TritonFuture._compile_cache.clear()
|
||||||
|
# Clear CUDA cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_max_memory_allocated()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkStreamer(BaseStreamer):
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
self.timestamps = []
|
||||||
|
self.text_queue = Queue()
|
||||||
|
|
||||||
|
def put(self, value):
|
||||||
|
"""Receives tokens and logs the timestamp of the generation."""
|
||||||
|
self.timestamps.append(time.perf_counter())
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.timestamps.append(time.perf_counter())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
value = self.text_queue.get(timeout=self.timeout)
|
||||||
|
if value == self.stop_signal:
|
||||||
|
raise StopIteration()
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRunner:
|
||||||
|
"""Main benchmark runner that coordinates benchmark execution."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
# Those stay constant for the whole run
|
||||||
|
self.logger = logger
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.commit_id = get_git_revision() if commit_id is None else commit_id
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
self.profile_dir = None
|
||||||
|
# Attributes that are reset for each model
|
||||||
|
self._setup_for = ""
|
||||||
|
# Attributes that are reset for each run
|
||||||
|
self.model: Optional[GenerationMixin] = None
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
del self.model
|
||||||
|
self.model = None
|
||||||
|
flush_memory()
|
||||||
|
|
||||||
|
def setup_one_run(self, model_id: str, config: BenchmarkConfig) -> None:
|
||||||
|
# Some attributes only need to be set once per model
|
||||||
|
if self._setup_for != model_id:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
# We set the EOS token to the padding token for open-ended generation
|
||||||
|
self.tokenizer.eos_token = self.tokenizer.pad_token
|
||||||
|
self._setup_for = model_id
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
self.inputs = self.tokenizer(
|
||||||
|
[DEFAULT_PROMPT for _ in range(config.batch_size)],
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=config.sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
).to(config.device)
|
||||||
|
self.inputs["use_cache"] = True
|
||||||
|
|
||||||
|
# Prepare generation config
|
||||||
|
gen_config = GenerationConfig(
|
||||||
|
do_sample=False, top_p=1.0, temperature=1.0, max_new_tokens=config.num_tokens_to_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare compile config
|
||||||
|
if config.compile_mode is not None:
|
||||||
|
gen_config.compile_config = CompileConfig(mode=config.compile_mode, options=config.compile_options)
|
||||||
|
gen_config.cache_implementation = "static"
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
self.logger.debug(f"Loading model {model_id} on device {config.device}...")
|
||||||
|
dtype = getattr(torch, config.dtype.removeprefix("torch."))
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, dtype=dtype, attn_implementation=config.attn_implementation, generation_config=gen_config
|
||||||
|
)
|
||||||
|
self.model = self.model.eval().to(config.device)
|
||||||
|
|
||||||
|
# Kernelize the model if needed
|
||||||
|
if config.kernelize:
|
||||||
|
self.model = kernelize(self.model, mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
def run_one_benchmark(self, model_id: str, config: BenchmarkConfig, num_tokens_to_profile: int = 0) -> None:
|
||||||
|
sdpa_ctx = nullcontext()
|
||||||
|
if config.attn_implementation == "sdpa":
|
||||||
|
sdpa_backend = get_sdpa_backend(config.sdpa_backend)
|
||||||
|
sdpa_ctx = torch.nn.attention.sdpa_kernel(sdpa_backend)
|
||||||
|
|
||||||
|
with sdpa_ctx, torch.no_grad():
|
||||||
|
self.logger.info(f"Running benchmark scenario: {config.name}")
|
||||||
|
|
||||||
|
# Quick validation: try one measurement first to see if this scenario works
|
||||||
|
flush_memory()
|
||||||
|
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate(
|
||||||
|
max_new_tokens=1, gpu_monitor=None
|
||||||
|
)
|
||||||
|
if e2e_latency < 0:
|
||||||
|
self.logger.warning(f"Skipping config {config.name}: {e2e_latency = } (no GPU monitoring)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Warmup runs
|
||||||
|
self.logger.info(f"Warming up with {config.warmup_iterations} iterations...")
|
||||||
|
for _ in trange(config.warmup_iterations):
|
||||||
|
_ = self.time_generate(max_new_tokens=config.num_tokens_to_generate)
|
||||||
|
self.logger.info("Warmup over.")
|
||||||
|
|
||||||
|
# Measurement runs
|
||||||
|
result = BenchmarkResult()
|
||||||
|
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
|
||||||
|
for _ in trange(config.measurement_iterations):
|
||||||
|
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate(
|
||||||
|
max_new_tokens=config.num_tokens_to_generate,
|
||||||
|
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
|
||||||
|
)
|
||||||
|
result.accumulate(e2e_latency, token_generation_times, decoded_output, gpu_metrics)
|
||||||
|
self.logger.info("Benchmarking done. Cleaning up.")
|
||||||
|
|
||||||
|
# Profile if needed
|
||||||
|
if num_tokens_to_profile > 0:
|
||||||
|
self.profile_generate(num_tokens_to_profile, config.name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"metadata": BenchmarkMetadata(model_id=model_id, commit_id=self.commit_id),
|
||||||
|
"measurements": result,
|
||||||
|
"config": config,
|
||||||
|
}
|
||||||
|
|
||||||
|
def time_generate(
|
||||||
|
self,
|
||||||
|
max_new_tokens: int,
|
||||||
|
gpu_monitor: Optional[GPUMonitor] = None,
|
||||||
|
) -> tuple[float, list[float], str, Optional[GPURawMetrics]]:
|
||||||
|
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
||||||
|
# Prepare gpu monitoring if needed
|
||||||
|
if gpu_monitor is not None:
|
||||||
|
gpu_monitor.start()
|
||||||
|
# Prepare streamer
|
||||||
|
streamer = BenchmarkStreamer()
|
||||||
|
# Generate and time
|
||||||
|
wall_time_0 = time.perf_counter()
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**self.inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
streamer=streamer,
|
||||||
|
)
|
||||||
|
wall_time_1 = time.perf_counter()
|
||||||
|
# Stop gpu monitoring if needed
|
||||||
|
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
|
||||||
|
# Check if generation had the right number of tokens
|
||||||
|
input_tokens = self.inputs["input_ids"].size(-1)
|
||||||
|
batch_size, output_tokens = outputs.shape
|
||||||
|
new_tokens = output_tokens - input_tokens
|
||||||
|
if new_tokens != max_new_tokens:
|
||||||
|
raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}")
|
||||||
|
# Decode outputs
|
||||||
|
decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True)
|
||||||
|
# Compute intermediate quantities
|
||||||
|
e2e_latency = wall_time_1 - wall_time_0
|
||||||
|
token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]]
|
||||||
|
return e2e_latency, token_generation_times, decoded_output, gpu_metrics
|
||||||
|
|
||||||
|
def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None:
|
||||||
|
"""Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||||
|
record_shapes=True,
|
||||||
|
)
|
||||||
|
with profiler as prof:
|
||||||
|
_ = self.model.generate(
|
||||||
|
**self.inputs,
|
||||||
|
max_new_tokens=num_tokens_to_profile,
|
||||||
|
)
|
||||||
|
if self.profile_dir is None:
|
||||||
|
self.profile_dir = self.output_dir + "_profiles"
|
||||||
|
os.makedirs(self.profile_dir, exist_ok=True)
|
||||||
|
prof.export_chrome_trace(f"{self.profile_dir}/{config_name}.json")
|
||||||
|
|
||||||
|
def run_benchmarks(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
benchmark_configs: list[BenchmarkConfig],
|
||||||
|
num_tokens_to_profile: int = 0,
|
||||||
|
pretty_print_summary: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
all_results = {}
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
n_configs = len(benchmark_configs)
|
||||||
|
for i, config in enumerate(benchmark_configs):
|
||||||
|
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
|
||||||
|
if config.attn_implementation == "sdpa" and config.sdpa_backend is None:
|
||||||
|
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
|
||||||
|
self.logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
|
||||||
|
config.sdpa_backend = default_backend
|
||||||
|
|
||||||
|
# Skip if already run
|
||||||
|
if config.hash in all_results:
|
||||||
|
self.logger.info(f"Skipping duplicate config {config.name} for model {model_id} ({i + 1}/{n_configs})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Otherwise, run the benchmark
|
||||||
|
self.setup_one_run(model_id, config)
|
||||||
|
self.logger.info(
|
||||||
|
f"Running benchmark of model {model_id} with scenario: {config.name} ({i + 1}/{n_configs})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch benchmark in a try/except block to avoid stopping the whole run if one benchmark fails
|
||||||
|
try:
|
||||||
|
results = self.run_one_benchmark(model_id, config, num_tokens_to_profile)
|
||||||
|
if results is not None:
|
||||||
|
all_results[config.hash] = results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error running with scenario: {config.name}:\n{repr(e)}")
|
||||||
|
# Cleanup model and save results
|
||||||
|
self.cleanup()
|
||||||
|
self.save_results(model_id, all_results, timestamp=timestamp)
|
||||||
|
|
||||||
|
if pretty_print_summary:
|
||||||
|
print()
|
||||||
|
print("=" * 100)
|
||||||
|
print(f"Finished benchmarks in {time.perf_counter() - start_time:.2f} seconds")
|
||||||
|
print(f"Total number of benchmarks: {len(all_results)}")
|
||||||
|
if len(all_results) > 0:
|
||||||
|
print("First run metadata:")
|
||||||
|
first_key = list(all_results.keys())[0]
|
||||||
|
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||||
|
hardware_info = first_metadata.pop("hardware_info")
|
||||||
|
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||||
|
for value in all_results.values():
|
||||||
|
print("=" * 100)
|
||||||
|
print(f"Config: {value['config'].infer_name(compact=False)}\n")
|
||||||
|
value["measurements"].pprint(tabs=1)
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def save_results(self, model_name: str, results: dict, timestamp: str = "") -> str:
|
||||||
|
"""Save benchmark results to JSON file."""
|
||||||
|
# Create model-specific subdirectory
|
||||||
|
model_name = model_name.replace("/", "_")
|
||||||
|
model_dir = os.path.join(self.output_dir, model_name)
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if not timestamp else timestamp
|
||||||
|
filename = f"{model_name}_benchmark_{timestamp}.json"
|
||||||
|
filepath = os.path.join(model_dir, filename)
|
||||||
|
|
||||||
|
# Convert results to dict
|
||||||
|
converted_results = {}
|
||||||
|
for cfg_hash in results.keys():
|
||||||
|
converted_results[cfg_hash] = {
|
||||||
|
"metadata": results[cfg_hash]["metadata"].to_dict(),
|
||||||
|
"measurements": results[cfg_hash]["measurements"].to_dict(),
|
||||||
|
"config": results[cfg_hash]["config"].to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to JSON file
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
f.write(compact_json_numeric_arrays(converted_results))
|
||||||
|
|
||||||
|
self.logger.info(f"Results saved to {filepath}")
|
||||||
|
return filepath
|
152
benchmark_v2/framework/data_classes.py
Normal file
152
benchmark_v2/framework/data_classes.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .hardware_metrics import GPURawMetrics, HardwareInfo
|
||||||
|
|
||||||
|
|
||||||
|
def compute_basic_statistics(measurements: list[float]) -> dict[str, float]:
|
||||||
|
return {
|
||||||
|
"avg": np.mean(measurements),
|
||||||
|
"std": np.std(measurements),
|
||||||
|
"min": np.min(measurements),
|
||||||
|
"med": np.median(measurements),
|
||||||
|
"max": np.max(measurements),
|
||||||
|
"p95": np.percentile(measurements, 95),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_unit_to_duration(stats: dict[str, float]) -> dict[str, str]:
|
||||||
|
for key in list(stats.keys()):
|
||||||
|
value = stats[key]
|
||||||
|
if value > 3600:
|
||||||
|
stats[key] = f"{(value / 3600):.2f}hr"
|
||||||
|
elif value > 60:
|
||||||
|
stats[key] = f"{(value / 60):.2f}min"
|
||||||
|
elif value > 1:
|
||||||
|
stats[key] = f"{value:.2f}s"
|
||||||
|
elif value > 1e-3:
|
||||||
|
stats[key] = f"{(value * 1e3):.2f}ms"
|
||||||
|
elif value > 1e-6:
|
||||||
|
stats[key] = f"{(value * 1e6):.2f}us"
|
||||||
|
else:
|
||||||
|
stats[key] = f"{(value * 1e9):.2f}ns"
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def equalize_lengths_and_collate(stats: list[dict[str, str]]) -> list[str]:
|
||||||
|
keys = ["avg", "std", "min", "med", "max", "p95"]
|
||||||
|
for key in keys:
|
||||||
|
max_length = max(len(stat[key]) for stat in stats)
|
||||||
|
for stat in stats:
|
||||||
|
stat[key] = stat[key].ljust(max_length, " ")
|
||||||
|
return [" ".join([f"{key}={stat[key]}" for key in keys]) for stat in stats]
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_dict(data: dict[str, Any], tabs: int = 0) -> None:
|
||||||
|
max_key_length = max([len(key) for key in data.keys()])
|
||||||
|
for key, value in data.items():
|
||||||
|
tabs_str = " " * tabs
|
||||||
|
padded_key = key.ljust(max_key_length + 1, ".")
|
||||||
|
print(f"{tabs_str}{padded_key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkMetadata:
|
||||||
|
"""Metadata collected for each benchmark run."""
|
||||||
|
|
||||||
|
model_id: str
|
||||||
|
timestamp: str
|
||||||
|
commit_id: str
|
||||||
|
hardware_info: HardwareInfo
|
||||||
|
|
||||||
|
def __init__(self, model_id: str, commit_id: str):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.timestamp = datetime.utcnow().isoformat()
|
||||||
|
self.commit_id = commit_id
|
||||||
|
self.hardware_info = HardwareInfo()
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"commit_id": self.commit_id,
|
||||||
|
"hardware_info": self.hardware_info.to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkResult:
|
||||||
|
"""Result from a series of benchmark runs."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.e2e_latency = []
|
||||||
|
self.token_generation_times = [] # time at which each token was generated (relative to start of the generation)
|
||||||
|
self.decoded_outputs = []
|
||||||
|
self.gpu_metrics = []
|
||||||
|
|
||||||
|
def accumulate(
|
||||||
|
self,
|
||||||
|
e2e_latency: float,
|
||||||
|
token_generation_times: list[float],
|
||||||
|
decoded_output: str,
|
||||||
|
gpu_metrics: Optional[GPURawMetrics],
|
||||||
|
) -> None:
|
||||||
|
self.e2e_latency.append(e2e_latency)
|
||||||
|
self.token_generation_times.append(token_generation_times)
|
||||||
|
self.decoded_outputs.append(decoded_output)
|
||||||
|
self.gpu_metrics.append(gpu_metrics)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Union[None, int, float]]:
|
||||||
|
# Save GPU metrics as None if it contains only None values
|
||||||
|
if all(gm is None for gm in self.gpu_metrics):
|
||||||
|
gpu_metrics = None
|
||||||
|
else:
|
||||||
|
gpu_metrics = [gm.to_dict() for gm in self.gpu_metrics]
|
||||||
|
return {
|
||||||
|
"e2e_latency": self.e2e_latency,
|
||||||
|
"token_generation_times": self.token_generation_times,
|
||||||
|
"decoded_outputs": self.decoded_outputs,
|
||||||
|
"gpu_metrics": gpu_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Union[None, int, float]]) -> "BenchmarkResult":
|
||||||
|
# Handle GPU metrics, which is saved as None if it contains only None values
|
||||||
|
if data["gpu_metrics"] is None:
|
||||||
|
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
|
||||||
|
else:
|
||||||
|
gpu_metrics = [GPURawMetrics.from_dict(gm) for gm in data["gpu_metrics"]]
|
||||||
|
# Create a new instance and accumulate the data
|
||||||
|
new_instance = cls()
|
||||||
|
for i in range(len(data["e2e_latency"])):
|
||||||
|
new_instance.accumulate(
|
||||||
|
e2e_latency=data["e2e_latency"][i],
|
||||||
|
token_generation_times=data["token_generation_times"][i],
|
||||||
|
decoded_output=data["decoded_output"][i],
|
||||||
|
gpu_metrics=gpu_metrics[i],
|
||||||
|
)
|
||||||
|
return new_instance
|
||||||
|
|
||||||
|
def get_measured_ttft(self) -> list[float]:
|
||||||
|
return [dt[0] for dt in self.token_generation_times if len(dt) > 0]
|
||||||
|
|
||||||
|
def get_measured_itl(self) -> list[float]:
|
||||||
|
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
||||||
|
|
||||||
|
def pprint(self, tabs: int = 0) -> None:
|
||||||
|
collated_stats = equalize_lengths_and_collate(
|
||||||
|
[
|
||||||
|
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||||
|
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||||
|
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
pretty_print_dict(
|
||||||
|
{
|
||||||
|
"E2E Latency": collated_stats[0],
|
||||||
|
"Time to First Token": collated_stats[1],
|
||||||
|
"Inter-Token Latency": collated_stats[2],
|
||||||
|
},
|
||||||
|
tabs=tabs,
|
||||||
|
)
|
172
benchmark_v2/framework/hardware_metrics.py
Normal file
172
benchmark_v2/framework/hardware_metrics.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import gpustat
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Data class to hold the hardware information
|
||||||
|
def get_device_name_and_memory_total() -> tuple[str, float]:
|
||||||
|
"""Returns the name and memory total of GPU 0."""
|
||||||
|
device_name = torch.cuda.get_device_properties(0).name
|
||||||
|
device_memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||||
|
return device_name, device_memory_total
|
||||||
|
|
||||||
|
|
||||||
|
class HardwareInfo:
|
||||||
|
"""A class to hold information about the hardware."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Retrieve GPU stats
|
||||||
|
try:
|
||||||
|
self.gpu_name, self.gpu_memory_total_gb = get_device_name_and_memory_total()
|
||||||
|
except Exception:
|
||||||
|
self.gpu_name, self.gpu_memory_total_gb = None, None
|
||||||
|
# Retrieve python, torch and CUDA version
|
||||||
|
self.python_version = f"{sys.version.split()[0]}"
|
||||||
|
self.torch_version = torch.__version__
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
self.cuda_version = torch.version.cuda
|
||||||
|
else:
|
||||||
|
self.cuda_version = None
|
||||||
|
# Retrieve general hardware information
|
||||||
|
self.cpu_count = psutil.cpu_count()
|
||||||
|
self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
||||||
|
return {
|
||||||
|
"gpu_name": self.gpu_name,
|
||||||
|
"gpu_memory_total_gb": self.gpu_memory_total_gb,
|
||||||
|
"python_version": self.python_version,
|
||||||
|
"torch_version": self.torch_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Functions to get information about the GPU
|
||||||
|
def get_amd_gpu_stats() -> tuple[int, float]:
|
||||||
|
"""Returns the utilization and memory used of an AMD GPU, both in percent"""
|
||||||
|
rocm_smi_output = subprocess.check_output(["rocm-smi", "--json", "--showuse", "--showmeminfo", "VRAM"])
|
||||||
|
gpu_stats = json.loads(rocm_smi_output.decode("utf-8"))
|
||||||
|
gpu_stats = [
|
||||||
|
(card_id, stats["GPU use (%)"], stats["VRAM Total Used Memory (B)"]) for card_id, stats in gpu_stats.items()
|
||||||
|
]
|
||||||
|
gpu_stats.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return int(gpu_stats[0][1]), float(gpu_stats[0][2]) / 1024**3
|
||||||
|
|
||||||
|
|
||||||
|
def get_nvidia_gpu_stats() -> tuple[int, float]:
|
||||||
|
"""Returns the utilization and memory used of an NVIDIA GPU, both in percent"""
|
||||||
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||||
|
gpu_stats = gpu_stats[0]
|
||||||
|
return int(gpu_stats["utilization.gpu"]), float(gpu_stats["memory.used"]) / 1024**3
|
||||||
|
|
||||||
|
|
||||||
|
class GPUStatsCollector:
|
||||||
|
"""A class to get statistics about the GPU. It serves as a wrapper that holds the GPU total memory and its name,
|
||||||
|
which is used to call the right function to get the utilization and memory used."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.device_name, self.device_memory_total = get_device_name_and_memory_total()
|
||||||
|
# Monkey patch the get_utilization_and_memory_used method based on the GPU type
|
||||||
|
if "amd" in self.device_name.lower():
|
||||||
|
self.get_utilization_and_memory_used = get_amd_gpu_stats
|
||||||
|
elif "nvidia" in self.device_name.lower():
|
||||||
|
self.get_utilization_and_memory_used = get_nvidia_gpu_stats
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported GPU: {self.device_name}")
|
||||||
|
|
||||||
|
def get_measurements(self) -> tuple[int, float]:
|
||||||
|
"""Get the utilization and memory used of the GPU, both in percent"""
|
||||||
|
raise NotImplementedError("This method is meant to be monkey patched during __init__")
|
||||||
|
|
||||||
|
|
||||||
|
# Simple data classes to hold the raw GPU metrics
|
||||||
|
class GPUMonitoringStatus(Enum):
|
||||||
|
"""Status of GPU monitoring."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
NO_GPUS_AVAILABLE = "no_gpus_available"
|
||||||
|
NO_SAMPLES_COLLECTED = "no_samples_collected"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPURawMetrics:
|
||||||
|
"""Raw values for GPU utilization and memory used."""
|
||||||
|
|
||||||
|
utilization: list[float] # in percent
|
||||||
|
memory_used: list[float] # in GB
|
||||||
|
timestamps: list[float] # in seconds
|
||||||
|
timestamp_0: float # in seconds
|
||||||
|
monitoring_status: GPUMonitoringStatus
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
||||||
|
return {
|
||||||
|
"utilization": self.utilization,
|
||||||
|
"memory_used": self.memory_used,
|
||||||
|
"timestamps": self.timestamps,
|
||||||
|
"timestamp_0": self.timestamp_0,
|
||||||
|
"monitoring_status": self.monitoring_status.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Main class, used to monitor the GPU utilization during benchmark execution
|
||||||
|
class GPUMonitor:
|
||||||
|
"""Monitor GPU utilization during benchmark execution."""
|
||||||
|
|
||||||
|
def __init__(self, sample_interval_sec: float = 0.1, logger: Optional[Logger] = None):
|
||||||
|
self.sample_interval_sec = sample_interval_sec
|
||||||
|
self.logger = logger if logger is not None else logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self.num_available_gpus = torch.cuda.device_count()
|
||||||
|
if self.num_available_gpus == 0:
|
||||||
|
raise RuntimeError("No GPUs detected by torch.cuda.device_count().")
|
||||||
|
self.gpu_stats_getter = GPUStatsCollector()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start monitoring GPU metrics."""
|
||||||
|
# Clear the stop event to enable monitoring
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self.gpu_utilization = []
|
||||||
|
self.gpu_memory_used = []
|
||||||
|
self.timestamps = []
|
||||||
|
self.thread = threading.Thread(target=self._monitor_loop)
|
||||||
|
self.thread.start()
|
||||||
|
self.logger.debug("GPU monitoring started")
|
||||||
|
|
||||||
|
def stop_and_collect(self) -> GPURawMetrics:
|
||||||
|
"""Stop monitoring and return collected metrics."""
|
||||||
|
self.stop_event.set()
|
||||||
|
self.thread.join()
|
||||||
|
if self.gpu_utilization:
|
||||||
|
timestamp_0 = self.timestamps[0]
|
||||||
|
metrics = GPURawMetrics(
|
||||||
|
utilization=self.gpu_utilization,
|
||||||
|
memory_used=self.gpu_memory_used,
|
||||||
|
timestamps=[t - timestamp_0 for t in self.timestamps],
|
||||||
|
timestamp_0=timestamp_0,
|
||||||
|
monitoring_status=GPUMonitoringStatus.SUCCESS,
|
||||||
|
)
|
||||||
|
self.logger.debug(f"GPU monitoring completed: {len(self.gpu_utilization)} samples collected")
|
||||||
|
else:
|
||||||
|
metrics = GPURawMetrics(monitoring_status=GPUMonitoringStatus.NO_SAMPLES_COLLECTED)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _monitor_loop(self):
|
||||||
|
"""Background monitoring loop using threading.Event for communication."""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
utilization, memory_used = self.gpu_stats_getter.get_utilization_and_memory_used()
|
||||||
|
self.gpu_utilization.append(utilization)
|
||||||
|
self.gpu_memory_used.append(memory_used)
|
||||||
|
self.timestamps.append(time.time())
|
||||||
|
if self.stop_event.wait(timeout=self.sample_interval_sec):
|
||||||
|
break
|
@ -19,477 +19,93 @@ in the ./benches directory, organizing outputs into model-specific subfolders.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib.util
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import random
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from framework.benchmark_config import BenchmarkConfig, generate_all_configs
|
||||||
from typing import Any, Optional
|
from framework.benchmark_runner import BenchmarkRunner
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(log_level: str = "INFO", enable_file_logging: bool = False) -> logging.Logger:
|
if __name__ == "__main__":
|
||||||
"""Setup logging configuration."""
|
# Parse arguments
|
||||||
numeric_level = getattr(logging, log_level.upper(), None)
|
parser = argparse.ArgumentParser()
|
||||||
if not isinstance(numeric_level, int):
|
parser.add_argument("--output-dir", type=str, default="benchmark_results", help="Output dir for benchmark results")
|
||||||
raise ValueError(f"Invalid log level: {log_level}")
|
parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO")
|
||||||
|
parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations")
|
||||||
|
parser.add_argument("--iterations", type=int, default=20, help="Number of measurement iterations")
|
||||||
|
|
||||||
|
parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size")
|
||||||
|
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
|
||||||
|
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
|
||||||
|
|
||||||
|
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
|
||||||
|
|
||||||
|
parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
benchmark_run_uuid = str(uuid.uuid4())[:8]
|
||||||
|
numeric_level = getattr(logging, args.log_level.upper())
|
||||||
|
|
||||||
handlers = [logging.StreamHandler(sys.stdout)]
|
handlers = [logging.StreamHandler(sys.stdout)]
|
||||||
|
|
||||||
if enable_file_logging:
|
|
||||||
handlers.append(logging.FileHandler(f"benchmark_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"))
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=numeric_level, format="[%(levelname)s - %(asctime)s] %(name)s: %(message)s", handlers=handlers
|
level=numeric_level, format="[%(levelname)s - %(asctime)s] %(name)s: %(message)s", handlers=handlers
|
||||||
)
|
)
|
||||||
|
|
||||||
return logging.getLogger(__name__)
|
logger = logging.getLogger("benchmark_v2")
|
||||||
|
|
||||||
|
|
||||||
def discover_benchmarks(benches_dir: str) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Discover all benchmark modules in the benches directory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries containing benchmark module info
|
|
||||||
"""
|
|
||||||
benchmarks = []
|
|
||||||
benches_path = Path(benches_dir)
|
|
||||||
|
|
||||||
if not benches_path.exists():
|
|
||||||
raise FileNotFoundError(f"Benches directory not found: {benches_dir}")
|
|
||||||
|
|
||||||
for py_file in benches_path.glob("*.py"):
|
|
||||||
if py_file.name.startswith("__"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
module_name = py_file.stem
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Import the module
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
|
|
||||||
# Check if it has a benchmark runner function
|
|
||||||
if hasattr(module, f"run_{module_name}"):
|
|
||||||
benchmarks.append(
|
|
||||||
{
|
|
||||||
"name": module_name,
|
|
||||||
"path": str(py_file),
|
|
||||||
"module": module,
|
|
||||||
"runner_function": getattr(module, f"run_{module_name}"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif hasattr(module, "run_benchmark"):
|
|
||||||
benchmarks.append(
|
|
||||||
{
|
|
||||||
"name": module_name,
|
|
||||||
"path": str(py_file),
|
|
||||||
"module": module,
|
|
||||||
"runner_function": getattr(module, "run_benchmark"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.warning(f"No runner function found in {py_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to import {py_file}: {e}")
|
|
||||||
|
|
||||||
return benchmarks
|
|
||||||
|
|
||||||
|
|
||||||
def run_single_benchmark(
|
|
||||||
benchmark_info: dict[str, Any], output_dir: str, logger: logging.Logger, **kwargs
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Run a single benchmark and return the output file path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
benchmark_info: Dictionary containing benchmark module info
|
|
||||||
output_dir: Base output directory
|
|
||||||
logger: Logger instance
|
|
||||||
**kwargs: Additional arguments to pass to the benchmark
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the output file if successful, None otherwise
|
|
||||||
"""
|
|
||||||
benchmark_name = benchmark_info["name"]
|
|
||||||
runner_func = benchmark_info["runner_function"]
|
|
||||||
|
|
||||||
logger.info(f"Running benchmark: {benchmark_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check function signature to determine what arguments to pass
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
sig = inspect.signature(runner_func)
|
|
||||||
|
|
||||||
# Prepare arguments based on function signature
|
|
||||||
func_kwargs = {"logger": logger, "output_dir": output_dir}
|
|
||||||
|
|
||||||
# Add other kwargs if the function accepts them
|
|
||||||
for param_name in sig.parameters:
|
|
||||||
if param_name in kwargs:
|
|
||||||
func_kwargs[param_name] = kwargs[param_name]
|
|
||||||
|
|
||||||
# Filter kwargs to only include parameters the function accepts
|
|
||||||
# If function has **kwargs, include all provided kwargs
|
|
||||||
has_var_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
|
|
||||||
if has_var_kwargs:
|
|
||||||
valid_kwargs = {**func_kwargs, **kwargs}
|
|
||||||
else:
|
|
||||||
valid_kwargs = {k: v for k, v in func_kwargs.items() if k in sig.parameters}
|
|
||||||
|
|
||||||
# Run the benchmark
|
|
||||||
result = runner_func(**valid_kwargs)
|
|
||||||
|
|
||||||
if isinstance(result, str):
|
|
||||||
# Function returned a file path
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
logger.info(f"Benchmark {benchmark_name} completed successfully")
|
|
||||||
return "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Benchmark {benchmark_name} failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_summary_report(
|
|
||||||
output_dir: str,
|
|
||||||
benchmark_results: dict[str, Any],
|
|
||||||
logger: logging.Logger,
|
|
||||||
benchmark_run_uuid: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Generate a summary report of all benchmark runs."""
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
summary_file = os.path.join(output_dir, f"benchmark_summary_{timestamp}.json")
|
|
||||||
|
|
||||||
summary_data = {
|
|
||||||
"run_metadata": {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"benchmark_run_uuid": benchmark_run_uuid,
|
|
||||||
"total_benchmarks": len(benchmark_results),
|
|
||||||
"successful_benchmarks": len([r for r in benchmark_results.values() if r is not None]),
|
|
||||||
"failed_benchmarks": len([r for r in benchmark_results.values() if r is None]),
|
|
||||||
},
|
|
||||||
"benchmark_results": benchmark_results,
|
|
||||||
"output_directory": output_dir,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(summary_file, "w") as f:
|
|
||||||
json.dump(summary_data, f, indent=2, default=str)
|
|
||||||
|
|
||||||
logger.info(f"Summary report saved to: {summary_file}")
|
|
||||||
return summary_file
|
|
||||||
|
|
||||||
|
|
||||||
def upload_results_to_hf_dataset(
|
|
||||||
output_dir: str,
|
|
||||||
summary_file: str,
|
|
||||||
dataset_name: str,
|
|
||||||
run_id: Optional[str] = None,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
logger: Optional[logging.Logger] = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Upload benchmark results to a HuggingFace Dataset.
|
|
||||||
Based on upload_collated_report() from utils/collated_reports.py
|
|
||||||
Args:
|
|
||||||
output_dir: Local output directory containing results
|
|
||||||
summary_file: Path to the summary file
|
|
||||||
dataset_name: Name of the HuggingFace dataset to upload to
|
|
||||||
run_id: Unique run identifier (if None, will generate one)
|
|
||||||
token: HuggingFace token for authentication (if None, will use environment variables)
|
|
||||||
logger: Logger instance
|
|
||||||
Returns:
|
|
||||||
The run_id used for the upload, None if upload failed
|
|
||||||
"""
|
|
||||||
if logger is None:
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
if run_id is None:
|
|
||||||
github_run_number = os.getenv("GITHUB_RUN_NUMBER")
|
|
||||||
github_run_id = os.getenv("GITHUB_RUN_ID")
|
|
||||||
if github_run_number and github_run_id:
|
|
||||||
run_id = f"{github_run_number}-{github_run_id}"
|
|
||||||
|
|
||||||
date_folder = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
github_event_name = os.getenv("GITHUB_EVENT_NAME")
|
|
||||||
if github_event_name != "schedule":
|
|
||||||
# Non-scheduled runs go under a runs subfolder
|
|
||||||
repo_path = f"{date_folder}/runs/{run_id}/benchmark_results"
|
|
||||||
else:
|
|
||||||
# Scheduled runs go directly under the date
|
|
||||||
repo_path = f"{date_folder}/{run_id}/benchmark_results"
|
|
||||||
|
|
||||||
logger.info(f"Uploading benchmark results to dataset '{dataset_name}' at path '{repo_path}'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Upload all files in the output directory
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
output_path = Path(output_dir)
|
|
||||||
|
|
||||||
for file_path in output_path.rglob("*"):
|
|
||||||
if file_path.is_file():
|
|
||||||
# Calculate relative path from output_dir
|
|
||||||
relative_path = file_path.relative_to(output_path)
|
|
||||||
path_in_repo = f"{repo_path}/{relative_path}"
|
|
||||||
|
|
||||||
logger.debug(f"Uploading {file_path} to {path_in_repo}")
|
|
||||||
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=str(file_path),
|
|
||||||
path_in_repo=path_in_repo,
|
|
||||||
repo_id=dataset_name,
|
|
||||||
repo_type="dataset",
|
|
||||||
token=token,
|
|
||||||
commit_message=f"Upload benchmark results for run {run_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Successfully uploaded results to: https://huggingface.co/datasets/{dataset_name}/tree/main/{repo_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return run_id
|
|
||||||
|
|
||||||
except Exception as upload_error:
|
|
||||||
logger.error(f"Failed to upload results: {upload_error}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main entry point for the benchmarking script."""
|
|
||||||
# Generate a unique UUID for this benchmark run
|
|
||||||
benchmark_run_uuid = str(uuid.uuid4())[:8]
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run all benchmarks in the ./benches directory",
|
|
||||||
epilog="""
|
|
||||||
Examples:
|
|
||||||
# Run all available benchmarks
|
|
||||||
python3 run_benchmarks.py
|
|
||||||
|
|
||||||
# Run with specific model and upload to HuggingFace Dataset
|
|
||||||
python3 run_benchmarks.py --model-id meta-llama/Llama-2-7b-hf --upload-to-hf username/benchmark-results
|
|
||||||
|
|
||||||
# Run with custom run ID and upload to HuggingFace Dataset
|
|
||||||
python3 run_benchmarks.py --run-id experiment_v1 --upload-to-hf org/benchmarks
|
|
||||||
|
|
||||||
# Run only specific benchmarks with file logging
|
|
||||||
python3 run_benchmarks.py --include llama --enable-file-logging
|
|
||||||
""", # noqa: W293
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=str,
|
|
||||||
default="benchmark_results",
|
|
||||||
help="Base output directory for benchmark results (default: benchmark_results)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--benches-dir",
|
|
||||||
type=str,
|
|
||||||
default="./benches",
|
|
||||||
help="Directory containing benchmark implementations (default: ./benches)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--log-level",
|
|
||||||
type=str,
|
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
||||||
default="INFO",
|
|
||||||
help="Logging level (default: INFO)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)")
|
|
||||||
|
|
||||||
parser.add_argument("--warmup-iterations", type=int, default=3, help="Number of warmup iterations (default: 3)")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--measurement-iterations", type=int, default=5, help="Number of measurement iterations (default: 5)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-tokens-to-generate",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Number of tokens to generate in benchmarks (default: 100)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--include", type=str, nargs="*", help="Only run benchmarks matching these names")
|
|
||||||
|
|
||||||
parser.add_argument("--exclude", type=str, nargs="*", help="Exclude benchmarks matching these names")
|
|
||||||
|
|
||||||
parser.add_argument("--enable-file-logging", action="store_true", help="Enable file logging (disabled by default)")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--commit-id", type=str, help="Git commit ID for metadata (if not provided, will auto-detect from git)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--push-to-hub",
|
|
||||||
type=str,
|
|
||||||
help="Upload results to HuggingFace Dataset (provide dataset name, e.g., 'username/benchmark-results')",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--run-id", type=str, help="Custom run ID for organizing results (if not provided, will generate a unique ID)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--token",
|
|
||||||
type=str,
|
|
||||||
help="HuggingFace token for dataset uploads (if not provided, will use HF_TOKEN environment variable)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logger = setup_logging(args.log_level, args.enable_file_logging)
|
|
||||||
|
|
||||||
logger.info("Starting benchmark discovery and execution")
|
logger.info("Starting benchmark discovery and execution")
|
||||||
logger.info(f"Benchmark run UUID: {benchmark_run_uuid}")
|
logger.info(f"Benchmark run UUID: {benchmark_run_uuid}")
|
||||||
logger.info(f"Output directory: {args.output_dir}")
|
logger.info(f"Output directory: {args.output_dir}")
|
||||||
logger.info(f"Benches directory: {args.benches_dir}")
|
|
||||||
|
|
||||||
# Create output directory
|
# Error out if one of the arguments is not provided
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
if len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of the arguments --batch-size, --sequence-length, or --num-tokens-to-generate is required"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
|
||||||
# Discover benchmarks
|
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
|
||||||
benchmarks = discover_benchmarks(args.benches_dir)
|
benchmark_configs = generate_all_configs(
|
||||||
logger.info(f"Discovered {len(benchmarks)} benchmark(s): {[b['name'] for b in benchmarks]}")
|
warmup_iterations=args.warmup,
|
||||||
|
measurement_iterations=args.iterations,
|
||||||
|
batch_size=args.batch_size[0],
|
||||||
|
sequence_length=args.sequence_length[0],
|
||||||
|
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||||
|
)
|
||||||
|
random.shuffle(benchmark_configs)
|
||||||
|
|
||||||
if not benchmarks:
|
# Otherwise, we benchmark across all combinations of dimensions
|
||||||
logger.warning("No benchmarks found!")
|
else:
|
||||||
return 1
|
kwargs = {
|
||||||
|
"warmup_iterations": args.warmup,
|
||||||
# Filter benchmarks based on include/exclude
|
"measurement_iterations": args.iterations,
|
||||||
filtered_benchmarks = benchmarks
|
"gpu_monitoring": False,
|
||||||
|
"batch_size": args.batch_size[0],
|
||||||
if args.include:
|
"sequence_length": args.sequence_length[0],
|
||||||
filtered_benchmarks = [
|
"num_tokens_to_generate": args.num_tokens_to_generate[0],
|
||||||
b for b in filtered_benchmarks if any(pattern in b["name"] for pattern in args.include)
|
"attn_implementation": "flex_attention",
|
||||||
]
|
"sdpa_backend": None,
|
||||||
logger.info(f"Filtered to include: {[b['name'] for b in filtered_benchmarks]}")
|
"compile_mode": "default",
|
||||||
|
"kernelize": False,
|
||||||
if args.exclude:
|
|
||||||
filtered_benchmarks = [
|
|
||||||
b for b in filtered_benchmarks if not any(pattern in b["name"] for pattern in args.exclude)
|
|
||||||
]
|
|
||||||
logger.info(f"After exclusion: {[b['name'] for b in filtered_benchmarks]}")
|
|
||||||
|
|
||||||
if not filtered_benchmarks:
|
|
||||||
logger.warning("No benchmarks remaining after filtering!")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# Prepare common kwargs for benchmarks
|
|
||||||
benchmark_kwargs = {
|
|
||||||
"warmup_iterations": args.warmup_iterations,
|
|
||||||
"measurement_iterations": args.measurement_iterations,
|
|
||||||
"num_tokens_to_generate": args.num_tokens_to_generate,
|
|
||||||
}
|
}
|
||||||
|
benchmark_configs = []
|
||||||
|
for num_tokens_to_generate in args.num_tokens_to_generate:
|
||||||
|
for sequence_length in args.sequence_length:
|
||||||
|
for batch_size in args.batch_size:
|
||||||
|
kwargs["batch_size"] = batch_size
|
||||||
|
kwargs["sequence_length"] = sequence_length
|
||||||
|
kwargs["num_tokens_to_generate"] = num_tokens_to_generate
|
||||||
|
benchmark_configs.append(BenchmarkConfig(**kwargs))
|
||||||
|
|
||||||
if args.model_id:
|
runner = BenchmarkRunner(logger, args.output_dir, args.commit_id)
|
||||||
benchmark_kwargs["model_id"] = args.model_id
|
results = runner.run_benchmarks(
|
||||||
|
args.model_id,
|
||||||
# Add commit_id if provided
|
benchmark_configs[:3],
|
||||||
if args.commit_id:
|
args.num_tokens_to_profile,
|
||||||
benchmark_kwargs["commit_id"] = args.commit_id
|
pretty_print_summary=True,
|
||||||
|
)
|
||||||
# Run benchmarks
|
# runner.save_results(args.model_id, results)
|
||||||
benchmark_results = {}
|
|
||||||
successful_count = 0
|
|
||||||
|
|
||||||
for benchmark_info in filtered_benchmarks:
|
|
||||||
result = run_single_benchmark(benchmark_info, args.output_dir, logger, **benchmark_kwargs)
|
|
||||||
|
|
||||||
benchmark_results[benchmark_info["name"]] = result
|
|
||||||
|
|
||||||
if result is not None:
|
|
||||||
successful_count += 1
|
|
||||||
|
|
||||||
# Generate summary report
|
|
||||||
summary_file = generate_summary_report(args.output_dir, benchmark_results, logger, benchmark_run_uuid)
|
|
||||||
|
|
||||||
# Upload results to HuggingFace Dataset if requested
|
|
||||||
upload_run_id = None
|
|
||||||
if args.push_to_hub:
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info("UPLOADING TO HUGGINGFACE DATASET")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
# Use provided run_id or fallback to benchmark run UUID
|
|
||||||
effective_run_id = args.run_id or benchmark_run_uuid
|
|
||||||
upload_run_id = upload_results_to_hf_dataset(
|
|
||||||
output_dir=args.output_dir,
|
|
||||||
summary_file=summary_file,
|
|
||||||
dataset_name=args.push_to_hub,
|
|
||||||
run_id=effective_run_id,
|
|
||||||
token=args.token,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
if upload_run_id:
|
|
||||||
logger.info(f"Upload completed with run ID: {upload_run_id}")
|
|
||||||
else:
|
|
||||||
logger.warning("Upload failed - continuing with local results")
|
|
||||||
|
|
||||||
# Final summary
|
|
||||||
total_benchmarks = len(filtered_benchmarks)
|
|
||||||
failed_count = total_benchmarks - successful_count
|
|
||||||
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info("BENCHMARK RUN SUMMARY")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info(f"Total benchmarks: {total_benchmarks}")
|
|
||||||
logger.info(f"Successful: {successful_count}")
|
|
||||||
logger.info(f"Failed: {failed_count}")
|
|
||||||
logger.info(f"Output directory: {args.output_dir}")
|
|
||||||
logger.info(f"Summary report: {summary_file}")
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
|
||||||
if upload_run_id:
|
|
||||||
logger.info(f"HuggingFace Dataset: {args.push_to_hub}")
|
|
||||||
logger.info(f"Run ID: {upload_run_id}")
|
|
||||||
logger.info(
|
|
||||||
f"View results: https://huggingface.co/datasets/{args.push_to_hub}/tree/main/{datetime.now().strftime('%Y-%m-%d')}/runs/{upload_run_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning("Upload to HuggingFace Dataset failed")
|
|
||||||
|
|
||||||
if failed_count > 0:
|
|
||||||
logger.warning(f"{failed_count} benchmark(s) failed. Check logs for details.")
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
logger.info("All benchmarks completed successfully!")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Benchmark run failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
|
Reference in New Issue
Block a user