[cutlass backend] switch layout for cutlass backend benchmark (#149009)

```
python benchmarks/inductor_backends/cutlass.py
```

logs:
```
Experiment group: mm (1024x1024, 1024x1024) torch.float16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 13.059554621577263 |  1.580178506206721   |         NA          |
|        triton         | 10.245470330119133 | 0.04118620231747627  | -21.54808776410064  |
| triton_persistent_tma | 10.388538241386414 | 0.04225084185600281  | -20.45258400908819  |
|  cutlass_lvl_default  | 12.882896699011326 |  231.14990583620965  | -1.3527101626732294 |
|   cutlass_lvl_1111    | 11.362981051206589 |  126.41650272067636  | -12.99105229490415  |
|   cutlass_lvl_2222    | 11.107578873634338 |  555.8380545829423   | -14.946725248331441 |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (1024x1024, 1024x1024) torch.bfloat16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 14.037585817277431 | 0.21587548777461052  |         NA          |
|        triton         | 10.571777820587158 |  78.15654796129093   | -24.68948750735019  |
| triton_persistent_tma | 10.761583223938942 |  1.3195342738181353  | -23.337364672110443 |
|  cutlass_lvl_default  | 12.872588820755482 |  237.0100042372942   | -8.299126443010406  |
|   cutlass_lvl_1111    | 11.08622644096613  |  137.55013868492097  | -21.02469338195443  |
|   cutlass_lvl_2222    | 11.044904589653015 |   551.265836935956   | -21.319059178545007 |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (2048x2048, 2048x2048) torch.float16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 30.483894050121307 | 0.27990864124149084  |         NA          |
|        triton         | 29.567627236247063 |  99.87172158574685   | -3.005740711366232  |
| triton_persistent_tma | 29.66325916349888  |  1.3695051120594144  | -2.692027748401006  |
|  cutlass_lvl_default  | 29.82821688055992  |  72.61214569816366   | -2.150897022812533  |
|   cutlass_lvl_1111    | 29.476772993803024 |   67.7428645719774   | -3.303780857728953  |
|   cutlass_lvl_2222    | 30.113255605101585 |  233.84051702311262  | -1.2158500630212203 |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (2048x2048, 2048x2048) torch.bfloat16
+-----------------------+--------------------+----------------------+---------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%)  |
+-----------------------+--------------------+----------------------+---------------------+
|         aten          | 30.58255836367607  | 0.058386584743857384 |         NA          |
|        triton         | 29.799651354551315 |  100.18178300186992  | -2.559978795150901  |
| triton_persistent_tma | 29.362043365836143 |  1.534341821912676   | -3.990885861562106  |
|  cutlass_lvl_default  |  29.4346883893013  |  73.68858492700383   | -3.7533484305817093 |
|   cutlass_lvl_1111    | 29.164200648665428 |  75.44329373072833   | -4.637799421958348  |
|   cutlass_lvl_2222    | 29.13798950612545  |  227.33327346481383  |  -4.7235056020244   |
+-----------------------+--------------------+----------------------+---------------------+

Experiment group: mm (8192x8192, 8192x8192) torch.float16
+-----------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+----------------------+--------------------+
|         aten          | 1656.6237211227417 |  0.0549461180344224  |         NA         |
|        triton         | 1892.8285837173462 |  2.3174119112081826  | 14.258208401997386 |
| triton_persistent_tma | 1665.332317352295  |  2.7922237082384527  | 0.525683419747917  |
|  cutlass_lvl_default  | 1705.5492401123047 |  108.31571159465238  | 2.9533272019312116 |
|   cutlass_lvl_1111    | 1714.9059772491455 |  17.64627545280382   | 3.518134829489478  |
|   cutlass_lvl_2222    | 1680.4152727127075 |  306.9972395859659   | 1.4361469829637354 |
+-----------------------+--------------------+----------------------+--------------------+

Experiment group: mm (8192x8192, 8192x8192) torch.bfloat16
+-----------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+----------------------+--------------------+
|         aten          | 1621.416687965393  | 0.06300561130046844  |         NA         |
|        triton         | 1782.3902368545532 |  2.318530729971826   | 9.927956834535548  |
| triton_persistent_tma | 1586.0934257507324 |  2.7931175641715527  | -2.178543151605614 |
|  cutlass_lvl_default  | 1657.4617624282837 |  43.31810224894434   | 2.2230605328307784 |
|   cutlass_lvl_1111    | 1641.5367126464844 |  17.648567833006382  | 1.2408916739557292 |
|   cutlass_lvl_2222    | 1645.8417177200317 |  249.33647010894492  | 1.5064005407078918 |
+-----------------------+--------------------+----------------------+--------------------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149009
Approved by: https://github.com/chenyang78, https://github.com/jingsh
This commit is contained in:
henrylhtsang
2025-03-11 16:18:29 -07:00
committed by PyTorch MergeBot
parent 4a12777ffe
commit f2d43d866c

View File

@ -167,7 +167,7 @@ def get_inputs(
if op_name == "mm":
A = torch.randn(M, K, dtype=dtype, device=device)
B = torch.randn(K, N, dtype=dtype, device=device)
B = torch.randn(N, K, dtype=dtype, device=device).t()
C = None
return A, B, C
else: