Files
pytorch/benchmarks/instruction_counts/applications/ci.py
2025-06-17 13:12:18 +00:00

86 lines
2.4 KiB
Python

"""Collect instruction counts for continuous integration."""
# mypy: ignore-errors
import argparse
import hashlib
import json
import time
from typing import Union
from core.expand import materialize
from definitions.standard import BENCHMARKS
from execution.runner import Runner
from execution.work import WorkOrder
REPEATS = 5
TIMEOUT = 600 # Seconds
RETRIES = 2
VERSION = 0
MD5 = "4d55e8abf881ad38bb617a96714c1296"
def main(argv: list[str]) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--destination", type=str, default=None)
parser.add_argument("--subset", action="store_true")
args = parser.parse_args(argv)
t0 = int(time.time())
version = VERSION
benchmarks = materialize(BENCHMARKS)
# Useful for local development, since e2e time for the full suite is O(1 hour)
in_debug_mode = args.subset or args.destination is None
if args.subset:
version = -1
benchmarks = benchmarks[:10]
work_orders = tuple(
WorkOrder(label, autolabels, timer_args, timeout=TIMEOUT, retries=RETRIES)
for label, autolabels, timer_args in benchmarks * REPEATS
)
keys = tuple({str(work_order): None for work_order in work_orders}.keys())
md5 = hashlib.md5(usedforsecurity=False)
for key in keys:
md5.update(key.encode("utf-8"))
# Warn early, since collection takes a long time.
if md5.hexdigest() != MD5 and not args.subset:
version = -1
print(f"WARNING: Expected {MD5}, got {md5.hexdigest()} instead")
results = Runner(work_orders, cadence=30.0).run()
# TODO: Annotate with TypedDict when 3.8 is the minimum supported version.
grouped_results: dict[str, dict[str, list[Union[float, int]]]] = {
key: {"times": [], "counts": []} for key in keys
}
for work_order, r in results.items():
key = str(work_order)
grouped_results[key]["times"].extend(r.wall_times)
grouped_results[key]["counts"].extend(r.instructions)
final_results = {
"version": version,
"md5": md5.hexdigest(),
"start_time": t0,
"end_time": int(time.time()),
"values": grouped_results,
}
if args.destination:
with open(args.destination, "w") as f:
json.dump(final_results, f)
if in_debug_mode:
result_str = json.dumps(final_results)
print(f"{result_str[:30]} ... {result_str[-30:]}\n")
import pdb
pdb.set_trace()