mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] : Fix Batched DeepGemm Experts (#19515)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
e6aab5de29
commit
e3b12667d4
@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
# end up sending their tokens. This needs to be fixed.
|
||||
num_dispatchers = self.world_size
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
|
||||
max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
|
||||
assert w2.size(1) == K
|
||||
|
||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
|
@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, num_experts)
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
else:
|
||||
assert self.batched_triton_experts is not None
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, num_experts)
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
|
@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
topk: int, global_num_experts: int, local_num_experts: int
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
num_experts = global_num_experts
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_dp = self.dp_size
|
||||
num_experts = local_num_experts
|
||||
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
|
||||
workspace2 = (self.max_num_tokens * num_dp, N)
|
||||
return (workspace13, workspace2, workspace13, a.dtype)
|
||||
@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
|
@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M, topk, max(N * 2, K))
|
||||
workspace2 = (M, topk, N)
|
||||
|
@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if num_chunks == 1:
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts)
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts)
|
||||
else:
|
||||
# Use the full M to get the final output shape.
|
||||
_, _, fused_out_shape, _ = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts))
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
# Use the CHUNK_SIZE to get the workspace shapes.
|
||||
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
|
||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
|
@ -159,6 +159,12 @@ def moe_align_block_size(
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Note: In the case of expert_parallel, moe_align_block_size initially
|
||||
considers all experts as valid and aligns all tokens appropriately.
|
||||
Before the function returns it marks the experts_ids that are not in
|
||||
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
|
||||
This requires the num_experts input arg to be the num global experts.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
|
@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a, aq, M, N, K, topk, num_experts)
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
else:
|
||||
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
|
||||
num_experts)
|
||||
global_num_experts,
|
||||
local_num_experts)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user