[Fix] Deepseek V3 expert bias routing (#41647)

* [Fix] Deepseek V3 expert bias routing

* [Fix] fix-copies

* [Fix] Run make style
This commit is contained in:
Fabian Joswig
2025-10-16 16:04:48 +02:00
committed by GitHub
parent 1fb3fc4db0
commit 8725ce10ed
4 changed files with 20 additions and 12 deletions

View File

@ -176,9 +176,11 @@ class DeepseekV3MoE(nn.Module):
def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()
router_logits = router_logits + self.gate.e_score_correction_bias
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
@ -188,7 +190,7 @@ class DeepseekV3MoE(nn.Module):
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:

View File

@ -132,9 +132,11 @@ class DeepseekV3MoE(nn.Module):
def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()
router_logits = router_logits + self.gate.e_score_correction_bias
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
@ -144,7 +146,7 @@ class DeepseekV3MoE(nn.Module):
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:

View File

@ -319,9 +319,11 @@ class Glm4MoeMoE(nn.Module):
def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()
router_logits = router_logits + self.gate.e_score_correction_bias
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
@ -331,7 +333,7 @@ class Glm4MoeMoE(nn.Module):
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:

View File

@ -376,9 +376,11 @@ class Glm4vMoeTextMoE(nn.Module):
def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()
router_logits = router_logits + self.gate.e_score_correction_bias
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
@ -388,7 +390,7 @@ class Glm4vMoeTextMoE(nn.Module):
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob: