mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cutlass backend] Benchmark compared to aten and triton (#148347)
Benchmark for cutlass backend. ``` python benchmarks/inductor_backends/cutlass.py ``` Test Plan: ``` Experiment group: mm (1024x1024, 1024x1024) torch.float16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 12.759539298713207 | 2.7271360370796174 | NA | | triton | 10.573655366897583 | 1.8661278090439737 | -17.131370346859384 | | triton_persistent_tma | 10.884030722081661 | 0.5315794269554317 | -14.698873781600327 | | cutlass_lvl_default | 13.09632882475853 | 0.5520401500398293 | 2.6395116481931873 | | cutlass_lvl_1111 | 11.05172373354435 | 0.569593315012753 | -13.384617776451302 | | cutlass_lvl_2222 | 11.371277272701263 | 133.58984916994814 | -10.880189272601317 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (1024x1024, 1024x1024) torch.bfloat16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 14.472318813204765 | 1.5445372510002926 | NA | | triton | 10.568295605480671 | 16.583424195996486 | -26.975796056689987 | | triton_persistent_tma | 10.45411266386509 | 5.830657540936954 | -27.764770809729562 | | cutlass_lvl_default | 12.742593884468079 | 28.994930602959357 | -11.951954286402668 | | cutlass_lvl_1111 | 11.522261425852776 | 79.85037935699802 | -20.38413764531163 | | cutlass_lvl_2222 | 10.993581265211105 | 132.86601971101481 | -24.037181552548486 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (2048x2048, 2048x2048) torch.float16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 30.700622126460075 | 2.225986961973831 | NA | | triton | 29.17378954589367 | 38.571991189033724 | -4.97329524553989 | | triton_persistent_tma | 29.642896726727486 | 7.2848734309664 | -3.4452897904663744 | | cutlass_lvl_default | 29.514770954847336 | 29.819900761009194 | -3.8626291243482167 | | cutlass_lvl_1111 | 29.411429539322853 | 23.82907024596352 | -4.19923929172139 | | cutlass_lvl_2222 | 29.57325428724289 | 134.31008586101234 | -3.672133530628152 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (2048x2048, 2048x2048) torch.bfloat16 +-----------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+--------------------+ | aten | 30.858177691698074 | 1.181898436974734 | NA | | triton | 28.630023822188377 | 39.24473957403097 | -7.220626868414034 | | triton_persistent_tma | 28.641965240240097 | 5.275042273919098 | -7.181929126210897 | | cutlass_lvl_default | 29.16003204882145 | 29.934022572939284 | -5.503065216107967 | | cutlass_lvl_1111 | 28.79570797085762 | 23.948012012057006 | -6.683705504085324 | | cutlass_lvl_2222 | 29.02756631374359 | 136.25560767308343 | -5.932337924306467 | +-----------------------+--------------------+----------------------+--------------------+ Experiment group: mm (8192x8192, 8192x8192) torch.float16 +-----------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+--------------------+ | aten | 1456.143856048584 | 1.020197194069624 | NA | | triton | 1708.2737684249878 | 5.766509635956027 | 17.31490410985819 | | triton_persistent_tma | 1476.485013961792 | 7.455113030038774 | 1.3969195302177155 | | cutlass_lvl_default | 1583.3594799041748 | 50.408804678940214 | 8.736473620182366 | | cutlass_lvl_1111 | 1636.4418268203735 | 82.82403108896688 | 12.381879030898025 | | cutlass_lvl_2222 | 1507.5665712356567 | 260.03901409788523 | 3.531430975962381 | +-----------------------+--------------------+----------------------+--------------------+ Experiment group: mm (8192x8192, 8192x8192) torch.bfloat16 +-----------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+--------------------+ | aten | 1382.230520248413 | 1.2586536260787398 | NA | | triton | 1646.9683647155762 | 5.442052865982987 | 19.15294450447995 | | triton_persistent_tma | 1423.9195585250854 | 6.515797697938979 | 3.016069871556595 | | cutlass_lvl_default | 1500.9030103683472 | 51.36402789200656 | 8.58557877152115 | | cutlass_lvl_1111 | 1446.9740390777588 | 30.65435610699933 | 4.683988515729638 | | cutlass_lvl_2222 | 1419.661521911621 | 205.1948991640238 | 2.7080144096717635 | +-----------------------+--------------------+----------------------+--------------------+ ``` Differential Revision: D70147589 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148347 Approved by: https://github.com/drisspg, https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
c21dc11a17
commit
17518007b2
322
benchmarks/inductor_backends/cutlass.py
Normal file
322
benchmarks/inductor_backends/cutlass.py
Normal file
@ -0,0 +1,322 @@
|
||||
import os
|
||||
|
||||
|
||||
os.environ["TORCH_LOGS"] = "inductor"
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
from triton.testing import do_bench
|
||||
|
||||
import torch
|
||||
from torch._inductor import config as inductor_config
|
||||
|
||||
|
||||
inductor_config.autotune_num_choices_displayed = None
|
||||
# force autotuning, but reuse compilation artifacts
|
||||
inductor_config.autotune_local_cache = False
|
||||
# uncomment for better debugging
|
||||
# inductor_config.force_disable_caches = True
|
||||
|
||||
|
||||
UNITS = {
|
||||
"name": "",
|
||||
"forward_time": " (us)",
|
||||
"compilation_time": " (s)",
|
||||
}
|
||||
|
||||
OP_NAMES = ["mm"]
|
||||
|
||||
SHAPES = [
|
||||
# M, N, K
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
]
|
||||
|
||||
DTYPES = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
# triton knobs
|
||||
ENABLE_PERSISTENT_TMA_MATMULS = [
|
||||
False,
|
||||
True,
|
||||
]
|
||||
|
||||
# cutlass knobs
|
||||
CUTLASS_INSTANTIATION_LEVELS = [
|
||||
"0",
|
||||
"1111",
|
||||
"2222",
|
||||
# not ready yet
|
||||
# "3333",
|
||||
]
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
return do_bench(lambda: func(*args, **kwargs)) * 1e3
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ExperimentConfig:
|
||||
autotune_fallback_to_aten: bool = False
|
||||
max_autotune: bool = True
|
||||
coordinate_descent_tuning: bool = True
|
||||
max_autotune_gemm_backends: str = "ATEN"
|
||||
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
def to_options(self) -> dict[str, Any]:
|
||||
return {
|
||||
"autotune_fallback_to_aten": self.autotune_fallback_to_aten,
|
||||
"max_autotune": self.max_autotune,
|
||||
"coordinate_descent_tuning": self.coordinate_descent_tuning,
|
||||
"max_autotune_gemm_backends": self.max_autotune_gemm_backends,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AtenExperimentConfig(ExperimentConfig):
|
||||
def name(self) -> str:
|
||||
return "aten"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CutlassExperimentConfig(ExperimentConfig):
|
||||
cutlass_instantiation_level: str
|
||||
|
||||
def name(self) -> str:
|
||||
level_name = (
|
||||
self.cutlass_instantiation_level
|
||||
if self.cutlass_instantiation_level != "0"
|
||||
else "default"
|
||||
)
|
||||
return f"cutlass_lvl_{level_name}"
|
||||
|
||||
def to_options(self) -> dict[str, Any]:
|
||||
return {
|
||||
**super().to_options(),
|
||||
"cuda.cutlass_instantiation_level": self.cutlass_instantiation_level,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TritonExperimentConfig(ExperimentConfig):
|
||||
enable_persistent_tma_matmul: bool = False
|
||||
|
||||
def name(self) -> str:
|
||||
if self.enable_persistent_tma_matmul:
|
||||
return "triton_persistent_tma"
|
||||
else:
|
||||
return "triton"
|
||||
|
||||
def to_options(self) -> dict[str, Any]:
|
||||
return {
|
||||
**super().to_options(),
|
||||
"triton.enable_persistent_tma_matmul": self.enable_persistent_tma_matmul,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ExperimentGroupConfig:
|
||||
op_name: str
|
||||
shape: tuple[int, int, int]
|
||||
dtype: torch.dtype
|
||||
|
||||
experiments: list[ExperimentConfig] = field(default_factory=list)
|
||||
|
||||
def name(self) -> str:
|
||||
M, N, K = self.shape
|
||||
sizes = f"({M}x{K}, {K}x{N})"
|
||||
return f"{self.op_name} {sizes} {self.dtype}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ExperimentResults:
|
||||
name: str
|
||||
forward_time: float
|
||||
compilation_time: float
|
||||
|
||||
def asdict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ExperimentGroup:
|
||||
config: ExperimentGroupConfig
|
||||
results: list[ExperimentResults] = field(default_factory=list)
|
||||
|
||||
|
||||
def get_inputs(
|
||||
config: ExperimentGroupConfig,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
op_name = config.op_name
|
||||
M, N, K = config.shape
|
||||
dtype = config.dtype
|
||||
device = torch.device("cuda")
|
||||
|
||||
if op_name == "mm":
|
||||
A = torch.randn(M, K, dtype=dtype, device=device)
|
||||
B = torch.randn(K, N, dtype=dtype, device=device)
|
||||
C = None
|
||||
return A, B, C
|
||||
else:
|
||||
raise ValueError(f"Unknown op {op_name}")
|
||||
|
||||
|
||||
def run_single_experiment_group(
|
||||
group_config: ExperimentGroupConfig,
|
||||
) -> list[ExperimentResults]:
|
||||
A, B, C = get_inputs(group_config)
|
||||
op = getattr(torch, group_config.op_name)
|
||||
|
||||
results = []
|
||||
|
||||
for config in group_config.experiments:
|
||||
torch._dynamo.reset()
|
||||
torch._inductor.utils.clear_inductor_caches()
|
||||
compiled_op = torch.compile(op, fullgraph=True, options=config.to_options())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
_ = compiled_op(A, B)
|
||||
compilation_time = time.perf_counter() - start_time
|
||||
|
||||
forward_time = benchmark_torch_function_in_microseconds(
|
||||
compiled_op,
|
||||
A,
|
||||
B,
|
||||
)
|
||||
|
||||
results.append(
|
||||
ExperimentResults(
|
||||
name=config.name(),
|
||||
forward_time=forward_time,
|
||||
compilation_time=compilation_time,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_experiment_groups(
|
||||
op_names: list[str],
|
||||
shapes: list[tuple[int, int, int]],
|
||||
dtypes: list[torch.dtype],
|
||||
enable_persistent_tma_matmuls: list[bool],
|
||||
cutlass_instantiation_levels: list[str],
|
||||
) -> list[ExperimentGroupConfig]:
|
||||
groups = []
|
||||
for op_name, shape, dtype in itertools.product(op_names, shapes, dtypes):
|
||||
group = ExperimentGroupConfig(
|
||||
op_name=op_name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
)
|
||||
experiments = generate_experiment_configs(
|
||||
enable_persistent_tma_matmuls, cutlass_instantiation_levels
|
||||
)
|
||||
group.experiments.extend(experiments)
|
||||
groups.append(group)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def generate_experiment_configs(
|
||||
enable_persistent_tma_matmuls: list[bool], cutlass_instantiation_levels: list[str]
|
||||
) -> list[ExperimentConfig]:
|
||||
configs = []
|
||||
|
||||
# add aten configs
|
||||
configs.append(
|
||||
AtenExperimentConfig(
|
||||
max_autotune_gemm_backends="ATEN",
|
||||
)
|
||||
)
|
||||
|
||||
# add triton configs
|
||||
for enable_persistent_tma_matmul in enable_persistent_tma_matmuls:
|
||||
configs.append(
|
||||
TritonExperimentConfig(
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
enable_persistent_tma_matmul=enable_persistent_tma_matmul,
|
||||
)
|
||||
)
|
||||
|
||||
# add cutlass configs
|
||||
for cutlass_instantiation_level in cutlass_instantiation_levels:
|
||||
configs.append(
|
||||
CutlassExperimentConfig(
|
||||
max_autotune_gemm_backends="CUTLASS",
|
||||
cutlass_instantiation_level=cutlass_instantiation_level,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def tabulate_group_results(results: list[ExperimentResults]):
|
||||
table_data = defaultdict(list)
|
||||
aten_perf: Optional[float] = None
|
||||
perf_over_aten_str: str = "perf_over_aten (%)"
|
||||
|
||||
for experiment_result in results:
|
||||
for key, value in experiment_result.asdict().items():
|
||||
assert key in UNITS, f"Unknown key {key}"
|
||||
table_data[key + UNITS[key]].append(value)
|
||||
|
||||
if experiment_result.name == "aten":
|
||||
aten_perf = experiment_result.forward_time
|
||||
table_data[perf_over_aten_str].append("NA")
|
||||
elif aten_perf is not None:
|
||||
perf_over_aten = (
|
||||
(experiment_result.forward_time - aten_perf) / aten_perf * 100
|
||||
)
|
||||
table_data[perf_over_aten_str].append(perf_over_aten)
|
||||
else:
|
||||
# fallback in case aten is not in experiment group
|
||||
table_data[perf_over_aten_str].append("NA")
|
||||
|
||||
return tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")
|
||||
|
||||
|
||||
def print_results(experiment_groups: list[ExperimentGroup]):
|
||||
for experiment_group in experiment_groups:
|
||||
group_config_name = experiment_group.config.name()
|
||||
print(f"\nExperiment group: {group_config_name}")
|
||||
print(tabulate_group_results(experiment_group.results))
|
||||
|
||||
|
||||
def main():
|
||||
seed = 123
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for group_config in tqdm(
|
||||
generate_experiment_groups(
|
||||
OP_NAMES,
|
||||
SHAPES,
|
||||
DTYPES,
|
||||
ENABLE_PERSISTENT_TMA_MATMULS,
|
||||
CUTLASS_INSTANTIATION_LEVELS,
|
||||
)
|
||||
):
|
||||
results.append(
|
||||
ExperimentGroup(
|
||||
config=group_config, results=run_single_experiment_group(group_config)
|
||||
),
|
||||
)
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user