[model] Reduce medusa weight (#10454)

Signed-off-by: skylee-01 <497627264@qq.com>
This commit is contained in:
Sky Lee
2024-11-20 14:05:55 +08:00
committed by GitHub
parent ed701ca963
commit 343041c4c4

View File

@ -61,14 +61,25 @@ class Medusa(nn.Module):
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
if getattr(config, "original_lm_head", False):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
)
self.lm_heads = [
self.lm_head for _ in range(self.config.num_heads)
]
else:
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
@ -172,6 +183,9 @@ class Medusa(nn.Module):
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
elif (getattr(self.config, "original_lm_head", False)
and name == "lm_heads.0.weight"):
weights_map["lm_head.weight"] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\