fix moe routing_weights (#39581)

* fix moe routing_weights

* fix ernie4_5_moe routing_weights

* fix integration test

---------

Co-authored-by: llbdyiu66 <llbdyiu66@users.noreply.github.com>
Co-authored-by: Vasqu <antonprogamer@gmail.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
llbdyiu66
2025-07-23 19:20:23 +08:00
committed by GitHub
parent 623ab01039
commit a62f65a989
3 changed files with 5 additions and 11 deletions

View File

@ -339,12 +339,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states.float())
# NOTE: we are using the original code base at
# https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116
# this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights = self.moe_statics(routing_weights)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)

View File

@ -150,12 +150,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states.float())
# NOTE: we are using the original code base at
# https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116
# this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights = self.moe_statics(routing_weights)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)

View File

@ -181,7 +181,7 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@slow
def test_model_21b_a3b_generation(self):
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: Yes, I am conscious and I can communicate with you. How can I assist you with any questions or information you need?" # fmt: skip
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: I don't have consciousness in the way humans do. I'm a text-based AI created to process and generate responses based on patterns in data." # fmt: skip
model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")