[Bugfix] Loosen type check to avoid errors in V1 (#15021)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-18 20:54:40 +08:00
committed by simon-mo
parent 54e084f7fb
commit be13281d4b
9 changed files with 28 additions and 37 deletions

View File

@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
@ -565,12 +565,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
pixel_values = flatten_bn(pixel_values, concat=True)
return Blip2ImagePixelInputs(
type="pixel_values",
@ -578,12 +577,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
image_embeds = flatten_bn(image_embeds, concat=True)
return Blip2ImageEmbeddingInputs(
type="image_embeds",

View File

@ -39,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
@ -972,12 +972,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
pixel_values = flatten_bn(pixel_values, concat=True)
return ChameleonImagePixelInputs(
type="pixel_values",

View File

@ -478,7 +478,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
flatten_bn(images_spatial_crop, concat=True)))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

View File

@ -578,7 +578,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

View File

@ -838,7 +838,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
@ -856,7 +856,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
assert isinstance(image_num_patches, (torch.Tensor, list))
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(pixel_values_flat)}")
return InternVLImagePixelInputs(
type="pixel_values",

View File

@ -349,21 +349,18 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
if pixel_values_videos is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}")
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
data=pixel_values_videos,
)
def _select_image_features(self, image_features: torch.Tensor, *,

View File

@ -574,10 +574,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None:
return None
if not (is_list_of(pixel_values_videos,
torch.Tensor) # different shape videos
or isinstance(pixel_values_videos,
torch.Tensor)): # same shape videos
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}")

View File

@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@ -270,12 +270,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
pixel_values = flatten_bn(pixel_values, concat=True)
return PaliGemmaImagePixelInputs(
type="pixel_values",
@ -287,8 +286,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
image_embeds = flatten_bn(image_embeds, concat=True)
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",

View File

@ -711,7 +711,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
@ -722,13 +722,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return QwenImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=flatten_bn(image_embeds, concat=True),
)
return None