[cutlass backend] add teraflops and increase rep for benchmark script (#154944)

Differential Revision: [D75840023](https://our.internmc.facebook.com/intern/diff/D75840023/)

I think I will continue to use do_bench for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154944
Approved by: https://github.com/mlazos
This commit is contained in:
henrylhtsang
2025-06-04 14:26:12 -07:00
committed by PyTorch MergeBot
parent be2ab96347
commit 2481c4b2ea

View File

@ -1,4 +1,5 @@
import os
import sys
os.environ["TORCH_LOGS"] = "inductor"
@ -32,6 +33,7 @@ inductor_config.autotune_local_cache = False
UNITS = {
"name": "",
"forward_time": " (us)",
"teraflops": " (TFLOPS)",
"compilation_time": " (s)",
}
PERF_OVER_ATEN_STR: str = "perf_over_aten (%)"
@ -75,7 +77,7 @@ CUTLASS_INSTANTIATION_LEVELS = [
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
return do_bench(lambda: func(*args, **kwargs)) * 1e3
return do_bench(lambda: func(*args, **kwargs), warmup=100, rep=10000) * 1e3
@dataclass(frozen=True, kw_only=True)
@ -162,6 +164,7 @@ class ExperimentGroupConfig:
class ExperimentResults:
name: str
forward_time: float
teraflops: float
compilation_time: float
def asdict(self):
@ -211,7 +214,10 @@ def run_single_experiment_group(
for config in group_config.experiments:
torch._dynamo.reset()
torch._inductor.utils.clear_inductor_caches()
compiled_op = torch.compile(op, fullgraph=True, options=config.to_options())
compiled_op = torch.compile(
op,
options=config.to_options(),
)
start_time = time.perf_counter()
try:
@ -227,6 +233,7 @@ def run_single_experiment_group(
ExperimentResults(
name=config.name(),
forward_time=float("inf"),
teraflops=0.0,
compilation_time=float("inf"),
)
)
@ -238,10 +245,18 @@ def run_single_experiment_group(
*inputs,
)
flops = calculate_flops(
group_config.op_name,
group_config.shape,
group_config.batch_size,
)
teraflops = flops / (forward_time * 1e-6) / 1e12
results.append(
ExperimentResults(
name=config.name(),
forward_time=forward_time,
teraflops=teraflops,
compilation_time=compilation_time,
)
)
@ -336,6 +351,20 @@ def calculate_table_data(results: list[ExperimentResults]) -> dict:
return table_data
def calculate_flops(op_name: str, shape: tuple[int, int, int], batch_size: int) -> int:
"""
Calculate the number of floating point operations based on operation type and shape.
"""
M, N, K = shape
if op_name == "bmm":
return 2 * batch_size * M * N * K
elif op_name == "addmm":
return 2 * M * N * K + M * N
else:
return 2 * M * N * K
def get_printable_results(experiment_groups: list[ExperimentGroup]) -> list[str]:
edge_over_aten = defaultdict(list)
output = []
@ -390,8 +419,10 @@ def main():
results.append(
ExperimentGroup(config=group_config, results=group_results),
)
log.info(f"\nINTERMEDIATE results: {i}/{len(configs)}") # noqa: G004
log.info(get_printable_results(results))
sys.stderr.write(
f"\nINTERMEDIATE results: {i + 1}/{len(configs)} \n"
+ get_printable_results(results)
)
print("\nFINAL results...")
print(get_printable_results(results))