mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Benchmarking improvements (#39768)
* Start revamping benchmarking * Start refactoring benchmarking * Use Pandas for CSV * import fix * Remove benchmark files * Remove sample data * Address review comments
This commit is contained in:
1
benchmark/.gitignore
vendored
Normal file
1
benchmark/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
benchmark_results/
|
345
benchmark/benches/llama.py
Normal file
345
benchmark/benches/llama.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from logging import Logger
|
||||||
|
import os
|
||||||
|
from threading import Event, Thread
|
||||||
|
from time import perf_counter, sleep
|
||||||
|
from typing import Optional
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add the parent directory to Python path to import benchmarks_entrypoint
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from benchmarks_entrypoint import MetricsRecorder
|
||||||
|
|
||||||
|
import gpustat
|
||||||
|
import psutil
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
# Optional heavy ML dependencies - only required when actually running the benchmark
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
|
||||||
|
TRANSFORMERS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TRANSFORMERS_AVAILABLE = False
|
||||||
|
torch = None
|
||||||
|
AutoModelForCausalLM = None
|
||||||
|
AutoTokenizer = None
|
||||||
|
GenerationConfig = None
|
||||||
|
StaticCache = None
|
||||||
|
|
||||||
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
||||||
|
|
||||||
|
# Only set torch precision if torch is available
|
||||||
|
if TRANSFORMERS_AVAILABLE:
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
|
||||||
|
def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
|
||||||
|
p = psutil.Process(os.getpid())
|
||||||
|
while not continue_metric_collection.is_set():
|
||||||
|
with p.oneshot():
|
||||||
|
cpu_util = p.cpu_percent()
|
||||||
|
mem_megabytes = p.memory_info().rss / (1024 * 1024)
|
||||||
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||||
|
gpu_util = gpu_stats[0]["utilization.gpu"]
|
||||||
|
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
|
||||||
|
metrics_recorder.collect_device_measurements(
|
||||||
|
benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
|
||||||
|
)
|
||||||
|
sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(
|
||||||
|
logger: Logger, repository: str, branch: str, commit_id: str, commit_msg: str, metrics_recorder=None, num_tokens_to_generate=100
|
||||||
|
):
|
||||||
|
# Check if required ML dependencies are available
|
||||||
|
if not TRANSFORMERS_AVAILABLE:
|
||||||
|
logger.error("Transformers and torch are required to run the LLaMA benchmark. Please install them with:")
|
||||||
|
logger.error("pip install torch transformers")
|
||||||
|
logger.error("Skipping LLaMA benchmark due to missing dependencies.")
|
||||||
|
return
|
||||||
|
|
||||||
|
continue_metric_collection = Event()
|
||||||
|
metrics_thread = None
|
||||||
|
model_id = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
# If no metrics_recorder is provided, create one for backward compatibility
|
||||||
|
if metrics_recorder is None:
|
||||||
|
try:
|
||||||
|
metrics_recorder = MetricsRecorder(
|
||||||
|
psycopg2.connect("dbname=metrics"), logger, repository, branch, commit_id, commit_msg, True
|
||||||
|
)
|
||||||
|
should_close_recorder = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create metrics recorder: {e}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
should_close_recorder = False
|
||||||
|
try:
|
||||||
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||||
|
gpu_name = gpu_stats[0]["name"]
|
||||||
|
benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
|
||||||
|
logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}")
|
||||||
|
metrics_thread = Thread(
|
||||||
|
target=collect_metrics,
|
||||||
|
args=[benchmark_id, continue_metric_collection, metrics_recorder],
|
||||||
|
)
|
||||||
|
metrics_thread.start()
|
||||||
|
logger.info("started background thread to fetch device metrics")
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
logger.info("downloading weights")
|
||||||
|
# This is to avoid counting download in model load time measurement
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||||
|
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
|
||||||
|
logger.info("loading model")
|
||||||
|
start = perf_counter()
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch.float16, generation_config=gen_config
|
||||||
|
).eval()
|
||||||
|
model.to(device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
model_load_time = end - start
|
||||||
|
logger.info(f"loaded model in: {model_load_time}s")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
prompt = "Why dogs are so cute?"
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
# Specify the max length (including both the prompt and the response)
|
||||||
|
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
|
||||||
|
# with sequence length = `max_length`. The longer the more you will re-use it
|
||||||
|
seq_length = inputs["input_ids"].shape[1]
|
||||||
|
model.generation_config.max_length = seq_length + num_tokens_to_generate
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
|
||||||
|
# Copied from the gpt-fast repo
|
||||||
|
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
||||||
|
q = torch.empty_like(probs_sort).exponential_(1)
|
||||||
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||||
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
|
if top_k is not None:
|
||||||
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||||
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||||
|
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
||||||
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
|
return idx_next, probs
|
||||||
|
|
||||||
|
# First eager forward pass
|
||||||
|
logger.info("running first eager forward pass")
|
||||||
|
start = perf_counter()
|
||||||
|
outputs = model(**inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
first_eager_fwd_pass_time = end - start
|
||||||
|
logger.info(f"completed first eager forward pass in: {first_eager_fwd_pass_time}s")
|
||||||
|
|
||||||
|
# Second eager forward pass (should be faster)
|
||||||
|
logger.info("running second eager forward pass")
|
||||||
|
start = perf_counter()
|
||||||
|
outputs = model(**inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
second_eager_fwd_pass_time = end - start
|
||||||
|
logger.info(f"completed second eager forward pass in: {second_eager_fwd_pass_time}s")
|
||||||
|
|
||||||
|
# First eager generation
|
||||||
|
logger.info("running first eager generation")
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
first_eager_generate_time = end - start
|
||||||
|
logger.info(f"completed first eager generation in: {first_eager_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
# Second eager generation (should be faster)
|
||||||
|
logger.info("running second eager generation")
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
second_eager_generate_time = end - start
|
||||||
|
logger.info(f"completed second eager generation in: {second_eager_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
logger.info("running generation timing loop")
|
||||||
|
|
||||||
|
input_pos = torch.arange(0, seq_length, device=device)
|
||||||
|
inputs = inputs["input_ids"]
|
||||||
|
|
||||||
|
start = perf_counter()
|
||||||
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||||
|
logits = model(inputs, position_ids=input_pos).logits
|
||||||
|
next_token, probs = sample(logits, temperature=0.6, top_k=5)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
time_to_first_token = end - start
|
||||||
|
|
||||||
|
input_pos = torch.tensor([seq_length], device=device, dtype=torch.int)
|
||||||
|
next_token = next_token.clone()
|
||||||
|
start = perf_counter()
|
||||||
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||||
|
logits = model(next_token, position_ids=input_pos).logits
|
||||||
|
next_token, probs = sample(logits, temperature=0.6, top_k=5)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
time_to_second_token = end - start
|
||||||
|
|
||||||
|
input_pos = torch.tensor([seq_length + 1], device=device, dtype=torch.int)
|
||||||
|
next_token = next_token.clone()
|
||||||
|
start = perf_counter()
|
||||||
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||||
|
logits = model(next_token, position_ids=input_pos).logits
|
||||||
|
next_token, probs = sample(logits, temperature=0.6, top_k=5)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
time_to_third_token = end - start
|
||||||
|
|
||||||
|
logger.info("running longer generation timing loop")
|
||||||
|
|
||||||
|
total_time = 0
|
||||||
|
for i in range(20):
|
||||||
|
input_pos = torch.tensor([seq_length + 2 + i], device=device, dtype=torch.int)
|
||||||
|
next_token = next_token.clone()
|
||||||
|
start = perf_counter()
|
||||||
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||||
|
logits = model(next_token, position_ids=input_pos).logits
|
||||||
|
next_token, probs = sample(logits, temperature=0.6, top_k=5)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = perf_counter()
|
||||||
|
total_time += end - start
|
||||||
|
|
||||||
|
mean_time_to_next_token = total_time / 20
|
||||||
|
|
||||||
|
logger.info("running compilation benchmarks")
|
||||||
|
|
||||||
|
# Now compile the model
|
||||||
|
model = torch.compile(model, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
# StaticCache for generation
|
||||||
|
with torch.device(device):
|
||||||
|
model.setup_caches(max_batch_size=batch_size, max_seq_len=seq_length + num_tokens_to_generate)
|
||||||
|
|
||||||
|
input_pos = torch.arange(0, seq_length, device=device)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)["input_ids"]
|
||||||
|
|
||||||
|
logger.info("compiling model")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, generation_config=gen_config)
|
||||||
|
model.to(device)
|
||||||
|
model = torch.compile(model, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
past_key_values = StaticCache(
|
||||||
|
model.config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
max_cache_len=seq_length + 128,
|
||||||
|
)
|
||||||
|
# 1st call
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||||
|
end = perf_counter()
|
||||||
|
first_compile_generate_time = end - start
|
||||||
|
logger.info(f"completed first compile generation in: {first_compile_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
past_key_values = StaticCache(
|
||||||
|
model.config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
max_cache_len=seq_length + 128,
|
||||||
|
)
|
||||||
|
# 2nd call
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||||
|
end = perf_counter()
|
||||||
|
second_compile_generate_time = end - start
|
||||||
|
logger.info(f"completed second compile generation in: {second_compile_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
past_key_values = StaticCache(
|
||||||
|
model.config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
max_cache_len=seq_length + 128,
|
||||||
|
)
|
||||||
|
# 3rd call
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||||
|
end = perf_counter()
|
||||||
|
third_compile_generate_time = end - start
|
||||||
|
logger.info(f"completed third compile generation in: {third_compile_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
past_key_values = StaticCache(
|
||||||
|
model.config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
max_cache_len=seq_length + 128,
|
||||||
|
)
|
||||||
|
# 4th call
|
||||||
|
start = perf_counter()
|
||||||
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||||
|
end = perf_counter()
|
||||||
|
fourth_compile_generate_time = end - start
|
||||||
|
logger.info(f"completed fourth compile generation in: {fourth_compile_generate_time}s")
|
||||||
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||||
|
|
||||||
|
metrics_recorder.collect_model_measurements(
|
||||||
|
benchmark_id,
|
||||||
|
{
|
||||||
|
"model_load_time": model_load_time,
|
||||||
|
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
|
||||||
|
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
|
||||||
|
"first_eager_generate_time_secs": first_eager_generate_time,
|
||||||
|
"second_eager_generate_time_secs": second_eager_generate_time,
|
||||||
|
"time_to_first_token_secs": time_to_first_token,
|
||||||
|
"time_to_second_token_secs": time_to_second_token,
|
||||||
|
"time_to_third_token_secs": time_to_third_token,
|
||||||
|
"time_to_next_token_mean_secs": mean_time_to_next_token,
|
||||||
|
"first_compile_generate_time_secs": first_compile_generate_time,
|
||||||
|
"second_compile_generate_time_secs": second_compile_generate_time,
|
||||||
|
"third_compile_generate_time_secs": third_compile_generate_time,
|
||||||
|
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Caught exception: {e}")
|
||||||
|
continue_metric_collection.set()
|
||||||
|
if metrics_thread is not None:
|
||||||
|
metrics_thread.join()
|
||||||
|
|
||||||
|
# Only close the recorder if we created it locally
|
||||||
|
if should_close_recorder:
|
||||||
|
metrics_recorder.close()
|
@ -1,15 +1,35 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import argparse
|
import argparse
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, Tuple
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Tuple, Optional, List
|
||||||
|
|
||||||
from psycopg2.extensions import register_adapter
|
import pandas as pd
|
||||||
from psycopg2.extras import Json
|
|
||||||
|
|
||||||
|
try:
|
||||||
register_adapter(dict, Json)
|
from psycopg2.extensions import register_adapter
|
||||||
|
from psycopg2.extras import Json
|
||||||
|
register_adapter(dict, Json)
|
||||||
|
PSYCOPG2_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PSYCOPG2_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
class ImportModuleException(Exception):
|
class ImportModuleException(Exception):
|
||||||
@ -18,61 +38,239 @@ class ImportModuleException(Exception):
|
|||||||
|
|
||||||
class MetricsRecorder:
|
class MetricsRecorder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, connection, logger: logging.Logger, repository: str, branch: str, commit_id: str, commit_msg: str
|
self, connection, logger: logging.Logger, repository: str, branch: str, commit_id: str, commit_msg: str,
|
||||||
|
collect_csv_data: bool = True
|
||||||
):
|
):
|
||||||
self.conn = connection
|
self.conn = connection
|
||||||
self.conn.autocommit = True
|
self.use_database = connection is not None
|
||||||
|
if self.use_database:
|
||||||
|
self.conn.autocommit = True
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.repository = repository
|
self.repository = repository
|
||||||
self.branch = branch
|
self.branch = branch
|
||||||
self.commit_id = commit_id
|
self.commit_id = commit_id
|
||||||
self.commit_msg = commit_msg
|
self.commit_msg = commit_msg
|
||||||
|
self.collect_csv_data = collect_csv_data
|
||||||
|
|
||||||
|
# For CSV export - store all data in pandas DataFrames (only if CSV collection is enabled)
|
||||||
|
if self.collect_csv_data:
|
||||||
|
# Initialize empty DataFrames with proper schemas
|
||||||
|
self.benchmarks_df = pd.DataFrame(columns=[
|
||||||
|
'benchmark_id', 'repository', 'branch', 'commit_id', 'commit_message',
|
||||||
|
'metadata', 'created_at'
|
||||||
|
])
|
||||||
|
self.device_measurements_df = pd.DataFrame(columns=[
|
||||||
|
'benchmark_id', 'cpu_util', 'mem_megabytes', 'gpu_util',
|
||||||
|
'gpu_mem_megabytes', 'time'
|
||||||
|
])
|
||||||
|
self.model_measurements_df = pd.DataFrame(columns=[
|
||||||
|
'benchmark_id', 'time', 'model_load_time', 'first_eager_forward_pass_time_secs',
|
||||||
|
'second_eager_forward_pass_time_secs', 'first_eager_generate_time_secs',
|
||||||
|
'second_eager_generate_time_secs', 'time_to_first_token_secs',
|
||||||
|
'time_to_second_token_secs', 'time_to_third_token_secs',
|
||||||
|
'time_to_next_token_mean_secs', 'first_compile_generate_time_secs',
|
||||||
|
'second_compile_generate_time_secs', 'third_compile_generate_time_secs',
|
||||||
|
'fourth_compile_generate_time_secs'
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.benchmarks_df = None
|
||||||
|
self.device_measurements_df = None
|
||||||
|
self.model_measurements_df = None
|
||||||
|
|
||||||
def initialise_benchmark(self, metadata: dict[str, str]) -> int:
|
def initialise_benchmark(self, metadata: dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a new benchmark, returns the benchmark id
|
Creates a new benchmark, returns the benchmark id (UUID)
|
||||||
"""
|
"""
|
||||||
# gpu_name: str, model_id: str
|
# Generate a unique UUID for this benchmark
|
||||||
with self.conn.cursor() as cur:
|
benchmark_id = str(uuid.uuid4())
|
||||||
cur.execute(
|
|
||||||
"INSERT INTO benchmarks (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
|
if self.use_database:
|
||||||
(self.repository, self.branch, self.commit_id, self.commit_msg, metadata),
|
with self.conn.cursor() as cur:
|
||||||
)
|
cur.execute(
|
||||||
benchmark_id = cur.fetchone()[0]
|
"INSERT INTO benchmarks (benchmark_id, repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s, %s)",
|
||||||
logger.debug(f"initialised benchmark #{benchmark_id}")
|
(benchmark_id, self.repository, self.branch, self.commit_id, self.commit_msg, metadata),
|
||||||
return benchmark_id
|
)
|
||||||
|
self.logger.debug(f"initialised benchmark #{benchmark_id}")
|
||||||
|
|
||||||
|
# Store benchmark data for CSV export (if enabled)
|
||||||
|
if self.collect_csv_data:
|
||||||
|
# Add row to pandas DataFrame
|
||||||
|
new_row = pd.DataFrame([{
|
||||||
|
'benchmark_id': benchmark_id,
|
||||||
|
'repository': self.repository,
|
||||||
|
'branch': self.branch,
|
||||||
|
'commit_id': self.commit_id,
|
||||||
|
'commit_message': self.commit_msg,
|
||||||
|
'metadata': json.dumps(metadata),
|
||||||
|
'created_at': datetime.utcnow().isoformat()
|
||||||
|
}])
|
||||||
|
self.benchmarks_df = pd.concat([self.benchmarks_df, new_row], ignore_index=True)
|
||||||
|
|
||||||
|
mode_info = []
|
||||||
|
if self.use_database:
|
||||||
|
mode_info.append("database")
|
||||||
|
if self.collect_csv_data:
|
||||||
|
mode_info.append("CSV")
|
||||||
|
mode_str = " + ".join(mode_info) if mode_info else "no storage"
|
||||||
|
|
||||||
|
self.logger.debug(f"initialised benchmark #{benchmark_id} ({mode_str} mode)")
|
||||||
|
return benchmark_id
|
||||||
|
|
||||||
def collect_device_measurements(self, benchmark_id: int, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes):
|
def collect_device_measurements(self, benchmark_id: str, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes):
|
||||||
"""
|
"""
|
||||||
Collect device metrics, such as CPU & GPU usage. These are "static", as in you cannot pass arbitrary arguments to the function.
|
Collect device metrics, such as CPU & GPU usage. These are "static", as in you cannot pass arbitrary arguments to the function.
|
||||||
"""
|
"""
|
||||||
with self.conn.cursor() as cur:
|
# Store device measurements for CSV export (if enabled)
|
||||||
cur.execute(
|
if self.collect_csv_data:
|
||||||
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
|
# Add row to pandas DataFrame
|
||||||
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
|
new_row = pd.DataFrame([{
|
||||||
)
|
'benchmark_id': benchmark_id,
|
||||||
|
'cpu_util': cpu_util,
|
||||||
|
'mem_megabytes': mem_megabytes,
|
||||||
|
'gpu_util': gpu_util,
|
||||||
|
'gpu_mem_megabytes': gpu_mem_megabytes,
|
||||||
|
'time': datetime.utcnow().isoformat()
|
||||||
|
}])
|
||||||
|
self.device_measurements_df = pd.concat([self.device_measurements_df, new_row], ignore_index=True)
|
||||||
|
|
||||||
|
# Store in database if available
|
||||||
|
if self.use_database:
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
|
||||||
|
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"inserted device measurements for benchmark #{benchmark_id} [CPU util: {cpu_util}, mem MBs: {mem_megabytes}, GPU util: {gpu_util}, GPU mem MBs: {gpu_mem_megabytes}]"
|
f"collected device measurements for benchmark #{benchmark_id} [CPU util: {cpu_util}, mem MBs: {mem_megabytes}, GPU util: {gpu_util}, GPU mem MBs: {gpu_mem_megabytes}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
def collect_model_measurements(self, benchmark_id: int, measurements: dict[str, float]):
|
def collect_model_measurements(self, benchmark_id: str, measurements: dict[str, float]):
|
||||||
with self.conn.cursor() as cur:
|
# Store model measurements for CSV export (if enabled)
|
||||||
cur.execute(
|
if self.collect_csv_data:
|
||||||
"""
|
# Add row to pandas DataFrame with flattened measurements
|
||||||
INSERT INTO model_measurements (
|
row_data = {
|
||||||
benchmark_id,
|
'benchmark_id': benchmark_id,
|
||||||
measurements
|
'time': datetime.utcnow().isoformat()
|
||||||
) VALUES (%s, %s)
|
}
|
||||||
""",
|
# Flatten the measurements dict into the row
|
||||||
(
|
row_data.update(measurements)
|
||||||
benchmark_id,
|
|
||||||
measurements,
|
new_row = pd.DataFrame([row_data])
|
||||||
),
|
self.model_measurements_df = pd.concat([self.model_measurements_df, new_row], ignore_index=True)
|
||||||
)
|
|
||||||
self.logger.debug(f"inserted model measurements for benchmark #{benchmark_id}: {measurements}")
|
# Store in database if available
|
||||||
|
if self.use_database:
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO model_measurements (
|
||||||
|
benchmark_id,
|
||||||
|
measurements
|
||||||
|
) VALUES (%s, %s)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
benchmark_id,
|
||||||
|
measurements,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(f"collected model measurements for benchmark #{benchmark_id}: {measurements}")
|
||||||
|
|
||||||
|
def export_to_csv(self, output_dir: str = "benchmark_results"):
|
||||||
|
"""
|
||||||
|
Export all collected data to CSV files using pandas DataFrames
|
||||||
|
"""
|
||||||
|
if not self.collect_csv_data:
|
||||||
|
self.logger.warning("CSV data collection is disabled - no CSV files will be generated")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
self.logger.info(f"Created output directory: {output_dir}")
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
files_created = []
|
||||||
|
|
||||||
|
# Export using pandas DataFrames
|
||||||
|
self._export_pandas_data(output_dir, timestamp, files_created)
|
||||||
|
|
||||||
|
self.logger.info(f"CSV export complete! Created {len(files_created)} files in {output_dir}")
|
||||||
|
|
||||||
|
def _export_pandas_data(self, output_dir: str, timestamp: str, files_created: list):
|
||||||
|
"""
|
||||||
|
Export CSV files using pandas DataFrames
|
||||||
|
"""
|
||||||
|
# Export benchmarks
|
||||||
|
benchmarks_file = os.path.join(output_dir, f"benchmarks_{timestamp}.csv")
|
||||||
|
self.benchmarks_df.to_csv(benchmarks_file, index=False)
|
||||||
|
files_created.append(benchmarks_file)
|
||||||
|
self.logger.info(f"Exported {len(self.benchmarks_df)} benchmark records to {benchmarks_file}")
|
||||||
|
|
||||||
|
# Export device measurements
|
||||||
|
device_file = os.path.join(output_dir, f"device_measurements_{timestamp}.csv")
|
||||||
|
self.device_measurements_df.to_csv(device_file, index=False)
|
||||||
|
files_created.append(device_file)
|
||||||
|
self.logger.info(f"Exported {len(self.device_measurements_df)} device measurement records to {device_file}")
|
||||||
|
|
||||||
|
# Export model measurements (already flattened)
|
||||||
|
model_file = os.path.join(output_dir, f"model_measurements_{timestamp}.csv")
|
||||||
|
self.model_measurements_df.to_csv(model_file, index=False)
|
||||||
|
files_created.append(model_file)
|
||||||
|
self.logger.info(f"Exported {len(self.model_measurements_df)} model measurement records to {model_file}")
|
||||||
|
|
||||||
|
# Create comprehensive summary using pandas operations
|
||||||
|
summary_file = os.path.join(output_dir, f"benchmark_summary_{timestamp}.csv")
|
||||||
|
self._create_summary(summary_file)
|
||||||
|
files_created.append(summary_file)
|
||||||
|
|
||||||
|
def _create_summary(self, summary_file: str):
|
||||||
|
"""
|
||||||
|
Create a comprehensive summary CSV using pandas operations
|
||||||
|
"""
|
||||||
|
if len(self.benchmarks_df) == 0:
|
||||||
|
# Create empty summary file
|
||||||
|
summary_df = pd.DataFrame()
|
||||||
|
summary_df.to_csv(summary_file, index=False)
|
||||||
|
self.logger.info(f"Created empty benchmark summary at {summary_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start with benchmarks as the base
|
||||||
|
summary_df = self.benchmarks_df.copy()
|
||||||
|
|
||||||
|
# Add model measurements (join on benchmark_id)
|
||||||
|
if len(self.model_measurements_df) > 0:
|
||||||
|
# Drop 'time' column from model measurements to avoid conflicts
|
||||||
|
model_df = self.model_measurements_df.drop(columns=['time'], errors='ignore')
|
||||||
|
summary_df = summary_df.merge(model_df, on='benchmark_id', how='left')
|
||||||
|
|
||||||
|
# Calculate device measurement aggregates using pandas groupby
|
||||||
|
if len(self.device_measurements_df) > 0:
|
||||||
|
device_agg = self.device_measurements_df.groupby('benchmark_id').agg({
|
||||||
|
'cpu_util': ['mean', 'max', 'std', 'count'],
|
||||||
|
'mem_megabytes': ['mean', 'max', 'std'],
|
||||||
|
'gpu_util': ['mean', 'max', 'std'],
|
||||||
|
'gpu_mem_megabytes': ['mean', 'max', 'std']
|
||||||
|
}).round(3)
|
||||||
|
|
||||||
|
# Flatten column names
|
||||||
|
device_agg.columns = [f"{col[0]}_{col[1]}" for col in device_agg.columns]
|
||||||
|
device_agg = device_agg.reset_index()
|
||||||
|
|
||||||
|
# Rename count column to be more descriptive
|
||||||
|
if 'cpu_util_count' in device_agg.columns:
|
||||||
|
device_agg = device_agg.rename(columns={'cpu_util_count': 'device_measurement_count'})
|
||||||
|
|
||||||
|
# Merge with summary
|
||||||
|
summary_df = summary_df.merge(device_agg, on='benchmark_id', how='left')
|
||||||
|
|
||||||
|
# Export the comprehensive summary
|
||||||
|
summary_df.to_csv(summary_file, index=False)
|
||||||
|
self.logger.info(f"Created comprehensive benchmark summary with {len(summary_df)} records at {summary_file}")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.conn.close()
|
if self.use_database and self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -85,7 +283,7 @@ handler.setFormatter(formatter)
|
|||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> tuple[str, str, str, str]:
|
def parse_arguments() -> tuple[str, str, str, str, bool, str]:
|
||||||
"""
|
"""
|
||||||
Parse command line arguments for the benchmarking CLI.
|
Parse command line arguments for the benchmarking CLI.
|
||||||
"""
|
"""
|
||||||
@ -114,10 +312,27 @@ def parse_arguments() -> tuple[str, str, str, str]:
|
|||||||
type=str,
|
type=str,
|
||||||
help="The commit message associated with the commit, truncated to 70 characters.",
|
help="The commit message associated with the commit, truncated to 70 characters.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable CSV output files generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-output-dir",
|
||||||
|
type=str,
|
||||||
|
default="benchmark_results",
|
||||||
|
help="Directory for CSV output files (default: benchmark_results)."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# CSV is disabled by default, only enabled when --csv is used
|
||||||
|
generate_csv = args.csv
|
||||||
|
|
||||||
return args.repository, args.branch, args.commit_id, args.commit_msg
|
return args.repository, args.branch, args.commit_id, args.commit_msg, generate_csv, args.csv_output_dir
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name, file_path):
|
def import_from_path(module_name, file_path):
|
||||||
@ -131,22 +346,124 @@ def import_from_path(module_name, file_path):
|
|||||||
raise ImportModuleException(f"failed to load python module: {e}")
|
raise ImportModuleException(f"failed to load python module: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_database_connection():
|
||||||
|
"""
|
||||||
|
Try to create a database connection. Returns None if connection fails.
|
||||||
|
"""
|
||||||
|
if not PSYCOPG2_AVAILABLE:
|
||||||
|
logger.warning("psycopg2 not available - running in CSV-only mode")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg2
|
||||||
|
conn = psycopg2.connect("dbname=metrics")
|
||||||
|
logger.info("Successfully connected to database")
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to connect to database: {e}. Running in CSV-only mode")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_global_metrics_recorder(repository: str, branch: str, commit_id: str, commit_msg: str,
|
||||||
|
generate_csv: bool = False) -> MetricsRecorder:
|
||||||
|
"""
|
||||||
|
Create a global metrics recorder that will be used across all benchmarks.
|
||||||
|
"""
|
||||||
|
connection = create_database_connection()
|
||||||
|
recorder = MetricsRecorder(connection, logger, repository, branch, commit_id, commit_msg, generate_csv)
|
||||||
|
|
||||||
|
# Log the storage mode
|
||||||
|
storage_modes = []
|
||||||
|
if connection is not None:
|
||||||
|
storage_modes.append("database")
|
||||||
|
if generate_csv:
|
||||||
|
storage_modes.append("CSV")
|
||||||
|
|
||||||
|
if not storage_modes:
|
||||||
|
logger.warning("Running benchmarks with NO data storage (no database connection, CSV disabled)")
|
||||||
|
logger.warning("Use --csv flag to enable CSV output when database is unavailable")
|
||||||
|
else:
|
||||||
|
logger.info(f"Running benchmarks with: {' + '.join(storage_modes)} storage")
|
||||||
|
|
||||||
|
return recorder
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__))
|
benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
benches_folder_path = os.path.join(benchmarks_folder_path, "benches")
|
||||||
|
|
||||||
repository, branch, commit_id, commit_msg = parse_arguments()
|
repository, branch, commit_id, commit_msg, generate_csv, csv_output_dir = parse_arguments()
|
||||||
|
|
||||||
for entry in os.scandir(benchmarks_folder_path):
|
# Create a global metrics recorder
|
||||||
try:
|
global_metrics_recorder = create_global_metrics_recorder(repository, branch, commit_id, commit_msg, generate_csv)
|
||||||
|
|
||||||
|
successful_benchmarks = 0
|
||||||
|
failed_benchmarks = 0
|
||||||
|
|
||||||
|
# Automatically discover all benchmark modules in benches/ folder
|
||||||
|
benchmark_modules = []
|
||||||
|
|
||||||
|
if os.path.exists(benches_folder_path):
|
||||||
|
logger.debug(f"Scanning for benchmarks in: {benches_folder_path}")
|
||||||
|
for entry in os.scandir(benches_folder_path):
|
||||||
if not entry.name.endswith(".py"):
|
if not entry.name.endswith(".py"):
|
||||||
continue
|
continue
|
||||||
if entry.path == __file__:
|
if entry.name.startswith("__"): # Skip __init__.py, __pycache__, etc.
|
||||||
continue
|
continue
|
||||||
logger.debug(f"loading: {entry.name}")
|
|
||||||
module = import_from_path(entry.name.split(".")[0], entry.path)
|
# Check if the file has a run_benchmark function
|
||||||
logger.info(f"running benchmarks in: {entry.name}")
|
try:
|
||||||
module.run_benchmark(logger, repository, branch, commit_id, commit_msg)
|
logger.debug(f"checking if benches/{entry.name} has run_benchmark function")
|
||||||
|
module = import_from_path(entry.name.split(".")[0], entry.path)
|
||||||
|
if hasattr(module, 'run_benchmark'):
|
||||||
|
benchmark_modules.append(entry.name)
|
||||||
|
logger.debug(f"discovered benchmark: {entry.name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"skipping {entry.name} - no run_benchmark function found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"failed to check benches/{entry.name}: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Benches directory not found: {benches_folder_path}")
|
||||||
|
|
||||||
|
if benchmark_modules:
|
||||||
|
logger.info(f"Discovered {len(benchmark_modules)} benchmark(s): {benchmark_modules}")
|
||||||
|
else:
|
||||||
|
logger.warning("No benchmark modules found in benches/ directory")
|
||||||
|
|
||||||
|
for module_name in benchmark_modules:
|
||||||
|
module_path = os.path.join(benches_folder_path, module_name)
|
||||||
|
try:
|
||||||
|
logger.debug(f"loading: {module_name}")
|
||||||
|
module = import_from_path(module_name.split(".")[0], module_path)
|
||||||
|
logger.info(f"running benchmarks in: {module_name}")
|
||||||
|
|
||||||
|
# Check if the module has an updated run_benchmark function that accepts metrics_recorder
|
||||||
|
try:
|
||||||
|
# Try the new signature first
|
||||||
|
module.run_benchmark(logger, repository, branch, commit_id, commit_msg, global_metrics_recorder)
|
||||||
|
except TypeError:
|
||||||
|
# Fall back to the old signature for backward compatibility
|
||||||
|
logger.warning(f"Module {module_name} using old run_benchmark signature - database connection will be created per module")
|
||||||
|
module.run_benchmark(logger, repository, branch, commit_id, commit_msg)
|
||||||
|
|
||||||
|
successful_benchmarks += 1
|
||||||
except ImportModuleException as e:
|
except ImportModuleException as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
failed_benchmarks += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"error running benchmarks for {entry.name}: {e}")
|
logger.error(f"error running benchmarks for {module_name}: {e}")
|
||||||
|
failed_benchmarks += 1
|
||||||
|
|
||||||
|
# Export CSV results at the end (if enabled)
|
||||||
|
try:
|
||||||
|
if generate_csv:
|
||||||
|
global_metrics_recorder.export_to_csv(csv_output_dir)
|
||||||
|
logger.info(f"CSV reports have been generated and saved to the {csv_output_dir} directory")
|
||||||
|
else:
|
||||||
|
logger.info("CSV generation disabled - no CSV files created (use --csv to enable)")
|
||||||
|
|
||||||
|
logger.info(f"Benchmark run completed. Successful: {successful_benchmarks}, Failed: {failed_benchmarks}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to export CSV results: {e}")
|
||||||
|
finally:
|
||||||
|
global_metrics_recorder.close()
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS benchmarks (
|
|
||||||
benchmark_id SERIAL PRIMARY KEY,
|
|
||||||
repository VARCHAR(255),
|
|
||||||
branch VARCHAR(255),
|
|
||||||
commit_id VARCHAR(72),
|
|
||||||
commit_message VARCHAR(70),
|
|
||||||
metadata jsonb,
|
|
||||||
created_at timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS benchmarks_benchmark_id_idx ON benchmarks (benchmark_id);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS benchmarks_branch_idx ON benchmarks (branch);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS device_measurements (
|
|
||||||
measurement_id SERIAL PRIMARY KEY,
|
|
||||||
benchmark_id int REFERENCES benchmarks (benchmark_id),
|
|
||||||
cpu_util double precision,
|
|
||||||
mem_megabytes double precision,
|
|
||||||
gpu_util double precision,
|
|
||||||
gpu_mem_megabytes double precision,
|
|
||||||
time timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS device_measurements_branch_idx ON device_measurements (benchmark_id);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS model_measurements (
|
|
||||||
measurement_id SERIAL PRIMARY KEY,
|
|
||||||
benchmark_id int REFERENCES benchmarks (benchmark_id),
|
|
||||||
measurements jsonb,
|
|
||||||
time timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS model_measurements_branch_idx ON model_measurements (benchmark_id);
|
|
@ -1,346 +0,0 @@
|
|||||||
from logging import Logger
|
|
||||||
import os
|
|
||||||
from threading import Event, Thread
|
|
||||||
from time import perf_counter, sleep
|
|
||||||
from typing import Optional
|
|
||||||
from benchmarks_entrypoint import MetricsRecorder
|
|
||||||
import gpustat
|
|
||||||
import psutil
|
|
||||||
import psycopg2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
|
||||||
torch.set_float32_matmul_precision("high")
|
|
||||||
|
|
||||||
|
|
||||||
def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
|
|
||||||
p = psutil.Process(os.getpid())
|
|
||||||
while not continue_metric_collection.is_set():
|
|
||||||
with p.oneshot():
|
|
||||||
cpu_util = p.cpu_percent()
|
|
||||||
mem_megabytes = p.memory_info().rss / (1024 * 1024)
|
|
||||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
|
||||||
gpu_util = gpu_stats[0]["utilization.gpu"]
|
|
||||||
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
|
|
||||||
metrics_recorder.collect_device_measurements(
|
|
||||||
benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
|
|
||||||
)
|
|
||||||
sleep(0.01)
|
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(
|
|
||||||
logger: Logger, repository: str, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100
|
|
||||||
):
|
|
||||||
continue_metric_collection = Event()
|
|
||||||
metrics_thread = None
|
|
||||||
model_id = "meta-llama/Llama-2-7b-hf"
|
|
||||||
metrics_recorder = MetricsRecorder(
|
|
||||||
psycopg2.connect("dbname=metrics"), logger, repository, branch, commit_id, commit_msg
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
|
||||||
gpu_name = gpu_stats[0]["name"]
|
|
||||||
benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
|
|
||||||
logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}")
|
|
||||||
metrics_thread = Thread(
|
|
||||||
target=collect_metrics,
|
|
||||||
args=[benchmark_id, continue_metric_collection, metrics_recorder],
|
|
||||||
)
|
|
||||||
metrics_thread.start()
|
|
||||||
logger.info("started background thread to fetch device metrics")
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
logger.info("downloading weights")
|
|
||||||
# This is to avoid counting download in model load time measurement
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
|
|
||||||
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
|
|
||||||
logger.info("loading model")
|
|
||||||
start = perf_counter()
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id, torch_dtype=torch.float16, generation_config=gen_config
|
|
||||||
).eval()
|
|
||||||
model.to(device)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
model_load_time = end - start
|
|
||||||
logger.info(f"loaded model in: {model_load_time}s")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
||||||
|
|
||||||
prompt = "Why dogs are so cute?"
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
||||||
|
|
||||||
# Specify the max length (including both the prompt and the response)
|
|
||||||
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
|
|
||||||
# with sequence length = `max_length`. The longer the more you will re-use it
|
|
||||||
seq_length = inputs["input_ids"].shape[1]
|
|
||||||
model.generation_config.max_length = seq_length + num_tokens_to_generate
|
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
|
||||||
|
|
||||||
# Copied from the gpt-fast repo
|
|
||||||
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
|
||||||
q = torch.empty_like(probs_sort).exponential_(1)
|
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
||||||
|
|
||||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
||||||
logits = logits / max(temperature, 1e-5)
|
|
||||||
|
|
||||||
if top_k is not None:
|
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
|
||||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
||||||
return probs
|
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
||||||
probs = logits_to_probs(logits[:, -1], temperature, top_k)
|
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
|
||||||
return idx_next, probs
|
|
||||||
|
|
||||||
def decode_one_token(model, cur_token, cache_position, past_key_values):
|
|
||||||
logits = model(
|
|
||||||
cur_token,
|
|
||||||
cache_position=cache_position,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
return_dict=False,
|
|
||||||
use_cache=True,
|
|
||||||
)[0]
|
|
||||||
new_token = sample(logits, temperature=0.6, top_k=5)[0]
|
|
||||||
return new_token
|
|
||||||
|
|
||||||
#########
|
|
||||||
# Eager #
|
|
||||||
#########
|
|
||||||
with torch.no_grad():
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + num_tokens_to_generate,
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(seq_length, device=device)
|
|
||||||
start = perf_counter()
|
|
||||||
model(
|
|
||||||
**inputs,
|
|
||||||
cache_position=cache_position,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
return_dict=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
end = perf_counter()
|
|
||||||
first_eager_fwd_pass_time = end - start
|
|
||||||
logger.info(f"completed first eager fwd pass in: {first_eager_fwd_pass_time}s")
|
|
||||||
start = perf_counter()
|
|
||||||
output = model.generate(**inputs, do_sample=False)
|
|
||||||
end = perf_counter()
|
|
||||||
first_eager_generate_time = end - start
|
|
||||||
logger.info(f"completed first eager generation in: {first_eager_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + num_tokens_to_generate,
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(seq_length, device=device)
|
|
||||||
start = perf_counter()
|
|
||||||
model(
|
|
||||||
**inputs,
|
|
||||||
cache_position=cache_position,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
return_dict=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
end = perf_counter()
|
|
||||||
second_eager_fwd_pass_time = end - start
|
|
||||||
logger.info(f"completed second eager fwd pass in: {second_eager_fwd_pass_time}s")
|
|
||||||
start = perf_counter()
|
|
||||||
model.generate(**inputs, do_sample=False)
|
|
||||||
end = perf_counter()
|
|
||||||
second_eager_generate_time = end - start
|
|
||||||
logger.info(f"completed second eager generation in: {second_eager_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
torch.compiler.reset()
|
|
||||||
|
|
||||||
################
|
|
||||||
# Forward pass #
|
|
||||||
################
|
|
||||||
|
|
||||||
# `torch.compile(model, ...)` is not recommended as you compile callbacks
|
|
||||||
# and full generate. We recommend compiling only the forward for now.
|
|
||||||
# "reduce-overhead" will use cudagraphs.
|
|
||||||
generated_ids = torch.zeros(
|
|
||||||
(batch_size, num_tokens_to_generate + seq_length), dtype=torch.int, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
generated_ids[:, :seq_length] = inputs["input_ids"]
|
|
||||||
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
|
||||||
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
|
||||||
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(seq_length, device=device)
|
|
||||||
all_generated_tokens = []
|
|
||||||
### First compile, prefill
|
|
||||||
start = perf_counter()
|
|
||||||
next_token = decode_one_token(
|
|
||||||
model, inputs["input_ids"], cache_position=cache_position, past_key_values=past_key_values
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
time_to_first_token = end - start
|
|
||||||
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
|
||||||
cache_position += 1
|
|
||||||
all_generated_tokens += next_token.tolist()
|
|
||||||
|
|
||||||
cache_position = torch.tensor([seq_length], device=device)
|
|
||||||
### First compile, decoding
|
|
||||||
start = perf_counter()
|
|
||||||
next_token = decode_one_token(
|
|
||||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
time_to_second_token = end - start
|
|
||||||
logger.info(f"completed second compile generation in: {time_to_second_token}s")
|
|
||||||
cache_position += 1
|
|
||||||
all_generated_tokens += next_token.tolist()
|
|
||||||
|
|
||||||
### Second compile, decoding
|
|
||||||
start = perf_counter()
|
|
||||||
next_token = decode_one_token(
|
|
||||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
time_to_third_token = end - start
|
|
||||||
logger.info(f"completed third compile forward in: {time_to_third_token}s")
|
|
||||||
cache_position += 1
|
|
||||||
all_generated_tokens += next_token.tolist()
|
|
||||||
|
|
||||||
### Using cuda graphs decoding
|
|
||||||
|
|
||||||
start = perf_counter()
|
|
||||||
for _ in range(1, num_tokens_to_generate):
|
|
||||||
all_generated_tokens += next_token.tolist()
|
|
||||||
next_token = decode_one_token(
|
|
||||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
||||||
)
|
|
||||||
cache_position += 1
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
mean_time_to_next_token = (end - start) / num_tokens_to_generate
|
|
||||||
logger.info(f"completed next compile generation in: {mean_time_to_next_token}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(all_generated_tokens)}")
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Generate compile #
|
|
||||||
####################
|
|
||||||
torch.compiler.reset()
|
|
||||||
# we will not compile full generate as it' s to intensive, tho we measure full forward!
|
|
||||||
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + 128,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1st call
|
|
||||||
start = perf_counter()
|
|
||||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
first_compile_generate_time = end - start
|
|
||||||
logger.info(f"completed first compile generation in: {first_compile_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + 128,
|
|
||||||
)
|
|
||||||
# 2nd call
|
|
||||||
start = perf_counter()
|
|
||||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = perf_counter()
|
|
||||||
second_compile_generate_time = end - start
|
|
||||||
logger.info(f"completed second compile generation in: {second_compile_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + 128,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3rd call
|
|
||||||
start = perf_counter()
|
|
||||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
||||||
end = perf_counter()
|
|
||||||
third_compile_generate_time = end - start
|
|
||||||
logger.info(f"completed third compile generation in: {third_compile_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
past_key_values = StaticCache(
|
|
||||||
model.config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
max_cache_len=seq_length + 128,
|
|
||||||
)
|
|
||||||
# 4th call
|
|
||||||
start = perf_counter()
|
|
||||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
||||||
end = perf_counter()
|
|
||||||
fourth_compile_generate_time = end - start
|
|
||||||
logger.info(f"completed fourth compile generation in: {fourth_compile_generate_time}s")
|
|
||||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
||||||
|
|
||||||
metrics_recorder.collect_model_measurements(
|
|
||||||
benchmark_id,
|
|
||||||
{
|
|
||||||
"model_load_time": model_load_time,
|
|
||||||
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
|
|
||||||
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
|
|
||||||
"first_eager_generate_time_secs": first_eager_generate_time,
|
|
||||||
"second_eager_generate_time_secs": second_eager_generate_time,
|
|
||||||
"time_to_first_token_secs": time_to_first_token,
|
|
||||||
"time_to_second_token_secs": time_to_second_token,
|
|
||||||
"time_to_third_token_secs": time_to_third_token,
|
|
||||||
"time_to_next_token_mean_secs": mean_time_to_next_token,
|
|
||||||
"first_compile_generate_time_secs": first_compile_generate_time,
|
|
||||||
"second_compile_generate_time_secs": second_compile_generate_time,
|
|
||||||
"third_compile_generate_time_secs": third_compile_generate_time,
|
|
||||||
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Caught exception: {e}")
|
|
||||||
continue_metric_collection.set()
|
|
||||||
if metrics_thread is not None:
|
|
||||||
metrics_thread.join()
|
|
||||||
metrics_recorder.close()
|
|
@ -2,4 +2,5 @@ gpustat==1.1.1
|
|||||||
psutil==6.0.0
|
psutil==6.0.0
|
||||||
psycopg2==2.9.9
|
psycopg2==2.9.9
|
||||||
torch>=2.4.0
|
torch>=2.4.0
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
pandas>=1.5.0
|
0
benchmark/utils/init_db.sql
Normal file
0
benchmark/utils/init_db.sql
Normal file
Reference in New Issue
Block a user