mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[Fix] Deepseek V3 expert bias routing (#41647)
* [Fix] Deepseek V3 expert bias routing * [Fix] fix-copies * [Fix] Run make style
This commit is contained in:
@ -176,9 +176,11 @@ class DeepseekV3MoE(nn.Module):
|
||||
|
||||
def route_tokens_to_experts(self, router_logits):
|
||||
router_logits = router_logits.sigmoid()
|
||||
router_logits = router_logits + self.gate.e_score_correction_bias
|
||||
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
||||
group_scores = (
|
||||
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
@ -188,7 +190,7 @@ class DeepseekV3MoE(nn.Module):
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_logits.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
|
@ -132,9 +132,11 @@ class DeepseekV3MoE(nn.Module):
|
||||
|
||||
def route_tokens_to_experts(self, router_logits):
|
||||
router_logits = router_logits.sigmoid()
|
||||
router_logits = router_logits + self.gate.e_score_correction_bias
|
||||
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
||||
group_scores = (
|
||||
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
@ -144,7 +146,7 @@ class DeepseekV3MoE(nn.Module):
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_logits.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
|
@ -319,9 +319,11 @@ class Glm4MoeMoE(nn.Module):
|
||||
|
||||
def route_tokens_to_experts(self, router_logits):
|
||||
router_logits = router_logits.sigmoid()
|
||||
router_logits = router_logits + self.gate.e_score_correction_bias
|
||||
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
||||
group_scores = (
|
||||
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
@ -331,7 +333,7 @@ class Glm4MoeMoE(nn.Module):
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_logits.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
|
@ -376,9 +376,11 @@ class Glm4vMoeTextMoE(nn.Module):
|
||||
|
||||
def route_tokens_to_experts(self, router_logits):
|
||||
router_logits = router_logits.sigmoid()
|
||||
router_logits = router_logits + self.gate.e_score_correction_bias
|
||||
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
|
||||
group_scores = (
|
||||
router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
@ -388,7 +390,7 @@ class Glm4vMoeTextMoE(nn.Module):
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_logits.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
|
Reference in New Issue
Block a user