JetMoe Fix jetmoe after #40132 (#41324)

* update

* up
This commit is contained in:
Arthur
2025-10-04 11:02:13 +02:00
committed by GitHub
parent 1bc75db9bd
commit e11a00a16f
2 changed files with 5 additions and 4 deletions

View File

@ -490,10 +490,10 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = JetMoeAttention(config, layer_idx)
self.mlp = JetMoeMoE(config)
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
self.self_attention = JetMoeAttention(config, layer_idx)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
@ -510,7 +510,7 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _, _ = self.self_attn(
hidden_states, _, _ = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,

View File

@ -374,9 +374,10 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
self.self_attn = JetMoeAttention(config, layer_idx)
self.self_attention = JetMoeAttention(config, layer_idx)
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
self.mlp = JetMoeMoE(config)
del self.self_attn
def forward(
self,
@ -392,7 +393,7 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _, _ = self.self_attn(
hidden_states, _, _ = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,