mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156077 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #156069
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
"""Collect instruction counts for continuous integration."""
|
|
|
|
# mypy: ignore-errors
|
|
|
|
import argparse
|
|
import hashlib
|
|
import json
|
|
import time
|
|
from typing import Union
|
|
|
|
from core.expand import materialize
|
|
from definitions.standard import BENCHMARKS
|
|
from execution.runner import Runner
|
|
from execution.work import WorkOrder
|
|
|
|
|
|
REPEATS = 5
|
|
TIMEOUT = 600 # Seconds
|
|
RETRIES = 2
|
|
|
|
VERSION = 0
|
|
MD5 = "4d55e8abf881ad38bb617a96714c1296"
|
|
|
|
|
|
def main(argv: list[str]) -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--destination", type=str, default=None)
|
|
parser.add_argument("--subset", action="store_true")
|
|
args = parser.parse_args(argv)
|
|
|
|
t0 = int(time.time())
|
|
version = VERSION
|
|
benchmarks = materialize(BENCHMARKS)
|
|
|
|
# Useful for local development, since e2e time for the full suite is O(1 hour)
|
|
in_debug_mode = args.subset or args.destination is None
|
|
if args.subset:
|
|
version = -1
|
|
benchmarks = benchmarks[:10]
|
|
|
|
work_orders = tuple(
|
|
WorkOrder(label, autolabels, timer_args, timeout=TIMEOUT, retries=RETRIES)
|
|
for label, autolabels, timer_args in benchmarks * REPEATS
|
|
)
|
|
|
|
keys = tuple({str(work_order): None for work_order in work_orders}.keys())
|
|
md5 = hashlib.md5(usedforsecurity=False)
|
|
for key in keys:
|
|
md5.update(key.encode("utf-8"))
|
|
|
|
# Warn early, since collection takes a long time.
|
|
if md5.hexdigest() != MD5 and not args.subset:
|
|
version = -1
|
|
print(f"WARNING: Expected {MD5}, got {md5.hexdigest()} instead")
|
|
|
|
results = Runner(work_orders, cadence=30.0).run()
|
|
|
|
# TODO: Annotate with TypedDict when 3.8 is the minimum supported version.
|
|
grouped_results: dict[str, dict[str, list[Union[float, int]]]] = {
|
|
key: {"times": [], "counts": []} for key in keys
|
|
}
|
|
|
|
for work_order, r in results.items():
|
|
key = str(work_order)
|
|
grouped_results[key]["times"].extend(r.wall_times)
|
|
grouped_results[key]["counts"].extend(r.instructions)
|
|
|
|
final_results = {
|
|
"version": version,
|
|
"md5": md5.hexdigest(),
|
|
"start_time": t0,
|
|
"end_time": int(time.time()),
|
|
"values": grouped_results,
|
|
}
|
|
|
|
if args.destination:
|
|
with open(args.destination, "w") as f:
|
|
json.dump(final_results, f)
|
|
|
|
if in_debug_mode:
|
|
result_str = json.dumps(final_results)
|
|
print(f"{result_str[:30]} ... {result_str[-30:]}\n")
|
|
import pdb
|
|
|
|
pdb.set_trace()
|