Small changes to benchmarking script (#41662)

This commit is contained in:
Rémi Ouazan
2025-10-16 17:25:49 +02:00
committed by GitHub
parent 9839d57a02
commit f7c33abab3
4 changed files with 88 additions and 77 deletions

View File

@ -104,7 +104,7 @@ class BenchmarkConfig:
"attn_implementation": self.attn_implementation, "attn_implementation": self.attn_implementation,
"sdpa_backend": self.sdpa_backend, "sdpa_backend": self.sdpa_backend,
"compile_mode": self.compile_mode, "compile_mode": self.compile_mode,
"compile_options": self.compile_options, "compile_options": self.compile_options | {}, # to avoid inplace modification of the original dict
"kernelize": self.kernelize, "kernelize": self.kernelize,
} }
@ -191,7 +191,7 @@ def generate_all_configs(
) )
def generate_default_configs( def generate_main_configs(
warmup_iterations: int = 5, warmup_iterations: int = 5,
measurement_iterations: int = 20, measurement_iterations: int = 20,
batch_size: int = 1, batch_size: int = 1,
@ -199,20 +199,17 @@ def generate_default_configs(
num_tokens_to_generate: int = 128, num_tokens_to_generate: int = 128,
gpu_monitoring: bool = False, gpu_monitoring: bool = False,
) -> list[BenchmarkConfig]: ) -> list[BenchmarkConfig]:
all_attn_implementations = [ # Create kwargs common to all configs
("flash_attention_2", None), kwargs = {
("eager", None), "warmup_iterations": warmup_iterations,
("sdpa", "math"), "measurement_iterations": measurement_iterations,
("sdpa", "flash_attention"), # note: this one can fail with compile because of attn mask "batch_size": batch_size,
"sequence_length": sequence_length,
"num_tokens_to_generate": num_tokens_to_generate,
"gpu_monitoring": gpu_monitoring,
}
return [ # TODO: test max-autotune instead of default
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", **kwargs),
BenchmarkConfig(attn_implementation="eager", compile_mode="default", **kwargs),
BenchmarkConfig(attn_implementation="flash_attention_2", **kwargs),
] ]
return cross_generate_configs(
attn_impl_and_sdpa_backend=all_attn_implementations,
compiled_mode=[None, "max-autotune"],
kernelized=[False, KERNELIZATION_AVAILABLE],
warmup_iterations=warmup_iterations,
measurement_iterations=measurement_iterations,
batch_size=batch_size,
sequence_length=sequence_length,
num_tokens_to_generate=num_tokens_to_generate,
gpu_monitoring=gpu_monitoring,
)

View File

@ -144,11 +144,11 @@ class BenchmarkStreamer(BaseStreamer):
class BenchmarkRunner: class BenchmarkRunner:
"""Main benchmark runner that coordinates benchmark execution.""" """Main benchmark runner that coordinates benchmark execution."""
def __init__( def __init__(self, logger: logging.Logger, output_dir: str | None = None, commit_id: str | None = None) -> None:
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: str | None = None
) -> None:
# Those stay constant for the whole run # Those stay constant for the whole run
self.logger = logger self.logger = logger
if output_dir is None:
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "benchmark_results")
self.output_dir = output_dir self.output_dir = output_dir
self.commit_id = get_git_revision() if commit_id is None else commit_id self.commit_id = get_git_revision() if commit_id is None else commit_id
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
@ -214,7 +214,7 @@ class BenchmarkRunner:
# Quick validation: try one measurement first to see if this scenario works # Quick validation: try one measurement first to see if this scenario works
flush_memory() flush_memory()
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate( e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
max_new_tokens=1, gpu_monitor=None max_new_tokens=1, gpu_monitor=None
) )
if e2e_latency < 0: if e2e_latency < 0:
@ -231,11 +231,11 @@ class BenchmarkRunner:
result = BenchmarkResult() result = BenchmarkResult()
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.") self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
for _ in trange(config.measurement_iterations): for _ in trange(config.measurement_iterations):
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate( e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
max_new_tokens=config.num_tokens_to_generate, max_new_tokens=config.num_tokens_to_generate,
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None), gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
) )
result.accumulate(e2e_latency, token_generation_times, decoded_output, gpu_metrics) result.accumulate(e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics)
self.logger.info("Benchmarking done. Cleaning up.") self.logger.info("Benchmarking done. Cleaning up.")
# Profile if needed # Profile if needed
@ -277,10 +277,11 @@ class BenchmarkRunner:
raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}") raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}")
# Decode outputs # Decode outputs
decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True) decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True)
shape_and_decoded_output = f"{tuple(outputs.shape)} | {decoded_output}"
# Compute intermediate quantities # Compute intermediate quantities
e2e_latency = wall_time_1 - wall_time_0 e2e_latency = wall_time_1 - wall_time_0
token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]] token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]]
return e2e_latency, token_generation_times, decoded_output, gpu_metrics return e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics
def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None: def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None:
"""Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens).""" """Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
@ -351,10 +352,10 @@ class BenchmarkRunner:
first_metadata = all_results[first_key]["metadata"].to_dict() first_metadata = all_results[first_key]["metadata"].to_dict()
hardware_info = first_metadata.pop("hardware_info") hardware_info = first_metadata.pop("hardware_info")
pretty_print_dict(first_metadata | hardware_info, tabs=1) pretty_print_dict(first_metadata | hardware_info, tabs=1)
for value in all_results.values(): for result in all_results.values():
print("=" * 100) print("=" * 100)
print(f"Config: {value['config'].infer_name(compact=False)}\n") print(f"Config: {result['config'].infer_name(compact=False)}\n")
value["measurements"].pprint(tabs=1) result["measurements"].pprint(batch_size=result["config"].batch_size, tabs=1)
print("=" * 100) print("=" * 100)
return all_results return all_results

View File

@ -82,19 +82,19 @@ class BenchmarkResult:
def __init__(self) -> None: def __init__(self) -> None:
self.e2e_latency = [] self.e2e_latency = []
self.token_generation_times = [] # time at which each token was generated (relative to start of the generation) self.token_generation_times = [] # time at which each token was generated (relative to start of the generation)
self.decoded_outputs = [] self.shape_and_decoded_outputs = []
self.gpu_metrics = [] self.gpu_metrics = []
def accumulate( def accumulate(
self, self,
e2e_latency: float, e2e_latency: float,
token_generation_times: list[float], token_generation_times: list[float],
decoded_output: str, shape_and_decoded_output: str,
gpu_metrics: GPURawMetrics | None, gpu_metrics: GPURawMetrics | None,
) -> None: ) -> None:
self.e2e_latency.append(e2e_latency) self.e2e_latency.append(e2e_latency)
self.token_generation_times.append(token_generation_times) self.token_generation_times.append(token_generation_times)
self.decoded_outputs.append(decoded_output) self.shape_and_decoded_outputs.append(shape_and_decoded_output)
self.gpu_metrics.append(gpu_metrics) self.gpu_metrics.append(gpu_metrics)
def to_dict(self) -> dict[str, None | int | float]: def to_dict(self) -> dict[str, None | int | float]:
@ -106,7 +106,7 @@ class BenchmarkResult:
return { return {
"e2e_latency": self.e2e_latency, "e2e_latency": self.e2e_latency,
"token_generation_times": self.token_generation_times, "token_generation_times": self.token_generation_times,
"decoded_outputs": self.decoded_outputs, "shape_and_decoded_outputs": self.shape_and_decoded_outputs,
"gpu_metrics": gpu_metrics, "gpu_metrics": gpu_metrics,
} }
@ -123,7 +123,7 @@ class BenchmarkResult:
new_instance.accumulate( new_instance.accumulate(
e2e_latency=data["e2e_latency"][i], e2e_latency=data["e2e_latency"][i],
token_generation_times=data["token_generation_times"][i], token_generation_times=data["token_generation_times"][i],
decoded_output=data["decoded_output"][i], shape_and_decoded_output=data["shape_and_decoded_outputs"][i],
gpu_metrics=gpu_metrics[i], gpu_metrics=gpu_metrics[i],
) )
return new_instance return new_instance
@ -134,19 +134,27 @@ class BenchmarkResult:
def get_measured_itl(self) -> list[float]: def get_measured_itl(self) -> list[float]:
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1] return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
def pprint(self, tabs: int = 0) -> None: def get_throughput(self, batch_size: int) -> float:
collated_stats = equalize_lengths_and_collate( return [
[ batch_size * len(dt) / e2e_latency
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)), for e2e_latency, dt in zip(self.e2e_latency, self.token_generation_times)
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())), ]
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
] def pprint(self, batch_size: int = 0, tabs: int = 0) -> None:
) stats_to_collate = [
pretty_print_dict( add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
{ add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
"E2E Latency": collated_stats[0], add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
"Time to First Token": collated_stats[1], ]
"Inter-Token Latency": collated_stats[2], if batch_size > 0:
}, throughput_stats = compute_basic_statistics(self.get_throughput(batch_size))
tabs=tabs, stats_to_collate.append({key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()})
) collated_stats = equalize_lengths_and_collate(stats_to_collate)
dict_to_pprint = {
"E2E Latency": collated_stats[0],
"Time to First Token": collated_stats[1],
"Inter-Token Latency": collated_stats[2],
}
if batch_size > 0:
dict_to_pprint["Throughput"] = collated_stats[3]
pretty_print_dict(dict_to_pprint, tabs=tabs)

View File

@ -20,28 +20,28 @@ in the ./benches directory, organizing outputs into model-specific subfolders.
import argparse import argparse
import logging import logging
import random
import sys import sys
import uuid import uuid
from framework.benchmark_config import BenchmarkConfig, generate_all_configs from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
from framework.benchmark_runner import BenchmarkRunner from framework.benchmark_runner import BenchmarkRunner
if __name__ == "__main__": if __name__ == "__main__":
# Parse arguments # Parse arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str, default="benchmark_results", help="Output dir for benchmark results") parser.add_argument("--output-dir", type=str, default=None, help="Output dir for benchmark results")
parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO") parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO")
parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)") parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)")
parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations")
parser.add_argument("--iterations", type=int, default=20, help="Number of measurement iterations") parser.add_argument("--iterations", type=int, default=10, help="Number of measurement iterations")
parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size") parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size")
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length") parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate") parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
parser.add_argument("--cross-generate", action="store_true", help="Cross-generate all combinations of configs")
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile") parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)") parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)")
@ -69,42 +69,47 @@ if __name__ == "__main__":
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs # If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1: elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
benchmark_configs = generate_all_configs( if args.cross_generate:
benchmark_configs = generate_all_configs(
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0],
)
else:
benchmark_configs = generate_main_configs(
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0],
)
# Otherwise, we benchmark across all combinations of dimensions
else:
main_config = generate_main_configs(
warmup_iterations=args.warmup, warmup_iterations=args.warmup,
measurement_iterations=args.iterations, measurement_iterations=args.iterations,
batch_size=args.batch_size[0], batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0], sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0], num_tokens_to_generate=args.num_tokens_to_generate[0],
) )[0]
random.shuffle(benchmark_configs)
# Otherwise, we benchmark across all combinations of dimensions
else:
kwargs = {
"warmup_iterations": args.warmup,
"measurement_iterations": args.iterations,
"gpu_monitoring": False,
"batch_size": args.batch_size[0],
"sequence_length": args.sequence_length[0],
"num_tokens_to_generate": args.num_tokens_to_generate[0],
"attn_implementation": "flex_attention",
"sdpa_backend": None,
"compile_mode": "default",
"kernelize": False,
}
benchmark_configs = [] benchmark_configs = []
for num_tokens_to_generate in args.num_tokens_to_generate: for num_tokens_to_generate in args.num_tokens_to_generate:
for sequence_length in args.sequence_length: for sequence_length in args.sequence_length:
for batch_size in args.batch_size: for batch_size in args.batch_size:
kwargs["batch_size"] = batch_size cfg_dict = main_config.to_dict()
kwargs["sequence_length"] = sequence_length cfg_dict["batch_size"] = batch_size
kwargs["num_tokens_to_generate"] = num_tokens_to_generate cfg_dict["sequence_length"] = sequence_length
benchmark_configs.append(BenchmarkConfig(**kwargs)) cfg_dict["num_tokens_to_generate"] = num_tokens_to_generate
cfg_dict.pop("name")
benchmark_configs.append(BenchmarkConfig.from_dict(cfg_dict))
runner = BenchmarkRunner(logger, args.output_dir, args.commit_id) runner = BenchmarkRunner(logger, args.output_dir, args.commit_id)
results = runner.run_benchmarks( results = runner.run_benchmarks(
args.model_id, args.model_id,
benchmark_configs[:3], benchmark_configs,
args.num_tokens_to_profile, args.num_tokens_to_profile,
pretty_print_summary=True, pretty_print_summary=True,
) )