mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-13 08:04:53 +08:00
* feat(ci): add continuous batching to benchmarks * refactor(ci): PR comments * refactor(cb): when stopping, block by default * fix(benchmarks): `stream` -> `streaming` * fix(benchmarks): invalid configuration when cb has attn_impl == sdpa * tests(cb): fix attn impl * fix(benchmarks): update `get_throughput` formula * fix(benchmarks): prevent version conflicts and ensure proper cleanup in continuous batching (#42063) * Initial plan * fix(benchmarks): ensure proper cleanup and remove transformers from requirements - Remove transformers from benchmark_v2/requirements.txt to prevent version conflicts - Add try-finally block to ensure ContinuousBatchingManager.stop() is always called - This fixes TypeError about unexpected 'streaming' argument and prevents OOM from improper cleanup Co-authored-by: McPatate <9112841+McPatate@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: McPatate <9112841+McPatate@users.noreply.github.com> * fix(benchmarks): raise the exception on failure instead of ignoring we catch the exception later on and raising it here helps debugging because it will be logged * test(cb): comment out failing tests for now added a `FIXME` mark * fix(benchmarks): revert `finally` removal but keep raising exception * test(cb): fix missing `require_read_token` import * refactor(benchmarks): error if no benchmarks were run * refactor(benchmarks): change default lvls of cb bench config --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: McPatate <9112841+McPatate@users.noreply.github.com>
161 lines
6.3 KiB
Python
161 lines
6.3 KiB
Python
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from .hardware_metrics import GPURawMetrics, HardwareInfo
|
|
|
|
|
|
def compute_basic_statistics(measurements: list[float]) -> dict[str, float]:
|
|
return {
|
|
"avg": np.mean(measurements),
|
|
"std": np.std(measurements),
|
|
"min": np.min(measurements),
|
|
"med": np.median(measurements),
|
|
"max": np.max(measurements),
|
|
"p95": np.percentile(measurements, 95),
|
|
}
|
|
|
|
|
|
def add_unit_to_duration(stats: dict[str, float]) -> dict[str, str]:
|
|
for key in list(stats.keys()):
|
|
value = stats[key]
|
|
if value > 3600:
|
|
stats[key] = f"{(value / 3600):.2f}hr"
|
|
elif value > 60:
|
|
stats[key] = f"{(value / 60):.2f}min"
|
|
elif value > 1:
|
|
stats[key] = f"{value:.2f}s"
|
|
elif value > 1e-3:
|
|
stats[key] = f"{(value * 1e3):.2f}ms"
|
|
elif value > 1e-6:
|
|
stats[key] = f"{(value * 1e6):.2f}us"
|
|
else:
|
|
stats[key] = f"{(value * 1e9):.2f}ns"
|
|
return stats
|
|
|
|
|
|
def equalize_lengths_and_collate(stats: dict[str, dict[str, str]]) -> dict[str, str]:
|
|
"""Note: This operation is destructive as it will update values in place before returning a new correctly formatted dict"""
|
|
keys = ["avg", "std", "min", "med", "max", "p95"]
|
|
for key in keys:
|
|
max_length = max(len(stat[key]) for stat in stats.values())
|
|
for stat in stats.values():
|
|
stat[key] = stat[key].ljust(max_length, " ")
|
|
return {name: " ".join([f"{key}={stat[key]}" for key in keys]) for name, stat in stats.items()}
|
|
|
|
|
|
def pretty_print_dict(data: dict[str, str], tabs: int = 0) -> None:
|
|
max_key_length = max([len(key) for key in data.keys()])
|
|
for key, value in data.items():
|
|
tabs_str = " " * tabs
|
|
padded_key = key.ljust(max_key_length + 1, ".")
|
|
print(f"{tabs_str}{padded_key}: {value}")
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkMetadata:
|
|
"""Metadata collected for each benchmark run."""
|
|
|
|
model_id: str
|
|
timestamp: str
|
|
branch_name: str
|
|
commit_id: str
|
|
commit_message: str
|
|
hardware_info: HardwareInfo
|
|
|
|
def __init__(self, model_id: str, commit_id: str, branch_name: str = "main", commit_message: str = "") -> None:
|
|
self.model_id = model_id
|
|
self.timestamp = datetime.now(timezone.utc).isoformat()
|
|
self.branch_name = branch_name
|
|
self.commit_id = commit_id
|
|
self.commit_message = commit_message
|
|
self.hardware_info = HardwareInfo()
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"model_id": self.model_id,
|
|
"timestamp": self.timestamp,
|
|
"branch_name": self.branch_name,
|
|
"commit_id": self.commit_id,
|
|
"commit_message": self.commit_message,
|
|
"hardware_info": self.hardware_info.to_dict(),
|
|
}
|
|
|
|
|
|
class BenchmarkResult:
|
|
"""Result from a series of benchmark runs."""
|
|
|
|
def __init__(self) -> None:
|
|
self.e2e_latency = []
|
|
self.token_generation_times = [] # time at which each token was generated (relative to start of the generation)
|
|
self.shape_and_decoded_outputs = []
|
|
self.gpu_metrics = []
|
|
|
|
def accumulate(
|
|
self,
|
|
e2e_latency: float,
|
|
token_generation_times: list[float],
|
|
shape_and_decoded_output: str,
|
|
gpu_metrics: GPURawMetrics | None,
|
|
) -> None:
|
|
self.e2e_latency.append(e2e_latency)
|
|
self.token_generation_times.append(token_generation_times)
|
|
self.shape_and_decoded_outputs.append(shape_and_decoded_output)
|
|
self.gpu_metrics.append(gpu_metrics)
|
|
|
|
def to_dict(self) -> dict[str, None | int | float]:
|
|
# Save GPU metrics as None if it contains only None values
|
|
if all(gm is None for gm in self.gpu_metrics):
|
|
gpu_metrics = None
|
|
else:
|
|
gpu_metrics = [gm.to_dict() for gm in self.gpu_metrics]
|
|
return {
|
|
"e2e_latency": self.e2e_latency,
|
|
"token_generation_times": self.token_generation_times,
|
|
"shape_and_decoded_outputs": self.shape_and_decoded_outputs,
|
|
"gpu_metrics": gpu_metrics,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, None | int | float]) -> "BenchmarkResult":
|
|
# Handle GPU metrics, which is saved as None if it contains only None values
|
|
if data["gpu_metrics"] is None:
|
|
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
|
|
else:
|
|
gpu_metrics = [GPURawMetrics.from_dict(gm) for gm in data["gpu_metrics"]]
|
|
# Create a new instance and accumulate the data
|
|
new_instance = cls()
|
|
for i in range(len(data["e2e_latency"])):
|
|
new_instance.accumulate(
|
|
e2e_latency=data["e2e_latency"][i],
|
|
token_generation_times=data["token_generation_times"][i],
|
|
shape_and_decoded_output=data["shape_and_decoded_outputs"][i],
|
|
gpu_metrics=gpu_metrics[i],
|
|
)
|
|
return new_instance
|
|
|
|
def get_measured_ttft(self) -> list[float]:
|
|
return [dt[0] for dt in self.token_generation_times if len(dt) > 0]
|
|
|
|
def get_measured_itl(self) -> list[float]:
|
|
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
|
|
|
def get_throughput(self, total_generated_tokens: int) -> list[float]:
|
|
return [total_generated_tokens / e2e_latency for e2e_latency in self.e2e_latency]
|
|
|
|
def pprint(self, batch_size: int = 0, num_generated_tokens: int = 0, tabs: int = 0) -> None:
|
|
measurements = {
|
|
"E2E Latency": add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
|
"Time to First Token": add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
|
}
|
|
itl_values = self.get_measured_itl()
|
|
if len(itl_values) > 0:
|
|
measurements["Inter-Token Latency"] = add_unit_to_duration(compute_basic_statistics(itl_values))
|
|
if batch_size > 0:
|
|
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size * num_generated_tokens))
|
|
measurements["Throughput"] = {key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()}
|
|
dict_to_pprint = equalize_lengths_and_collate(measurements)
|
|
pretty_print_dict(dict_to_pprint, tabs=tabs)
|