Update the operator benchmarking, to benchmark using torch.compile (#161394)

This pull request enhances the PyTorch operator benchmarking suite by introducing support for benchmarking with `torch.compile` mode, in addition to existing Eager and JIT. It also adds peak memory measurement (fwd/bwd pass); improves the output format in JSON to be used by dashboard for reporting; and introduce some more CLI options. The new CLI flags introduced are:

- Added `--use-compile` CLI argument and corresponding logic to run benchmarks using `torch.compile`, including mutual exclusivity with `--use-jit`
- Added `--benchmark-name` argument for customizing the benchmark name in output
- Updated default value for `--output-json-for-dashboard` to `benchmark-results.json` for more predictable output file name

Sample command to run a single operator:
`python -m pt.mm_test --use-compile`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161394
Approved by: https://github.com/jbschlosser
This commit is contained in:
jainapurva
2025-09-09 18:17:32 +00:00
committed by PyTorch MergeBot
parent 82f1eb9b03
commit af60398c3a
3 changed files with 154 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import csv
import functools
import json
import os
import platform
import timeit
from collections import namedtuple
from dataclasses import asdict, dataclass
@ -191,6 +192,11 @@ class BenchmarkRunner:
self.predefined_minimum_secs = 1
self.max_iters = 1e6
self.use_jit = args.use_jit
self.use_compile = args.use_compile
if self.use_jit and self.use_compile:
raise ValueError(
"use_jit and use_compile are mutually exclusive, please specify one."
)
self.num_runs = args.num_runs
self.print_per_iter = False
self.output_csv = args.output_csv
@ -222,7 +228,7 @@ class BenchmarkRunner:
if self.args.operators:
print(f"# {self.args.operators}")
def _print_perf_result(self, reported_run_time_us, test_case):
def _print_perf_result(self, results, test_case):
if self.args.report_aibench:
# Output for AIBench
# Print out per iteration execution time instead of avg time
@ -236,12 +242,14 @@ class BenchmarkRunner:
"type": test_name,
"metric": "latency",
"unit": "us",
"value": str(reported_run_time_us[run]),
"value": str(results["reported_run_time_us"[run]]),
}
)
)
else:
print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}")
print(
f"# Mode: {'JIT' if self.use_jit else 'Compile' if self.use_compile else 'Eager'}"
)
print(
f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}"
)
@ -250,25 +258,33 @@ class BenchmarkRunner:
if self.num_runs > 1:
for run in range(self.num_runs):
print(
f"Run: {run}, {mode} Execution Time (us) : {reported_run_time_us[run]:.3f}"
f"Run: {run}, {mode} Execution Time (us) : {results['reported_run_time_us'][run]:.3f}"
)
print()
else:
print(f"{mode} Execution Time (us) : {reported_run_time_us[0]:.3f}\n")
print(
f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}"
)
print(f"Peak Memory (KB) : {results['peak_memory']}\n")
def _perf_result_to_dict(self, reported_run_time_us, test_case):
def _perf_result_to_dict(self, results, test_case):
"""This function is the parallel of _print_perf_result, which instead of
writing information to terminal, returns a dictionary.
"""
if self.args.report_aibench:
return {}
out = {
"test_name": test_case.test_config.test_name,
"input_config": test_case.test_config.input_config,
"mode": "JIT" if self.use_jit else "Eager",
"runtime": (
"JIT" if self.use_jit else "Compile" if self.use_compile else "Eager"
),
"run": "Backward" if test_case.test_config.run_backward else "Forward",
"latency": round(reported_run_time_us[0], 3),
"latency": round(results["reported_run_time_us"][0], 3),
"latency unit": "us",
"peak memory": results["peak_memory"],
"memory unit": "KB",
}
# parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary
@ -330,6 +346,8 @@ class BenchmarkRunner:
func = test_case.run_forward
if self.use_jit:
func = test_case.run_jit_forward
if self.use_compile:
func = test_case.run_compile_forward
forward_time = timeit.timeit(
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
)
@ -346,7 +364,7 @@ class BenchmarkRunner:
)
return backward_time
def _measure_time(self, launch_test, test_case, iters, print_per_iter):
def _measure_metrics(self, launch_test, test_case, iters, print_per_iter):
"""
This function execute the operator for <iters> iterations then look at the time.
If it's not significant, the number of iterations will be increased before rerun.
@ -354,8 +372,20 @@ class BenchmarkRunner:
"""
curr_test_total_time = 0
time_trace = []
peak_memory = 0
sample_input = next(iter(test_case.op_bench.inputs.values()))
device = sample_input.device
device_module = torch.get_device_module(device.type)
# TODO: add support for cpu memory measurement
while True:
if hasattr(device_module, "reset_peak_memory_stats"):
device_module.reset_peak_memory_stats(device)
run_time_sec = launch_test(test_case, iters, print_per_iter)
if hasattr(device_module, "synchronize"):
device_module.synchronize(device)
# Memory measurement process
if hasattr(device_module, "max_memory_allocated"):
peak_memory = device_module.max_memory_allocated(device)
curr_test_total_time += run_time_sec
# Analyze time after each run to decide if the result is stable
results_are_significant = self._iteration_result_is_significant(
@ -369,7 +399,13 @@ class BenchmarkRunner:
time_trace.append(report_run_time)
# Print out the time spent in each epoch in ms
if self.args.report_aibench:
mode = "JIT" if self.use_jit else "Eager"
mode = (
"JIT"
if self.use_jit
else "Compile"
if self.use_compile
else "Eager"
)
test_name = "_".join(
[test_case.framework, test_case.test_config.test_name, mode]
)
@ -381,7 +417,7 @@ class BenchmarkRunner:
"metric": "latency",
"unit": "ms",
"value": str(report_run_time / 1e3),
}
},
)
)
if results_are_significant:
@ -391,7 +427,7 @@ class BenchmarkRunner:
# iteration count, and run the benchmark again...
iters = self._predict_num_iter_needed(iters)
reported_run_time_us = np.percentile(np.array(time_trace), 50)
return reported_run_time_us
return reported_run_time_us, peak_memory / 1024
def _check_keep(self, test_flag, cmd_flag):
return cmd_flag is None or test_flag == cmd_flag
@ -478,6 +514,7 @@ class BenchmarkRunner:
self,
perf_list,
output_file,
benchmark_name="PyTorch operator benchmark",
):
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
@ -495,8 +532,10 @@ class BenchmarkRunner:
input_config = perf_item.get("input_config", "")
run_type = perf_item.get("run")
latency = perf_item.get("latency", 0)
dtype = "float32" # default
peak_memory = perf_item.get("peak memory", 0)
device = perf_item.get("device", "unknown")
dtype = perf_item.get("dtype", "torch.float").split(".")[1]
runtime = perf_item.get("runtime", None)
# Extract mode based on run_type
mode = None
@ -505,6 +544,22 @@ class BenchmarkRunner:
elif run_type == "Backward":
mode = "training"
# Extract use_compile from it
if runtime == "Compile":
use_compile = True
elif runtime == "Eager":
use_compile = False
else:
use_compile = None
device_arch = (
torch.cuda.get_device_name(0)
if device == "cuda"
else platform.processor()
if device == "cpu"
else "unknown"
)
# Create the record
@dataclass
class BenchmarkInfo:
@ -532,12 +587,18 @@ class BenchmarkRunner:
model: ModelInfo
metric: MetricInfo
record = BenchmarkRecord(
# Add record for latency
record_latency = BenchmarkRecord(
benchmark=BenchmarkInfo(
name="PyTorch operator benchmark",
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info={"input_config": input_config},
extra_info={
"input_config": input_config,
"device": device,
"arch": device_arch,
"use_compile": use_compile,
},
),
model=ModelInfo(
name=test_name, type="micro-benchmark", origins=["pytorch"]
@ -549,8 +610,17 @@ class BenchmarkRunner:
target_value=None,
),
)
records.append(asdict(record_latency))
records.append(asdict(record))
# Add record for peak memory
record_memory = copy.deepcopy(record_latency)
record_memory.metric = MetricInfo(
name="peak memory",
unit="KB",
benchmark_values=[peak_memory],
target_value=None,
)
records.append(asdict(record_memory))
# Write all records to the output file
with open(output_file, "w", encoding="utf-8") as f:
@ -566,6 +636,7 @@ class BenchmarkRunner:
"tag",
"run_backward",
"Execution Time",
"Peak Memory (KB)",
]
if self.args.output_json or self.args.output_json_for_dashboard:
@ -603,13 +674,16 @@ class BenchmarkRunner:
test_case, self.args.warmup_iterations, print_per_iter=False
)
# Actual Execution
reported_time = [
self._measure_time(
results = [
self._measure_metrics(
launch_func, test_case, self.iters, self.print_per_iter
)
for _ in range(self.num_runs)
]
self._print_perf_result(reported_time, test_case)
result_dict = dict()
result_dict["reported_run_time_us"] = [r[0] for r in results]
result_dict["peak_memory"] = results[0][1]
self._print_perf_result(results=result_dict, test_case=test_case)
# output results to csv
self._output_csv(
@ -625,16 +699,17 @@ class BenchmarkRunner:
),
test_case.test_config.tag,
test_case.test_config.run_backward,
reported_time[0],
result_dict["reported_run_time_us"][0],
result_dict["peak_memory"],
],
)
if self.args.output_json or self.args.output_json_for_dashboard:
perf_list.append(
self._perf_result_to_dict(reported_time, test_case)
)
perf_list.append(self._perf_result_to_dict(result_dict, test_case))
if self.args.output_json_for_dashboard:
self._output_json(perf_list, self.args.output_json_for_dashboard)
self._output_json(
perf_list, self.args.output_json_for_dashboard, self.args.benchmark_name
)
if self.args.output_json:
with open(self.args.output_json, "w") as f:

View File

@ -4,6 +4,15 @@ import time
import torch
# Import the C++ extension to register the _consume operator
try:
import benchmark_cpp_extension # noqa: F401
except ImportError as err:
# If the extension isn't built, the script must raise an error
raise ImportError(
"Failed to import C++ extension, please build it using \ncd pt_extension \npython -m pip install ."
) from err
"""PyTorch performance microbenchmarks.
This module contains PyTorch-specific functionalities for performance
@ -71,6 +80,16 @@ class TorchBenchmarkBase(torch.nn.Module):
for _ in range(iters):
torch.ops.operator_benchmark._consume(self.forward_impl())
def forward_impl_eager(self):
# This is to supply the inputs to the forward function which
# will be called in both the eager and compile mode of local runs
return self.forward(*self.get_inputs())
def forward_consume_eager(self, iters: int):
# Eager version of forward_consume without decorators (compilation handled by torch.compile)
for _ in range(iters):
torch.ops.operator_benchmark._consume(self.forward_impl_eager())
def module_name(self):
"""this is used to label the operator being benchmarked"""
if self.user_given_name:
@ -117,18 +136,32 @@ class PyTorchOperatorTestCase:
self.framework = "PyTorch"
self.time_series = []
self._jit_forward_graph = None
self._compile_forward_graph = None
def _generate_jit_forward_graph(self):
"""generate a graph for the forward function via scripting"""
scripted_op_bench = torch.jit.script(self.op_bench)
return scripted_op_bench.forward_consume
def _generate_compile_forward_graph(self):
"""generate a compiled graph for the forward function via torch.compile"""
compiled_forward_consume = torch.compile(
self.op_bench.forward_consume_eager, backend="inductor"
)
return compiled_forward_consume
def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
"""Run the forward path of an op with JIT mode"""
if self._jit_forward_graph is None:
self._jit_forward_graph = self._generate_jit_forward_graph()
self._jit_forward_graph(num_runs)
def run_compile_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
"""Run the forward path of an op with compile mode"""
if self._compile_forward_graph is None:
self._compile_forward_graph = self._generate_compile_forward_graph()
self._compile_forward_graph(num_runs)
def _print_per_iter(self):
# print last 50 values
length = min(len(self.time_series), 50)
@ -150,14 +183,14 @@ class PyTorchOperatorTestCase:
if print_per_iter:
for _ in range(num_runs):
start_time = time.time()
self.output = self.op_bench.forward_impl()
self.output = self.op_bench.forward_impl_eager()
if cuda_sync:
torch.cuda.synchronize(torch.cuda.current_device())
end_time = time.time()
self.time_series.append((end_time - start_time) * 1e3)
else:
for _ in range(num_runs):
self.output = self.op_bench.forward_impl()
self.output = self.op_bench.forward_impl_eager()
if cuda_sync:
torch.cuda.synchronize(torch.cuda.current_device())

View File

@ -62,6 +62,13 @@ def parse_args():
default=None,
)
parser.add_argument(
"--benchmark-name",
"--benchmark_name",
help="Name of the benchmark to store results to",
default="PyTorch operator benchmark",
)
parser.add_argument(
"--list-tests",
"--list_tests",
@ -135,6 +142,16 @@ def parse_args():
help="Run operators with PyTorch JIT mode",
)
parser.add_argument(
"--use-compile",
"--use_compile",
type=benchmark_utils.str2bool,
nargs="?",
const=True,
default=False,
help="Run operators with PyTorch Compile mode",
)
parser.add_argument(
"--forward-only",
"--forward_only",
@ -162,7 +179,7 @@ def parse_args():
"--output-json-for-dashboard",
"--output_json_for_dashboard",
help="Save results in JSON format for display on the OSS dashboard",
default="False",
default="benchmark-results.json",
)
args, _ = parser.parse_known_args()