[Misc] Move Llama 4 projector call into encoder execution (#16201)

This commit is contained in:
Roger Wang
2025-04-07 14:02:05 -07:00
committed by GitHub
parent 090c856d76
commit ed636d99ca

View File

@ -760,6 +760,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self,
@ -791,10 +793,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings)
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, select_patch_features(mm_embeddings),
input_ids, inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_index)
return inputs_embeds