mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cutlass backend] switch layout for cutlass backend benchmark (#149009)
``` python benchmarks/inductor_backends/cutlass.py ``` logs: ``` Experiment group: mm (1024x1024, 1024x1024) torch.float16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 13.059554621577263 | 1.580178506206721 | NA | | triton | 10.245470330119133 | 0.04118620231747627 | -21.54808776410064 | | triton_persistent_tma | 10.388538241386414 | 0.04225084185600281 | -20.45258400908819 | | cutlass_lvl_default | 12.882896699011326 | 231.14990583620965 | -1.3527101626732294 | | cutlass_lvl_1111 | 11.362981051206589 | 126.41650272067636 | -12.99105229490415 | | cutlass_lvl_2222 | 11.107578873634338 | 555.8380545829423 | -14.946725248331441 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (1024x1024, 1024x1024) torch.bfloat16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 14.037585817277431 | 0.21587548777461052 | NA | | triton | 10.571777820587158 | 78.15654796129093 | -24.68948750735019 | | triton_persistent_tma | 10.761583223938942 | 1.3195342738181353 | -23.337364672110443 | | cutlass_lvl_default | 12.872588820755482 | 237.0100042372942 | -8.299126443010406 | | cutlass_lvl_1111 | 11.08622644096613 | 137.55013868492097 | -21.02469338195443 | | cutlass_lvl_2222 | 11.044904589653015 | 551.265836935956 | -21.319059178545007 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (2048x2048, 2048x2048) torch.float16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 30.483894050121307 | 0.27990864124149084 | NA | | triton | 29.567627236247063 | 99.87172158574685 | -3.005740711366232 | | triton_persistent_tma | 29.66325916349888 | 1.3695051120594144 | -2.692027748401006 | | cutlass_lvl_default | 29.82821688055992 | 72.61214569816366 | -2.150897022812533 | | cutlass_lvl_1111 | 29.476772993803024 | 67.7428645719774 | -3.303780857728953 | | cutlass_lvl_2222 | 30.113255605101585 | 233.84051702311262 | -1.2158500630212203 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (2048x2048, 2048x2048) torch.bfloat16 +-----------------------+--------------------+----------------------+---------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+---------------------+ | aten | 30.58255836367607 | 0.058386584743857384 | NA | | triton | 29.799651354551315 | 100.18178300186992 | -2.559978795150901 | | triton_persistent_tma | 29.362043365836143 | 1.534341821912676 | -3.990885861562106 | | cutlass_lvl_default | 29.4346883893013 | 73.68858492700383 | -3.7533484305817093 | | cutlass_lvl_1111 | 29.164200648665428 | 75.44329373072833 | -4.637799421958348 | | cutlass_lvl_2222 | 29.13798950612545 | 227.33327346481383 | -4.7235056020244 | +-----------------------+--------------------+----------------------+---------------------+ Experiment group: mm (8192x8192, 8192x8192) torch.float16 +-----------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+--------------------+ | aten | 1656.6237211227417 | 0.0549461180344224 | NA | | triton | 1892.8285837173462 | 2.3174119112081826 | 14.258208401997386 | | triton_persistent_tma | 1665.332317352295 | 2.7922237082384527 | 0.525683419747917 | | cutlass_lvl_default | 1705.5492401123047 | 108.31571159465238 | 2.9533272019312116 | | cutlass_lvl_1111 | 1714.9059772491455 | 17.64627545280382 | 3.518134829489478 | | cutlass_lvl_2222 | 1680.4152727127075 | 306.9972395859659 | 1.4361469829637354 | +-----------------------+--------------------+----------------------+--------------------+ Experiment group: mm (8192x8192, 8192x8192) torch.bfloat16 +-----------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+----------------------+--------------------+ | aten | 1621.416687965393 | 0.06300561130046844 | NA | | triton | 1782.3902368545532 | 2.318530729971826 | 9.927956834535548 | | triton_persistent_tma | 1586.0934257507324 | 2.7931175641715527 | -2.178543151605614 | | cutlass_lvl_default | 1657.4617624282837 | 43.31810224894434 | 2.2230605328307784 | | cutlass_lvl_1111 | 1641.5367126464844 | 17.648567833006382 | 1.2408916739557292 | | cutlass_lvl_2222 | 1645.8417177200317 | 249.33647010894492 | 1.5064005407078918 | +-----------------------+--------------------+----------------------+--------------------+ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149009 Approved by: https://github.com/chenyang78, https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
4a12777ffe
commit
f2d43d866c
@ -167,7 +167,7 @@ def get_inputs(
|
||||
|
||||
if op_name == "mm":
|
||||
A = torch.randn(M, K, dtype=dtype, device=device)
|
||||
B = torch.randn(K, N, dtype=dtype, device=device)
|
||||
B = torch.randn(N, K, dtype=dtype, device=device).t()
|
||||
C = None
|
||||
return A, B, C
|
||||
else:
|
||||
|
Reference in New Issue
Block a user