mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
@ -490,10 +490,10 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = JetMoeAttention(config, layer_idx)
|
||||
self.mlp = JetMoeMoE(config)
|
||||
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
|
||||
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
|
||||
self.self_attention = JetMoeAttention(config, layer_idx)
|
||||
|
||||
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
||||
def forward(
|
||||
@ -510,7 +510,7 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _, _ = self.self_attn(
|
||||
hidden_states, _, _ = self.self_attention(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
|
@ -374,9 +374,10 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
|
||||
self.self_attn = JetMoeAttention(config, layer_idx)
|
||||
self.self_attention = JetMoeAttention(config, layer_idx)
|
||||
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
|
||||
self.mlp = JetMoeMoE(config)
|
||||
del self.self_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -392,7 +393,7 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _, _ = self.self_attn(
|
||||
hidden_states, _, _ = self.self_attention(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
|
Reference in New Issue
Block a user