mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Fix broken Florence-2 model (#23426)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module):
|
|||||||
|
|
||||||
encoder_hidden_states = None
|
encoder_hidden_states = None
|
||||||
|
|
||||||
if inputs_embeds is not None or encoder_input_ids.numel() > 0:
|
if ((inputs_embeds is not None and inputs_embeds.numel() > 0)
|
||||||
|
or encoder_input_ids.numel() > 0):
|
||||||
# Run encoder attention if a non-zero number of encoder tokens
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
# are provided as input
|
# are provided as input
|
||||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
|||||||
self.lm_head = BartParallelLMHead(self.vocab_size,
|
self.lm_head = BartParallelLMHead(self.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
embed_scale=embed_scale)
|
embed_scale=embed_scale)
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head.tie_weights(self.model.shared)
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.vocab_size,
|
self.logits_processor = LogitsProcessor(self.vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
|||||||
else:
|
else:
|
||||||
if "final_logits_bias" in name:
|
if "final_logits_bias" in name:
|
||||||
continue
|
continue
|
||||||
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
if self.config.tie_word_embeddings and ("embed_tokens" in name
|
||||||
|
or "lm_head" in name):
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
Reference in New Issue
Block a user