[EP+DP] Optimize the little operations in the DeepGEMM + DeepEP low latency case (#19885)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Tyler Michael Smith
2025-06-23 14:07:47 -04:00
committed by GitHub
parent c3649e4fee
commit 68aaeb3749
3 changed files with 263 additions and 18 deletions

View File

@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm)
from vllm.platforms import current_platform
# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
]
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform.seed_everything(seed)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
tokens_per_expert = torch.randint(
low=0,
high=T,
size=(E, ),
dtype=torch.int32,
device="cuda",
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
tokens_per_expert,
group_size=group_size,
eps=1e-10)
# Reference implementation
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
# Compute silu activation and elementwise multiplication
y1 = y[..., :H]
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data / scale.repeat_interleave(group_size)
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)
ref_s[t] = scale
ref_q[t] = q
y_se = y_s[e]
y_qe = y_q[e]
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
)

View File

@ -6,14 +6,179 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
input_ptr, # 16-bit activations (E, T, 2*H)
y_q_ptr, # fp8 quantized activations (E, T, H)
y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e,
stride_i_t,
stride_i_h,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,
# Stride for counts (elements)
stride_counts_e,
# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
):
G = H // GROUP_SIZE
# map program id -> (e, g)
pid = tl.program_id(0)
e = pid // G
g = pid % G
e = e.to(tl.int64)
g = g.to(tl.int64)
# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK)
cols = cols.to(tl.int64)
mask_h = cols < BLOCK
t = tl.zeros([], tl.int64)
while t < n_tokens:
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
g * GROUP_SIZE * stride_yq_h)
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
x = x * (1.0 / (1.0 + tl.exp(-x)))
y = x * y2
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)
t += 1
def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
):
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
"tokens_per_expert must be shape (E,)"
tokens_per_expert = tokens_per_expert.to(device=y.device,
dtype=torch.int32)
# allocate outputs
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided((E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
device=y.device)
stride_cnt_e = tokens_per_expert.stride()[0]
# static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
stride_cnt_e,
eps,
fp8_min,
fp8_max,
BLOCK=group_size,
num_warps=4,
)
return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
@ -96,7 +261,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states, w1, w2, topk_ids)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
@ -109,19 +273,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
masked_m=expert_num_tokens,
expected_m=expected_m)
# TODO (varun) [Optimization]: Use a batched version of activation.
# Similarly for the quant below.
self.activation(activation, workspace2, workspace1.view(-1, N))
w2_hidden_size = workspace2.size(-1)
workspace2 = workspace2.view(-1, w2_hidden_size)
a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
self.block_shape[1],
column_major_scales=False)
a2q = a2q.view(E, max_num_tokens, -1)
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),

View File

@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -377,6 +378,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
assert act_quant_block_size is not None
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size[1]
== DEEPEP_QUANT_BLOCK_SIZE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
@ -386,7 +394,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
use_fp8_dispatch=use_fp8_dispatch,
)
self.topk_indices_dtype = None