[Model] Fix Ernie4.5MoE e_score_correction_bias parameter (#21586)

Signed-off-by: zhouchong <zhouchong03@baidu.com>
Co-authored-by: zhouchong <zhouchong03@baidu.com>
This commit is contained in:
xyxinyang
2025-07-25 21:02:53 +08:00
committed by GitHub
parent f3a683b7c9
commit c72f049cb4

View File

@ -123,14 +123,19 @@ class Ernie4_5_MoeMoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=config.moe_num_experts,
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.moe_num_experts))
self.experts = FusedMoE(
num_experts=config.moe_num_experts,
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias)
if self.moe_num_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
@ -459,6 +464,10 @@ class Ernie4_5_MoeModel(nn.Module):
if "mtp" in name:
continue
if "e_score_correction_bias" in name:
name = name.replace("moe_statics", "gate")
loaded_weight = loaded_weight.squeeze(0)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: