Files
pytorch/benchmarks/inductor_backends/cutlass.py
Henry Tsang 17518007b2 [cutlass backend] Benchmark compared to aten and triton (#148347)
Benchmark for cutlass backend.

```
python benchmarks/inductor_backends/cutlass.py
```

Test Plan:
```
Experiment group: mm (1024x1024, 1024x1024) torch.float16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 12.759539298713207 |  2.7271360370796174  |         NA          |
|        triton         | 10.573655366897583 |  1.8661278090439737  | -17.131370346859384 |
| triton_persistent_tma | 10.884030722081661 |  0.5315794269554317  | -14.698873781600327 |
|  cutlass_lvl_default  | 13.09632882475853  |  0.5520401500398293  | 2.6395116481931873  |
|   cutlass_lvl_1111    | 11.05172373354435  |  0.569593315012753   | -13.384617776451302 |
|   cutlass_lvl_2222    | 11.371277272701263 |  133.58984916994814  | -10.880189272601317 |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (1024x1024, 1024x1024) torch.bfloat16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 14.472318813204765 |  1.5445372510002926  |         NA          |
|        triton         | 10.568295605480671 |  16.583424195996486  | -26.975796056689987 |
| triton_persistent_tma | 10.45411266386509  |  5.830657540936954   | -27.764770809729562 |
|  cutlass_lvl_default  | 12.742593884468079 |  28.994930602959357  | -11.951954286402668 |
|   cutlass_lvl_1111    | 11.522261425852776 |  79.85037935699802   | -20.38413764531163  |
|   cutlass_lvl_2222    | 10.993581265211105 |  132.86601971101481  | -24.037181552548486 |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (2048x2048, 2048x2048) torch.float16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 30.700622126460075 |  2.225986961973831   |         NA          |
|        triton         | 29.17378954589367  |  38.571991189033724  |  -4.97329524553989  |
| triton_persistent_tma | 29.642896726727486 |   7.2848734309664    | -3.4452897904663744 |
|  cutlass_lvl_default  | 29.514770954847336 |  29.819900761009194  | -3.8626291243482167 |
|   cutlass_lvl_1111    | 29.411429539322853 |  23.82907024596352   |  -4.19923929172139  |
|   cutlass_lvl_2222    | 29.57325428724289  |  134.31008586101234  | -3.672133530628152  |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (2048x2048, 2048x2048) torch.bfloat16
+-----------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+----------------------+--------------------+
|         aten          | 30.858177691698074 |  1.181898436974734   |         NA         |
|        triton         | 28.630023822188377 |  39.24473957403097   | -7.220626868414034 |
| triton_persistent_tma | 28.641965240240097 |  5.275042273919098   | -7.181929126210897 |
|  cutlass_lvl_default  | 29.16003204882145  |  29.934022572939284  | -5.503065216107967 |
|   cutlass_lvl_1111    | 28.79570797085762  |  23.948012012057006  | -6.683705504085324 |
|   cutlass_lvl_2222    | 29.02756631374359  |  136.25560767308343  | -5.932337924306467 |
+-----------------------+--------------------+----------------------+--------------------+

Experiment group: mm (8192x8192, 8192x8192) torch.float16
+-----------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+----------------------+--------------------+
|         aten          | 1456.143856048584  |  1.020197194069624   |         NA         |
|        triton         | 1708.2737684249878 |  5.766509635956027   | 17.31490410985819  |
| triton_persistent_tma | 1476.485013961792  |  7.455113030038774   | 1.3969195302177155 |
|  cutlass_lvl_default  | 1583.3594799041748 |  50.408804678940214  | 8.736473620182366  |
|   cutlass_lvl_1111    | 1636.4418268203735 |  82.82403108896688   | 12.381879030898025 |
|   cutlass_lvl_2222    | 1507.5665712356567 |  260.03901409788523  | 3.531430975962381  |
+-----------------------+--------------------+----------------------+--------------------+

Experiment group: mm (8192x8192, 8192x8192) torch.bfloat16
+-----------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+----------------------+--------------------+
|         aten          | 1382.230520248413  |  1.2586536260787398  |         NA         |
|        triton         | 1646.9683647155762 |  5.442052865982987   | 19.15294450447995  |
| triton_persistent_tma | 1423.9195585250854 |  6.515797697938979   | 3.016069871556595  |
|  cutlass_lvl_default  | 1500.9030103683472 |  51.36402789200656   |  8.58557877152115  |
|   cutlass_lvl_1111    | 1446.9740390777588 |  30.65435610699933   | 4.683988515729638  |
|   cutlass_lvl_2222    | 1419.661521911621  |  205.1948991640238   | 2.7080144096717635 |
+-----------------------+--------------------+----------------------+--------------------+
```

Differential Revision: D70147589

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148347
Approved by: https://github.com/drisspg, https://github.com/chenyang78
2025-03-04 01:45:36 +00:00

323 lines
8.4 KiB
Python

import os
os.environ["TORCH_LOGS"] = "inductor"
import itertools
import time
from abc import abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Optional
from tabulate import tabulate
from tqdm import tqdm
from triton.testing import do_bench
import torch
from torch._inductor import config as inductor_config
inductor_config.autotune_num_choices_displayed = None
# force autotuning, but reuse compilation artifacts
inductor_config.autotune_local_cache = False
# uncomment for better debugging
# inductor_config.force_disable_caches = True
UNITS = {
"name": "",
"forward_time": " (us)",
"compilation_time": " (s)",
}
OP_NAMES = ["mm"]
SHAPES = [
# M, N, K
(1024, 1024, 1024),
(2048, 2048, 2048),
(8192, 8192, 8192),
]
DTYPES = [
torch.float16,
torch.bfloat16,
]
# triton knobs
ENABLE_PERSISTENT_TMA_MATMULS = [
False,
True,
]
# cutlass knobs
CUTLASS_INSTANTIATION_LEVELS = [
"0",
"1111",
"2222",
# not ready yet
# "3333",
]
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
return do_bench(lambda: func(*args, **kwargs)) * 1e3
@dataclass(frozen=True, kw_only=True)
class ExperimentConfig:
autotune_fallback_to_aten: bool = False
max_autotune: bool = True
coordinate_descent_tuning: bool = True
max_autotune_gemm_backends: str = "ATEN"
@abstractmethod
def name(self) -> str:
pass
def to_options(self) -> dict[str, Any]:
return {
"autotune_fallback_to_aten": self.autotune_fallback_to_aten,
"max_autotune": self.max_autotune,
"coordinate_descent_tuning": self.coordinate_descent_tuning,
"max_autotune_gemm_backends": self.max_autotune_gemm_backends,
}
@dataclass(frozen=True, kw_only=True)
class AtenExperimentConfig(ExperimentConfig):
def name(self) -> str:
return "aten"
@dataclass(frozen=True, kw_only=True)
class CutlassExperimentConfig(ExperimentConfig):
cutlass_instantiation_level: str
def name(self) -> str:
level_name = (
self.cutlass_instantiation_level
if self.cutlass_instantiation_level != "0"
else "default"
)
return f"cutlass_lvl_{level_name}"
def to_options(self) -> dict[str, Any]:
return {
**super().to_options(),
"cuda.cutlass_instantiation_level": self.cutlass_instantiation_level,
}
@dataclass(frozen=True, kw_only=True)
class TritonExperimentConfig(ExperimentConfig):
enable_persistent_tma_matmul: bool = False
def name(self) -> str:
if self.enable_persistent_tma_matmul:
return "triton_persistent_tma"
else:
return "triton"
def to_options(self) -> dict[str, Any]:
return {
**super().to_options(),
"triton.enable_persistent_tma_matmul": self.enable_persistent_tma_matmul,
}
@dataclass(frozen=True, kw_only=True)
class ExperimentGroupConfig:
op_name: str
shape: tuple[int, int, int]
dtype: torch.dtype
experiments: list[ExperimentConfig] = field(default_factory=list)
def name(self) -> str:
M, N, K = self.shape
sizes = f"({M}x{K}, {K}x{N})"
return f"{self.op_name} {sizes} {self.dtype}"
@dataclass(frozen=True, kw_only=True)
class ExperimentResults:
name: str
forward_time: float
compilation_time: float
def asdict(self):
return asdict(self)
@dataclass(frozen=True, kw_only=True)
class ExperimentGroup:
config: ExperimentGroupConfig
results: list[ExperimentResults] = field(default_factory=list)
def get_inputs(
config: ExperimentGroupConfig,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
op_name = config.op_name
M, N, K = config.shape
dtype = config.dtype
device = torch.device("cuda")
if op_name == "mm":
A = torch.randn(M, K, dtype=dtype, device=device)
B = torch.randn(K, N, dtype=dtype, device=device)
C = None
return A, B, C
else:
raise ValueError(f"Unknown op {op_name}")
def run_single_experiment_group(
group_config: ExperimentGroupConfig,
) -> list[ExperimentResults]:
A, B, C = get_inputs(group_config)
op = getattr(torch, group_config.op_name)
results = []
for config in group_config.experiments:
torch._dynamo.reset()
torch._inductor.utils.clear_inductor_caches()
compiled_op = torch.compile(op, fullgraph=True, options=config.to_options())
start_time = time.perf_counter()
_ = compiled_op(A, B)
compilation_time = time.perf_counter() - start_time
forward_time = benchmark_torch_function_in_microseconds(
compiled_op,
A,
B,
)
results.append(
ExperimentResults(
name=config.name(),
forward_time=forward_time,
compilation_time=compilation_time,
)
)
return results
def generate_experiment_groups(
op_names: list[str],
shapes: list[tuple[int, int, int]],
dtypes: list[torch.dtype],
enable_persistent_tma_matmuls: list[bool],
cutlass_instantiation_levels: list[str],
) -> list[ExperimentGroupConfig]:
groups = []
for op_name, shape, dtype in itertools.product(op_names, shapes, dtypes):
group = ExperimentGroupConfig(
op_name=op_name,
shape=shape,
dtype=dtype,
)
experiments = generate_experiment_configs(
enable_persistent_tma_matmuls, cutlass_instantiation_levels
)
group.experiments.extend(experiments)
groups.append(group)
return groups
def generate_experiment_configs(
enable_persistent_tma_matmuls: list[bool], cutlass_instantiation_levels: list[str]
) -> list[ExperimentConfig]:
configs = []
# add aten configs
configs.append(
AtenExperimentConfig(
max_autotune_gemm_backends="ATEN",
)
)
# add triton configs
for enable_persistent_tma_matmul in enable_persistent_tma_matmuls:
configs.append(
TritonExperimentConfig(
max_autotune_gemm_backends="TRITON",
enable_persistent_tma_matmul=enable_persistent_tma_matmul,
)
)
# add cutlass configs
for cutlass_instantiation_level in cutlass_instantiation_levels:
configs.append(
CutlassExperimentConfig(
max_autotune_gemm_backends="CUTLASS",
cutlass_instantiation_level=cutlass_instantiation_level,
)
)
return configs
def tabulate_group_results(results: list[ExperimentResults]):
table_data = defaultdict(list)
aten_perf: Optional[float] = None
perf_over_aten_str: str = "perf_over_aten (%)"
for experiment_result in results:
for key, value in experiment_result.asdict().items():
assert key in UNITS, f"Unknown key {key}"
table_data[key + UNITS[key]].append(value)
if experiment_result.name == "aten":
aten_perf = experiment_result.forward_time
table_data[perf_over_aten_str].append("NA")
elif aten_perf is not None:
perf_over_aten = (
(experiment_result.forward_time - aten_perf) / aten_perf * 100
)
table_data[perf_over_aten_str].append(perf_over_aten)
else:
# fallback in case aten is not in experiment group
table_data[perf_over_aten_str].append("NA")
return tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")
def print_results(experiment_groups: list[ExperimentGroup]):
for experiment_group in experiment_groups:
group_config_name = experiment_group.config.name()
print(f"\nExperiment group: {group_config_name}")
print(tabulate_group_results(experiment_group.results))
def main():
seed = 123
torch.manual_seed(seed)
results = []
for group_config in tqdm(
generate_experiment_groups(
OP_NAMES,
SHAPES,
DTYPES,
ENABLE_PERSISTENT_TMA_MATMULS,
CUTLASS_INSTANTIATION_LEVELS,
)
):
results.append(
ExperimentGroup(
config=group_config, results=run_single_experiment_group(group_config)
),
)
print_results(results)
if __name__ == "__main__":
main()