mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cutlass backend] Add more logs for cutlass backend benchmark (#150639)
Goal is to have a way to compare if a change make it better or worse. ``` Average edge over aten (max(-edge, 0), higher is better): triton: 8.596507086950552 (from 6 valid values) triton_persistent_tma: 9.517193693923307 (from 6 valid values) cutlass_lvl_default: 3.3234737908691785 (from 6 valid values) cutlass_lvl_1111: 7.088173348313991 (from 6 valid values) cutlass_lvl_2222: 7.291869722320318 (from 6 valid values) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150639 Approved by: https://github.com/ColinPeppler
This commit is contained in:
committed by
PyTorch MergeBot
parent
48b4bc1640
commit
5a51de5ab1
@ -4,6 +4,7 @@ import os
|
||||
os.environ["TORCH_LOGS"] = "inductor"
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
@ -18,6 +19,9 @@ import torch
|
||||
from torch._inductor import config as inductor_config
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
inductor_config.autotune_num_choices_displayed = None
|
||||
# force autotuning, but reuse compilation artifacts
|
||||
inductor_config.autotune_local_cache = False
|
||||
@ -30,6 +34,7 @@ UNITS = {
|
||||
"forward_time": " (us)",
|
||||
"compilation_time": " (s)",
|
||||
}
|
||||
PERF_OVER_ATEN_STR: str = "perf_over_aten (%)"
|
||||
|
||||
OP_NAMES = ["mm"]
|
||||
|
||||
@ -191,7 +196,7 @@ def run_single_experiment_group(
|
||||
try:
|
||||
_ = compiled_op(A, B)
|
||||
except Exception as e:
|
||||
print(f"Benchmark config {config.name()} failed: {e}")
|
||||
log.warning(f"Benchmark config {config.name()} failed: {e}") # noqa: G004
|
||||
results.append(
|
||||
ExperimentResults(
|
||||
name=config.name(),
|
||||
@ -275,10 +280,9 @@ def generate_experiment_configs(
|
||||
return configs
|
||||
|
||||
|
||||
def tabulate_group_results(results: list[ExperimentResults]):
|
||||
def calculate_table_data(results: list[ExperimentResults]) -> dict:
|
||||
table_data = defaultdict(list)
|
||||
aten_perf: Optional[float] = None
|
||||
perf_over_aten_str: str = "perf_over_aten (%)"
|
||||
|
||||
for experiment_result in results:
|
||||
for key, value in experiment_result.asdict().items():
|
||||
@ -287,24 +291,43 @@ def tabulate_group_results(results: list[ExperimentResults]):
|
||||
|
||||
if experiment_result.name == "aten":
|
||||
aten_perf = experiment_result.forward_time
|
||||
table_data[perf_over_aten_str].append("NA")
|
||||
table_data[PERF_OVER_ATEN_STR].append("NA")
|
||||
elif aten_perf is not None:
|
||||
perf_over_aten = (
|
||||
(experiment_result.forward_time - aten_perf) / aten_perf * 100
|
||||
)
|
||||
table_data[perf_over_aten_str].append(perf_over_aten)
|
||||
table_data[PERF_OVER_ATEN_STR].append(perf_over_aten)
|
||||
else:
|
||||
# fallback in case aten is not in experiment group
|
||||
table_data[perf_over_aten_str].append("NA")
|
||||
table_data[PERF_OVER_ATEN_STR].append("NA")
|
||||
|
||||
return tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")
|
||||
return table_data
|
||||
|
||||
|
||||
def print_results(experiment_groups: list[ExperimentGroup]):
|
||||
edge_over_aten = defaultdict(list)
|
||||
for experiment_group in experiment_groups:
|
||||
group_config_name = experiment_group.config.name()
|
||||
print(f"\nExperiment group: {group_config_name}")
|
||||
print(tabulate_group_results(experiment_group.results))
|
||||
|
||||
table_data = calculate_table_data(experiment_group.results)
|
||||
for name, edge in zip(table_data["name"], table_data[PERF_OVER_ATEN_STR]):
|
||||
edge_over_aten[name].append(edge)
|
||||
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
|
||||
|
||||
if "aten" in edge_over_aten:
|
||||
print("\nAverage edge over aten (max(-edge, 0), higher is better):")
|
||||
for name in edge_over_aten:
|
||||
if name != "aten":
|
||||
# calculate average edge over aten, but need to exclude inf
|
||||
values = [
|
||||
max(-v, 0.0)
|
||||
for v in edge_over_aten[name]
|
||||
if v != float("inf") and v != "NA"
|
||||
]
|
||||
valid_count = len(values)
|
||||
average_edge = sum(values) / valid_count if values else "No valid data"
|
||||
print(f"{name}: {average_edge} (from {valid_count} valid values)")
|
||||
|
||||
|
||||
def main():
|
||||
|
Reference in New Issue
Block a user