Files
pytorch/benchmarks/dynamo/cachebench.py
Boyuan Feng 90b4e130d6 [Benchmark] cleanup torchbench models (#164816)
Prune models from TorchInductor dashboard to reduce ci cost. This PR prunes torchbench models according to the [doc](https://docs.google.com/document/d/1nLPNNAU-_M9Clx9FMrJ1ycdPxe-xRA54olPnsFzdpoU/edit?tab=t.0), which removes timm and huggingface models from torchbench.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164816
Approved by: https://github.com/anijain2305, https://github.com/seemethere, https://github.com/huydhn, https://github.com/malfet
2025-10-09 00:31:25 +00:00

255 lines
6.9 KiB
Python

import argparse
import dataclasses
import json
import logging
import os
import subprocess
import sys
import tempfile
from typing import Callable
from torch._inductor.utils import fresh_cache
logger: logging.Logger = logging.getLogger(__name__)
TIMEOUT: int = 2000
# Keep in sync with .ci/pytorch/test.sh
TORCHBENCH_MODELS: list[str] = [
"nanogpt",
"BERT_pytorch",
"resnet50",
"moco",
"llama",
]
HUGGINGFACE_MODELS: list[str] = [
"AllenaiLongformerBase",
"BertForMaskedLM",
"GPT2ForSequenceClassification",
]
@dataclasses.dataclass
class RunResult:
model: str
mode: str # inference or training
benchmark: str
dynamic: bool
device: str # cuda or cpu
cold_compile_s: list[float]
warm_compile_s: list[float]
speedup_pct: float
def get_compile_time(file: tempfile._TemporaryFileWrapper) -> float:
lines = file.readlines()
# Decode from byte string, remove new lines, parse csv
lines = [line.decode("utf-8").strip().split(",") for line in lines]
compilation_time_idx = lines[0].index("compilation_latency")
compilation_time = lines[1][compilation_time_idx]
return float(compilation_time)
def _run_torchbench_from_args(
cmd_args: argparse.Namespace,
model: str,
args: list[str],
) -> tuple[list[float], list[float]]:
cold_compile_time: list[float] = []
warm_compile_time: list[float] = []
for _ in range(cmd_args.repeat):
with fresh_cache():
env = os.environ.copy()
with tempfile.NamedTemporaryFile(suffix=".csv") as file:
args.append("--output=" + file.name)
logger.info(f"Performing cold-start run for {model}") # noqa: G004
subprocess.check_call(args, timeout=TIMEOUT, env=env)
cold_compile_time.append(get_compile_time(file))
args.pop()
with tempfile.NamedTemporaryFile(suffix=".csv") as file:
args.append("--output=" + file.name)
logger.info(f"Performing warm-start run for {model}") # noqa: G004
subprocess.check_call(args, timeout=TIMEOUT, env=env)
warm_compile_time.append(get_compile_time(file))
return cold_compile_time, warm_compile_time
MODE_ARGS_DICT = {
"inference": ["--inference", "--bfloat16"],
"training": ["--training", "--amp"],
}
BENCHMARK_FILE = {
"torchbench": "torchbench.py",
"huggingface": "huggingface.py",
}
def _run_torchbench_model(
cmd_args: argparse.Namespace,
results: list[RunResult],
model: str,
) -> None:
cur_file = os.path.abspath(__file__)
torchbench_file = os.path.join(
os.path.dirname(cur_file), BENCHMARK_FILE[cmd_args.benchmark]
)
assert os.path.exists(torchbench_file), (
f"Torchbench does not exist at {torchbench_file}"
)
dynamic = cmd_args.dynamic
dynamic_args = ["--dynamic-shapes", "--dynamic-batch-only"] if dynamic else []
args = (
[
sys.executable,
torchbench_file,
f"--only={model}",
"--repeat=1",
"--performance",
"--backend=inductor",
f"--device={cmd_args.device}",
]
+ MODE_ARGS_DICT[cmd_args.mode]
+ dynamic_args
)
logger.info(f"Command: {args}") # noqa: G004
try:
cold_compile_t, warm_compile_t = _run_torchbench_from_args(
cmd_args, model, args
)
speedup_pct = (1 - (sum(warm_compile_t) / sum(cold_compile_t))) * 100
results.append(
RunResult(
model=model,
mode=cmd_args.mode,
benchmark=cmd_args.benchmark,
dynamic=dynamic,
device=cmd_args.device,
cold_compile_s=cold_compile_t,
warm_compile_s=warm_compile_t,
speedup_pct=speedup_pct,
)
)
except Exception:
logger.info("fail", exc_info=True)
return None
def _write_results_to_json(
cmd_args: argparse.Namespace,
results: list[RunResult],
) -> None:
if len(results) == 0:
# do not write empty results
return
records = []
for result in results:
for metric_name, value in [
("Cold compile time (s)", result.cold_compile_s),
("Warm compile time (s)", result.warm_compile_s),
("Speedup (%)", [result.speedup_pct]),
]:
records.append(
{
"benchmark": {
"name": "TorchCache Benchmark",
"mode": result.mode,
"extra_info": {
"is_dynamic": result.dynamic,
"device": result.device,
},
},
"model": {
"name": result.model,
"backend": "inductor",
"origins": [result.benchmark],
},
"metric": {
"name": metric_name,
"type": "OSS model",
"benchmark_values": value,
},
}
)
with open(cmd_args.output, "w") as f:
json.dump(records, f)
def parse_cmd_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run a TorchCache benchmark.")
parser.add_argument(
"-m",
"--model",
help="Name of the model to run",
)
parser.add_argument(
"--dynamic",
action="store_true",
help="Whether to run with dynamic enabled",
)
parser.add_argument(
"--benchmark",
choices=("torchbench", "huggingface"),
required=True,
help="Name of benchmark suite to run",
)
parser.add_argument(
"--mode",
choices=("inference", "training"),
default="training",
)
parser.add_argument(
"--device",
default="cuda",
choices=("cuda", "cpu"),
)
parser.add_argument(
"--output",
required=True,
help="The output filename (json)",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
choices=range(1, 10),
help="Number of times to repeat the compilation (reduce noise)",
)
args, _ = parser.parse_known_args()
return args
Dispatch_fn_t = Callable[[argparse.Namespace, list[RunResult], str], None]
def main() -> None:
cmd_args = parse_cmd_args()
dispatcher: dict[str, tuple[Dispatch_fn_t, list[str]]] = {
"torchbench": (_run_torchbench_model, TORCHBENCH_MODELS),
"huggingface": (_run_torchbench_model, HUGGINGFACE_MODELS),
}
fn, models = dispatcher[cmd_args.benchmark]
if cmd_args.model is not None:
models = [cmd_args.model]
results: list[RunResult] = []
for model in models:
fn(cmd_args, results, model)
_write_results_to_json(cmd_args, results)
if __name__ == "__main__":
main()