[Kernel][Performance] Tweak MoE Batched silu_mul_fp8_quant_deep_gemm kernel (#21193)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-19 11:39:51 +05:30
committed by GitHub
parent dd572c0ab3
commit dcc6cfb991

View File

@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm(
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
G = H // GROUP_SIZE
@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm(
cols = cols.to(tl.int64)
mask_h = cols < BLOCK
t = tl.zeros([], tl.int64)
while t < n_tokens:
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm(
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)
t += 1
def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm(
fp8_max,
is_blackwell_deep_gemm_used(),
BLOCK=group_size,
num_warps=4,
NUM_STAGES=8,
num_warps=1,
)
return y_q, y_s