[cutlass backend] add fp8 to cutlass benchmark script (#155507)

Summary:
Add fp8.

Right now FP8 only allows fast_accum.

Test Plan:
```
Experiment group: _scaled_mm (8192x8192, 8192x8192) torch.float8_e4m3fn
+-----------------------+--------------------+--------------------+----------------------+--------------------+
|         name          | forward_time (us)  | teraflops (TFLOPS) | compilation_time (s) | perf_over_aten (%) |
+-----------------------+--------------------+--------------------+----------------------+--------------------+
|         aten          | 967.1226739883423  | 1136.8895149998868 |  1.219131228979677   |         NA         |
|        triton         | 1764.6185159683228 |  623.08743664783   |  20.373826419003308  | 82.46067054670186  |
| triton_persistent_tma | 1769.0335512161255 | 621.5323768280928  |  20.48663099599071   | 82.91718297956578  |
|  cutlass_lvl_default  | 790.5075550079346  | 1390.8932568835019 |  13.788519630907103  | -18.26191482535096 |
|   cutlass_lvl_3332    | 803.7384748458862  | 1367.996757884245  |  226.81587297911756  | -16.89384434227684 |
+-----------------------+--------------------+--------------------+----------------------+--------------------+
```

Rollback Plan:

Differential Revision: D76310809

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155507
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Henry Tsang
2025-06-13 05:11:15 +00:00
committed by PyTorch MergeBot
parent 2ba930d4ce
commit b878ca0c91

View File

@ -18,6 +18,7 @@ from triton.testing import do_bench
import torch
from torch._inductor import config as inductor_config
from torch.testing._internal.inductor_utils import _quantize_rowwise
log: logging.Logger = logging.getLogger(__name__)
@ -29,6 +30,7 @@ inductor_config.autotune_local_cache = False
# uncomment for better debugging
# inductor_config.force_disable_caches = True
USE_FAST_ACCUM = True
UNITS = {
"name": "",
@ -40,8 +42,9 @@ PERF_OVER_ATEN_STR: str = "perf_over_aten (%)"
OP_NAMES = [
"mm",
"addmm",
"bmm",
# "addmm",
# "bmm",
# "_scaled_mm",
]
SHAPES = [
@ -59,6 +62,7 @@ BATCH_SIZES = [
DTYPES = [
torch.float16,
torch.bfloat16,
# torch.float8_e4m3fn,
]
# triton knobs
@ -72,7 +76,8 @@ CUTLASS_INSTANTIATION_LEVELS = [
"0",
# "1111",
# "2222",
"3333",
"3332",
# "9992",
]
@ -199,6 +204,34 @@ def get_inputs(
A = torch.randn(batch_size, M, K, dtype=dtype, device=device)
B = torch.randn(batch_size, N, K, dtype=dtype, device=device).permute(0, 2, 1)
return A, B
elif op_name == "_scaled_mm":
# For _scaled_mm, we only support fp8e4m3 with rowwise scaling
if dtype != torch.float8_e4m3fn:
raise ValueError(f"_scaled_mm only supports fp8e4m3, got {dtype}")
# Create input tensors in bfloat16 first, then quantize to fp8
input_dtype = torch.bfloat16
x = torch.randn(M, K, dtype=input_dtype, device=device)
w = torch.randn(N, K, dtype=input_dtype, device=device)
# Quantize using rowwise scaling
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype)
# Return inputs for _scaled_mm: (input, weight_t, scale_a, scale_b, bias, out, out_dtype, use_fast_accum)
return (
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
None,
None,
torch.bfloat16,
USE_FAST_ACCUM,
)
else:
raise ValueError(f"Unknown op {op_name}")
@ -361,6 +394,8 @@ def calculate_flops(op_name: str, shape: tuple[int, int, int], batch_size: int)
return 2 * batch_size * M * N * K
elif op_name == "addmm":
return 2 * M * N * K + M * N
elif op_name == "_scaled_mm":
return 2 * M * N * K
else:
return 2 * M * N * K