mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix stablelm (#3038)
This commit is contained in:
@ -43,6 +43,7 @@ _MODELS = {
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
|
@ -94,7 +94,9 @@ class StablelmAttention(nn.Module):
|
||||
1, self.total_num_key_value_heads // tp_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
|
||||
rope_pct = getattr(config, "rope_pct",
|
||||
getattr(config, "partial_rotary_factor", 1))
|
||||
self.rotary_ndims = int(self.head_dim * rope_pct)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||
@ -114,7 +116,6 @@ class StablelmAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_ndims,
|
||||
@ -152,10 +153,11 @@ class StablelmDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config)
|
||||
self.mlp = StablelmMLP(config, linear_method)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_eps)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_eps)
|
||||
eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -199,7 +201,9 @@ class StableLMEpochModel(nn.Module):
|
||||
StablelmDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user