mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[model] Reduce medusa weight (#10454)
Signed-off-by: skylee-01 <497627264@qq.com>
This commit is contained in:
@ -61,14 +61,25 @@ class Medusa(nn.Module):
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||
|
||||
self.lm_heads = nn.ModuleList([
|
||||
ParallelLMHead(
|
||||
if getattr(config, "original_lm_head", False):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
) for _ in range(self.config.num_heads)
|
||||
])
|
||||
)
|
||||
self.lm_heads = [
|
||||
self.lm_head for _ in range(self.config.num_heads)
|
||||
]
|
||||
else:
|
||||
self.lm_heads = nn.ModuleList([
|
||||
ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
) for _ in range(self.config.num_heads)
|
||||
])
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
@ -172,6 +183,9 @@ class Medusa(nn.Module):
|
||||
requires_grad=False)
|
||||
elif name in params_dict:
|
||||
weights_map[name] = loaded_weight
|
||||
elif (getattr(self.config, "original_lm_head", False)
|
||||
and name == "lm_heads.0.weight"):
|
||||
weights_map["lm_head.weight"] = loaded_weight
|
||||
|
||||
for name, loaded_weight in weights_map.items():
|
||||
if "lm_head" in name and self.token_map is not None and\
|
||||
|
Reference in New Issue
Block a user