[Bugfix][v1] fixed llava-hf/llava-1.5-7b-hf is broken on V1 (#14554)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Chauncey
2025-03-11 02:24:51 +08:00
committed by GitHub
parent bc2d4473bf
commit 92b0ce2ac7

View File

@ -783,15 +783,19 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
if kwargs.get("v0_path", False) or \
image_input.get("feat_is_patch") is None or \
image_input.get("embed_is_patch") is None:
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
else:
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
def get_input_embeddings(
self,