mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel][Performance] Tweak MoE Batched silu_mul_fp8_quant_deep_gemm kernel (#21193)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
dd572c0ab3
commit
dcc6cfb991
@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
cols = cols.to(tl.int64)
|
||||
mask_h = cols < BLOCK
|
||||
|
||||
t = tl.zeros([], tl.int64)
|
||||
while t < n_tokens:
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
base_i_offset = (e * stride_i_e + t * stride_i_t +
|
||||
g * GROUP_SIZE * stride_i_h)
|
||||
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
|
||||
@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset, y_s)
|
||||
|
||||
t += 1
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm(
|
||||
y: torch.Tensor, # (E, T, 2*H) float32
|
||||
@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
fp8_max,
|
||||
is_blackwell_deep_gemm_used(),
|
||||
BLOCK=group_size,
|
||||
num_warps=4,
|
||||
NUM_STAGES=8,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
Reference in New Issue
Block a user