mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Small changes to benchmarking script (#41662)
This commit is contained in:
@ -104,7 +104,7 @@ class BenchmarkConfig:
|
|||||||
"attn_implementation": self.attn_implementation,
|
"attn_implementation": self.attn_implementation,
|
||||||
"sdpa_backend": self.sdpa_backend,
|
"sdpa_backend": self.sdpa_backend,
|
||||||
"compile_mode": self.compile_mode,
|
"compile_mode": self.compile_mode,
|
||||||
"compile_options": self.compile_options,
|
"compile_options": self.compile_options | {}, # to avoid inplace modification of the original dict
|
||||||
"kernelize": self.kernelize,
|
"kernelize": self.kernelize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ def generate_all_configs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_default_configs(
|
def generate_main_configs(
|
||||||
warmup_iterations: int = 5,
|
warmup_iterations: int = 5,
|
||||||
measurement_iterations: int = 20,
|
measurement_iterations: int = 20,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
@ -199,20 +199,17 @@ def generate_default_configs(
|
|||||||
num_tokens_to_generate: int = 128,
|
num_tokens_to_generate: int = 128,
|
||||||
gpu_monitoring: bool = False,
|
gpu_monitoring: bool = False,
|
||||||
) -> list[BenchmarkConfig]:
|
) -> list[BenchmarkConfig]:
|
||||||
all_attn_implementations = [
|
# Create kwargs common to all configs
|
||||||
("flash_attention_2", None),
|
kwargs = {
|
||||||
("eager", None),
|
"warmup_iterations": warmup_iterations,
|
||||||
("sdpa", "math"),
|
"measurement_iterations": measurement_iterations,
|
||||||
("sdpa", "flash_attention"), # note: this one can fail with compile because of attn mask
|
"batch_size": batch_size,
|
||||||
|
"sequence_length": sequence_length,
|
||||||
|
"num_tokens_to_generate": num_tokens_to_generate,
|
||||||
|
"gpu_monitoring": gpu_monitoring,
|
||||||
|
}
|
||||||
|
return [ # TODO: test max-autotune instead of default
|
||||||
|
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", **kwargs),
|
||||||
|
BenchmarkConfig(attn_implementation="eager", compile_mode="default", **kwargs),
|
||||||
|
BenchmarkConfig(attn_implementation="flash_attention_2", **kwargs),
|
||||||
]
|
]
|
||||||
return cross_generate_configs(
|
|
||||||
attn_impl_and_sdpa_backend=all_attn_implementations,
|
|
||||||
compiled_mode=[None, "max-autotune"],
|
|
||||||
kernelized=[False, KERNELIZATION_AVAILABLE],
|
|
||||||
warmup_iterations=warmup_iterations,
|
|
||||||
measurement_iterations=measurement_iterations,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
num_tokens_to_generate=num_tokens_to_generate,
|
|
||||||
gpu_monitoring=gpu_monitoring,
|
|
||||||
)
|
|
||||||
|
@ -144,11 +144,11 @@ class BenchmarkStreamer(BaseStreamer):
|
|||||||
class BenchmarkRunner:
|
class BenchmarkRunner:
|
||||||
"""Main benchmark runner that coordinates benchmark execution."""
|
"""Main benchmark runner that coordinates benchmark execution."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, logger: logging.Logger, output_dir: str | None = None, commit_id: str | None = None) -> None:
|
||||||
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: str | None = None
|
|
||||||
) -> None:
|
|
||||||
# Those stay constant for the whole run
|
# Those stay constant for the whole run
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "benchmark_results")
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.commit_id = get_git_revision() if commit_id is None else commit_id
|
self.commit_id = get_git_revision() if commit_id is None else commit_id
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
@ -214,7 +214,7 @@ class BenchmarkRunner:
|
|||||||
|
|
||||||
# Quick validation: try one measurement first to see if this scenario works
|
# Quick validation: try one measurement first to see if this scenario works
|
||||||
flush_memory()
|
flush_memory()
|
||||||
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate(
|
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||||
max_new_tokens=1, gpu_monitor=None
|
max_new_tokens=1, gpu_monitor=None
|
||||||
)
|
)
|
||||||
if e2e_latency < 0:
|
if e2e_latency < 0:
|
||||||
@ -231,11 +231,11 @@ class BenchmarkRunner:
|
|||||||
result = BenchmarkResult()
|
result = BenchmarkResult()
|
||||||
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
|
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
|
||||||
for _ in trange(config.measurement_iterations):
|
for _ in trange(config.measurement_iterations):
|
||||||
e2e_latency, token_generation_times, decoded_output, gpu_metrics = self.time_generate(
|
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||||
max_new_tokens=config.num_tokens_to_generate,
|
max_new_tokens=config.num_tokens_to_generate,
|
||||||
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
|
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
|
||||||
)
|
)
|
||||||
result.accumulate(e2e_latency, token_generation_times, decoded_output, gpu_metrics)
|
result.accumulate(e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics)
|
||||||
self.logger.info("Benchmarking done. Cleaning up.")
|
self.logger.info("Benchmarking done. Cleaning up.")
|
||||||
|
|
||||||
# Profile if needed
|
# Profile if needed
|
||||||
@ -277,10 +277,11 @@ class BenchmarkRunner:
|
|||||||
raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}")
|
raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}")
|
||||||
# Decode outputs
|
# Decode outputs
|
||||||
decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True)
|
decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True)
|
||||||
|
shape_and_decoded_output = f"{tuple(outputs.shape)} | {decoded_output}"
|
||||||
# Compute intermediate quantities
|
# Compute intermediate quantities
|
||||||
e2e_latency = wall_time_1 - wall_time_0
|
e2e_latency = wall_time_1 - wall_time_0
|
||||||
token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]]
|
token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]]
|
||||||
return e2e_latency, token_generation_times, decoded_output, gpu_metrics
|
return e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics
|
||||||
|
|
||||||
def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None:
|
def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None:
|
||||||
"""Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
"""Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
||||||
@ -351,10 +352,10 @@ class BenchmarkRunner:
|
|||||||
first_metadata = all_results[first_key]["metadata"].to_dict()
|
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||||
hardware_info = first_metadata.pop("hardware_info")
|
hardware_info = first_metadata.pop("hardware_info")
|
||||||
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||||
for value in all_results.values():
|
for result in all_results.values():
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
print(f"Config: {value['config'].infer_name(compact=False)}\n")
|
print(f"Config: {result['config'].infer_name(compact=False)}\n")
|
||||||
value["measurements"].pprint(tabs=1)
|
result["measurements"].pprint(batch_size=result["config"].batch_size, tabs=1)
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
@ -82,19 +82,19 @@ class BenchmarkResult:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.e2e_latency = []
|
self.e2e_latency = []
|
||||||
self.token_generation_times = [] # time at which each token was generated (relative to start of the generation)
|
self.token_generation_times = [] # time at which each token was generated (relative to start of the generation)
|
||||||
self.decoded_outputs = []
|
self.shape_and_decoded_outputs = []
|
||||||
self.gpu_metrics = []
|
self.gpu_metrics = []
|
||||||
|
|
||||||
def accumulate(
|
def accumulate(
|
||||||
self,
|
self,
|
||||||
e2e_latency: float,
|
e2e_latency: float,
|
||||||
token_generation_times: list[float],
|
token_generation_times: list[float],
|
||||||
decoded_output: str,
|
shape_and_decoded_output: str,
|
||||||
gpu_metrics: GPURawMetrics | None,
|
gpu_metrics: GPURawMetrics | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.e2e_latency.append(e2e_latency)
|
self.e2e_latency.append(e2e_latency)
|
||||||
self.token_generation_times.append(token_generation_times)
|
self.token_generation_times.append(token_generation_times)
|
||||||
self.decoded_outputs.append(decoded_output)
|
self.shape_and_decoded_outputs.append(shape_and_decoded_output)
|
||||||
self.gpu_metrics.append(gpu_metrics)
|
self.gpu_metrics.append(gpu_metrics)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, None | int | float]:
|
def to_dict(self) -> dict[str, None | int | float]:
|
||||||
@ -106,7 +106,7 @@ class BenchmarkResult:
|
|||||||
return {
|
return {
|
||||||
"e2e_latency": self.e2e_latency,
|
"e2e_latency": self.e2e_latency,
|
||||||
"token_generation_times": self.token_generation_times,
|
"token_generation_times": self.token_generation_times,
|
||||||
"decoded_outputs": self.decoded_outputs,
|
"shape_and_decoded_outputs": self.shape_and_decoded_outputs,
|
||||||
"gpu_metrics": gpu_metrics,
|
"gpu_metrics": gpu_metrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ class BenchmarkResult:
|
|||||||
new_instance.accumulate(
|
new_instance.accumulate(
|
||||||
e2e_latency=data["e2e_latency"][i],
|
e2e_latency=data["e2e_latency"][i],
|
||||||
token_generation_times=data["token_generation_times"][i],
|
token_generation_times=data["token_generation_times"][i],
|
||||||
decoded_output=data["decoded_output"][i],
|
shape_and_decoded_output=data["shape_and_decoded_outputs"][i],
|
||||||
gpu_metrics=gpu_metrics[i],
|
gpu_metrics=gpu_metrics[i],
|
||||||
)
|
)
|
||||||
return new_instance
|
return new_instance
|
||||||
@ -134,19 +134,27 @@ class BenchmarkResult:
|
|||||||
def get_measured_itl(self) -> list[float]:
|
def get_measured_itl(self) -> list[float]:
|
||||||
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
||||||
|
|
||||||
def pprint(self, tabs: int = 0) -> None:
|
def get_throughput(self, batch_size: int) -> float:
|
||||||
collated_stats = equalize_lengths_and_collate(
|
return [
|
||||||
[
|
batch_size * len(dt) / e2e_latency
|
||||||
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
for e2e_latency, dt in zip(self.e2e_latency, self.token_generation_times)
|
||||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
]
|
||||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
|
|
||||||
]
|
def pprint(self, batch_size: int = 0, tabs: int = 0) -> None:
|
||||||
)
|
stats_to_collate = [
|
||||||
pretty_print_dict(
|
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||||
{
|
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||||
"E2E Latency": collated_stats[0],
|
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
|
||||||
"Time to First Token": collated_stats[1],
|
]
|
||||||
"Inter-Token Latency": collated_stats[2],
|
if batch_size > 0:
|
||||||
},
|
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size))
|
||||||
tabs=tabs,
|
stats_to_collate.append({key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()})
|
||||||
)
|
collated_stats = equalize_lengths_and_collate(stats_to_collate)
|
||||||
|
dict_to_pprint = {
|
||||||
|
"E2E Latency": collated_stats[0],
|
||||||
|
"Time to First Token": collated_stats[1],
|
||||||
|
"Inter-Token Latency": collated_stats[2],
|
||||||
|
}
|
||||||
|
if batch_size > 0:
|
||||||
|
dict_to_pprint["Throughput"] = collated_stats[3]
|
||||||
|
pretty_print_dict(dict_to_pprint, tabs=tabs)
|
||||||
|
@ -20,28 +20,28 @@ in the ./benches directory, organizing outputs into model-specific subfolders.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from framework.benchmark_config import BenchmarkConfig, generate_all_configs
|
from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
|
||||||
from framework.benchmark_runner import BenchmarkRunner
|
from framework.benchmark_runner import BenchmarkRunner
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--output-dir", type=str, default="benchmark_results", help="Output dir for benchmark results")
|
parser.add_argument("--output-dir", type=str, default=None, help="Output dir for benchmark results")
|
||||||
parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO")
|
parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO")
|
||||||
parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)")
|
parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)")
|
||||||
|
|
||||||
parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations")
|
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations")
|
||||||
parser.add_argument("--iterations", type=int, default=20, help="Number of measurement iterations")
|
parser.add_argument("--iterations", type=int, default=10, help="Number of measurement iterations")
|
||||||
|
|
||||||
parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size")
|
parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size")
|
||||||
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
|
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
|
||||||
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
|
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
|
||||||
|
|
||||||
|
parser.add_argument("--cross-generate", action="store_true", help="Cross-generate all combinations of configs")
|
||||||
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
|
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
|
||||||
|
|
||||||
parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)")
|
parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)")
|
||||||
@ -69,42 +69,47 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
|
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
|
||||||
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
|
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
|
||||||
benchmark_configs = generate_all_configs(
|
if args.cross_generate:
|
||||||
|
benchmark_configs = generate_all_configs(
|
||||||
|
warmup_iterations=args.warmup,
|
||||||
|
measurement_iterations=args.iterations,
|
||||||
|
batch_size=args.batch_size[0],
|
||||||
|
sequence_length=args.sequence_length[0],
|
||||||
|
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
benchmark_configs = generate_main_configs(
|
||||||
|
warmup_iterations=args.warmup,
|
||||||
|
measurement_iterations=args.iterations,
|
||||||
|
batch_size=args.batch_size[0],
|
||||||
|
sequence_length=args.sequence_length[0],
|
||||||
|
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise, we benchmark across all combinations of dimensions
|
||||||
|
else:
|
||||||
|
main_config = generate_main_configs(
|
||||||
warmup_iterations=args.warmup,
|
warmup_iterations=args.warmup,
|
||||||
measurement_iterations=args.iterations,
|
measurement_iterations=args.iterations,
|
||||||
batch_size=args.batch_size[0],
|
batch_size=args.batch_size[0],
|
||||||
sequence_length=args.sequence_length[0],
|
sequence_length=args.sequence_length[0],
|
||||||
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||||
)
|
)[0]
|
||||||
random.shuffle(benchmark_configs)
|
|
||||||
|
|
||||||
# Otherwise, we benchmark across all combinations of dimensions
|
|
||||||
else:
|
|
||||||
kwargs = {
|
|
||||||
"warmup_iterations": args.warmup,
|
|
||||||
"measurement_iterations": args.iterations,
|
|
||||||
"gpu_monitoring": False,
|
|
||||||
"batch_size": args.batch_size[0],
|
|
||||||
"sequence_length": args.sequence_length[0],
|
|
||||||
"num_tokens_to_generate": args.num_tokens_to_generate[0],
|
|
||||||
"attn_implementation": "flex_attention",
|
|
||||||
"sdpa_backend": None,
|
|
||||||
"compile_mode": "default",
|
|
||||||
"kernelize": False,
|
|
||||||
}
|
|
||||||
benchmark_configs = []
|
benchmark_configs = []
|
||||||
for num_tokens_to_generate in args.num_tokens_to_generate:
|
for num_tokens_to_generate in args.num_tokens_to_generate:
|
||||||
for sequence_length in args.sequence_length:
|
for sequence_length in args.sequence_length:
|
||||||
for batch_size in args.batch_size:
|
for batch_size in args.batch_size:
|
||||||
kwargs["batch_size"] = batch_size
|
cfg_dict = main_config.to_dict()
|
||||||
kwargs["sequence_length"] = sequence_length
|
cfg_dict["batch_size"] = batch_size
|
||||||
kwargs["num_tokens_to_generate"] = num_tokens_to_generate
|
cfg_dict["sequence_length"] = sequence_length
|
||||||
benchmark_configs.append(BenchmarkConfig(**kwargs))
|
cfg_dict["num_tokens_to_generate"] = num_tokens_to_generate
|
||||||
|
cfg_dict.pop("name")
|
||||||
|
benchmark_configs.append(BenchmarkConfig.from_dict(cfg_dict))
|
||||||
|
|
||||||
runner = BenchmarkRunner(logger, args.output_dir, args.commit_id)
|
runner = BenchmarkRunner(logger, args.output_dir, args.commit_id)
|
||||||
results = runner.run_benchmarks(
|
results = runner.run_benchmarks(
|
||||||
args.model_id,
|
args.model_id,
|
||||||
benchmark_configs[:3],
|
benchmark_configs,
|
||||||
args.num_tokens_to_profile,
|
args.num_tokens_to_profile,
|
||||||
pretty_print_summary=True,
|
pretty_print_summary=True,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user