Fix stablelm (#3038)

This commit is contained in:
Roy
2024-02-27 10:31:10 +08:00
committed by GitHub
parent c1c0d00b88
commit 4dd6416faf
2 changed files with 11 additions and 6 deletions

View File

@ -43,6 +43,7 @@ _MODELS = {
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
}
# Models not supported by ROCm.

View File

@ -94,7 +94,9 @@ class StablelmAttention(nn.Module):
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
rope_pct = getattr(config, "rope_pct",
getattr(config, "partial_rotary_factor", 1))
self.rotary_ndims = int(self.head_dim * rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
@ -114,7 +116,6 @@ class StablelmAttention(nn.Module):
self.hidden_size,
bias=False,
linear_method=linear_method)
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
@ -152,10 +153,11 @@ class StablelmDecoderLayer(nn.Module):
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
eps=norm_eps)
def forward(
self,
@ -199,7 +201,9 @@ class StableLMEpochModel(nn.Module):
StablelmDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
def forward(
self,