[Perf] Small optimizations for silu_mul_fp8_quant_deep_gemm (#23265)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def benchmark(E, T, H, G=128, runs=50):
|
||||
current_platform.seed_everything(42)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg_time = (time.perf_counter() - start) / runs * 1000
|
||||
|
||||
# Calculate actual work done (only count valid tokens)
|
||||
actual_tokens = tokens_per_expert.sum().item()
|
||||
actual_elements = actual_tokens * H
|
||||
|
||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||
ops_per_element = 8
|
||||
total_ops = actual_elements * ops_per_element
|
||||
gflops = total_ops / (avg_time / 1000) / 1e9
|
||||
|
||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||
memory_bw = total_bytes / (avg_time / 1000) / 1e9
|
||||
|
||||
return avg_time, gflops, memory_bw
|
||||
|
||||
|
||||
configs = [
|
||||
(8, 32, 1024),
|
||||
(16, 64, 2048),
|
||||
(32, 128, 4096),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 16, 7168),
|
||||
(256, 32, 7168),
|
||||
(256, 64, 7168),
|
||||
(256, 128, 7168),
|
||||
(256, 256, 7168),
|
||||
(256, 512, 7168),
|
||||
(256, 1024, 7168),
|
||||
]
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for E, T, H in configs:
|
||||
try:
|
||||
time_ms, gflops, gbps = benchmark(E, T, H)
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
|
||||
except Exception:
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
|
@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
low=0,
|
||||
high=T,
|
||||
@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
y_se = y_s[e]
|
||||
y_qe = y_q[e]
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
|
@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK)
|
||||
cols = cols.to(tl.int64)
|
||||
mask_h = cols < BLOCK
|
||||
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||
mask = cols < BLOCK
|
||||
|
||||
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||
base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h +
|
||||
cols * stride_yq_h)
|
||||
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
base_i_offset = (e * stride_i_e + t * stride_i_t +
|
||||
g * GROUP_SIZE * stride_i_h)
|
||||
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
|
||||
g * GROUP_SIZE * stride_yq_h)
|
||||
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
|
||||
|
||||
mask = mask_h
|
||||
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
|
||||
cols * stride_i_h,
|
||||
gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
up = tl.load(input_ptr + base_up_offset + t * stride_i_t,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
other=0.0)
|
||||
|
||||
x = x * (1.0 / (1.0 + tl.exp(-x)))
|
||||
y = x * y2
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if use_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(
|
||||
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset, y_s)
|
||||
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm(
|
||||
y: torch.Tensor, # (E, T, 2*H) float32
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
|
||||
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# static grid over experts and H-groups.
|
||||
# Static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G, )
|
||||
|
||||
@ -178,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
fp8_max,
|
||||
is_blackwell_deep_gemm_e8m0_used(),
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=8,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user