mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Minor][Models] Pass partial_rotary_factor parameter to rope (#17266)
Signed-off-by: evian <eviantai@u.nus.edu> Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
@ -130,8 +130,8 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
# Phi models introduced a partial_rotary_factor parameter in the config
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
|
||||
1)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -163,11 +163,12 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
|
@ -115,9 +115,10 @@ class PersimmonAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
|
@ -104,9 +104,8 @@ class StablelmAttention(nn.Module):
|
||||
1, self.total_num_key_value_heads // tp_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
rope_pct = getattr(config, "rope_pct",
|
||||
getattr(config, "partial_rotary_factor", 1))
|
||||
self.rotary_ndims = int(self.head_dim * rope_pct)
|
||||
self.partial_rotary_factor = getattr(
|
||||
config, "rope_pct", getattr(config, "partial_rotary_factor", 1))
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||
@ -130,9 +129,10 @@ class StablelmAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_ndims,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.config.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
|
Reference in New Issue
Block a user