[BugFix] : Fix Batched DeepGemm Experts (#19515)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-12 22:43:02 -04:00
committed by GitHub
parent e6aab5de29
commit e3b12667d4
9 changed files with 52 additions and 32 deletions

View File

@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.world_size
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def apply(
@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states
_, N, K = w1.size()
if global_num_experts == -1:
global_num_experts = w1.size(0)
assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(

View File

@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
def apply(
self,

View File

@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()

View File

@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)

View File

@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
num_dp = self.dp_size
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N)
return (workspace13, workspace2, workspace13, a.dtype)
@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))

View File

@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)

View File

@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.

View File

@ -159,6 +159,12 @@ def moe_align_block_size(
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Note: In the case of expert_parallel, moe_align_block_size initially
considers all experts as valid and aligns all tokens appropriately.
Before the function returns it marks the experts_ids that are not in
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
This requires the num_experts input arg to be the num global experts.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.

View File

@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int,
K: int,
topk: int,
num_experts: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else:
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts)
global_num_experts,
local_num_experts)
def apply(
self,