mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cutlass backend] add fp8 to cutlass benchmark script (#155507)
Summary: Add fp8. Right now FP8 only allows fast_accum. Test Plan: ``` Experiment group: _scaled_mm (8192x8192, 8192x8192) torch.float8_e4m3fn +-----------------------+--------------------+--------------------+----------------------+--------------------+ | name | forward_time (us) | teraflops (TFLOPS) | compilation_time (s) | perf_over_aten (%) | +-----------------------+--------------------+--------------------+----------------------+--------------------+ | aten | 967.1226739883423 | 1136.8895149998868 | 1.219131228979677 | NA | | triton | 1764.6185159683228 | 623.08743664783 | 20.373826419003308 | 82.46067054670186 | | triton_persistent_tma | 1769.0335512161255 | 621.5323768280928 | 20.48663099599071 | 82.91718297956578 | | cutlass_lvl_default | 790.5075550079346 | 1390.8932568835019 | 13.788519630907103 | -18.26191482535096 | | cutlass_lvl_3332 | 803.7384748458862 | 1367.996757884245 | 226.81587297911756 | -16.89384434227684 | +-----------------------+--------------------+--------------------+----------------------+--------------------+ ``` Rollback Plan: Differential Revision: D76310809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155507 Approved by: https://github.com/ColinPeppler
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ba930d4ce
commit
b878ca0c91
@ -18,6 +18,7 @@ from triton.testing import do_bench
|
||||
|
||||
import torch
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch.testing._internal.inductor_utils import _quantize_rowwise
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
@ -29,6 +30,7 @@ inductor_config.autotune_local_cache = False
|
||||
# uncomment for better debugging
|
||||
# inductor_config.force_disable_caches = True
|
||||
|
||||
USE_FAST_ACCUM = True
|
||||
|
||||
UNITS = {
|
||||
"name": "",
|
||||
@ -40,8 +42,9 @@ PERF_OVER_ATEN_STR: str = "perf_over_aten (%)"
|
||||
|
||||
OP_NAMES = [
|
||||
"mm",
|
||||
"addmm",
|
||||
"bmm",
|
||||
# "addmm",
|
||||
# "bmm",
|
||||
# "_scaled_mm",
|
||||
]
|
||||
|
||||
SHAPES = [
|
||||
@ -59,6 +62,7 @@ BATCH_SIZES = [
|
||||
DTYPES = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
# torch.float8_e4m3fn,
|
||||
]
|
||||
|
||||
# triton knobs
|
||||
@ -72,7 +76,8 @@ CUTLASS_INSTANTIATION_LEVELS = [
|
||||
"0",
|
||||
# "1111",
|
||||
# "2222",
|
||||
"3333",
|
||||
"3332",
|
||||
# "9992",
|
||||
]
|
||||
|
||||
|
||||
@ -199,6 +204,34 @@ def get_inputs(
|
||||
A = torch.randn(batch_size, M, K, dtype=dtype, device=device)
|
||||
B = torch.randn(batch_size, N, K, dtype=dtype, device=device).permute(0, 2, 1)
|
||||
return A, B
|
||||
elif op_name == "_scaled_mm":
|
||||
# For _scaled_mm, we only support fp8e4m3 with rowwise scaling
|
||||
if dtype != torch.float8_e4m3fn:
|
||||
raise ValueError(f"_scaled_mm only supports fp8e4m3, got {dtype}")
|
||||
|
||||
# Create input tensors in bfloat16 first, then quantize to fp8
|
||||
input_dtype = torch.bfloat16
|
||||
x = torch.randn(M, K, dtype=input_dtype, device=device)
|
||||
w = torch.randn(N, K, dtype=input_dtype, device=device)
|
||||
|
||||
# Quantize using rowwise scaling
|
||||
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype)
|
||||
w_t_fp8 = w_fp8.t()
|
||||
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
|
||||
|
||||
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype)
|
||||
|
||||
# Return inputs for _scaled_mm: (input, weight_t, scale_a, scale_b, bias, out, out_dtype, use_fast_accum)
|
||||
return (
|
||||
x_fp8,
|
||||
w_t_fp8,
|
||||
x_inverse_scale,
|
||||
w_inverse_scale,
|
||||
None,
|
||||
None,
|
||||
torch.bfloat16,
|
||||
USE_FAST_ACCUM,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown op {op_name}")
|
||||
|
||||
@ -361,6 +394,8 @@ def calculate_flops(op_name: str, shape: tuple[int, int, int], batch_size: int)
|
||||
return 2 * batch_size * M * N * K
|
||||
elif op_name == "addmm":
|
||||
return 2 * M * N * K + M * N
|
||||
elif op_name == "_scaled_mm":
|
||||
return 2 * M * N * K
|
||||
else:
|
||||
return 2 * M * N * K
|
||||
|
||||
|
Reference in New Issue
Block a user