diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index e0d95758a8..c584dce793 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: engine_args = EngineArgs( model=model_name, trust_remote_code=True, - max_model_len=8192, + max_model_len=32768, max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 567793e9b7..5e69c10b40 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -53,7 +53,7 @@ from .idefics2_vision_model import ( # yapf: enable from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema): """ type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] - pixel_attention_mask: torch.Tensor + pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -569,6 +569,8 @@ class Idefics3Model(nn.Module): dummy_inputs=Idefics3DummyInputsBuilder) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Idefics3ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_attention_mask = kwargs.pop("pixel_attention_mask") - if not isinstance(pixel_attention_mask, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_attention_mask. " - f"Got type: {type(pixel_attention_mask)}") - num_patches = kwargs.pop("num_patches") - if not isinstance(num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") - expected_h = expected_w = self.config.vision_config.image_size + return Idefics3ImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - pixel_attention_mask=flatten_bn(pixel_attention_mask, - concat=True), - num_patches=flatten_bn(num_patches, concat=True), + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, resolve_bindings={ "h": expected_h, "w": expected_w diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 20f705cca8..dda24bb784 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) @@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -100,8 +99,7 @@ def smart_resize( class KeyeImagePixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images @@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema): type: Literal["pixel_values"] pixel_values: Annotated[ torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images @@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): class BaseKeyeModule(nn.Module): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, prefix: str = "") -> nn.Module: return Projector(text_config, vision_config, quant_config, prefix) - def _validate_and_reshape_mm_tensor( - self, mm_input: NestedTensors, - name: str) -> Union[torch.Tensor, list[torch.Tensor]]: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim == 5: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - elif is_list_of(mm_input, torch.Tensor): - if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 - for p in mm_input): - return mm_input - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, - "video pixel values", - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return KeyeVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 93a3bf5f98..6e34230878 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) @@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor, class KeyeVL1_5ImagePixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images @@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema): pixel_values: Annotated[ torch.Tensor, - TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, class KeyeVL1_5VideoPixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images @@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ torch.Tensor, - TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] num_frames: torch.Tensor @@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, self.merge_size = config.vision_config.spatial_merge_size super().__init__(vllm_config=vllm_config, prefix=prefix) - def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, - expected_dim: int, name: str): - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == expected_dim: - return mm_input - elif mm_input.ndim == expected_dim + 1: - return mm_input.reshape(-1, *mm_input.shape[2:]) - else: - raise ValueError( - f"{name} should be {expected_dim}D or " - f"batched {expected_dim}D tensor." - f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, expected_dim=4, name="image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, expected_dim=2, name="image grid_thw") - return KeyeVL1_5ImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, expected_dim=2, name="image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, expected_dim=2, name="image grid_thw") - return KeyeVL1_5ImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, - expected_dim=4, - name="video pixel values", - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, expected_dim=2, name="video grid_thw") - - num_frames = self._validate_and_reshape_mm_tensor( - num_frames, expected_dim=1, name="video num frames") - return KeyeVL1_5VideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, num_frames=num_frames) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, expected_dim=2, name="video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, expected_dim=2, name="video grid_thw") - return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index a47bdd2f5a..60404376f2 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): dummy_inputs=KimiVLDummyInputsBuilder) class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True supports_encoder_tp_data = True @@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, config.vocab_size, logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id - # ref: qwen2_vl.py - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KimiVLImageInputs]: # image input type must be pixel values now @@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is None: return None - image_grid_hws = self._validate_and_reshape_mm_tensor( - image_grid_hws, "image grid hws") - # pixel_values may have complex shapes - num_channels = 3 - patch_size = self.config.vision_config.patch_size - if isinstance(pixel_values, list): - pixel_values = torch.cat([ - x.reshape(-1, num_channels, patch_size, patch_size) - for x in pixel_values - ]) - else: - pixel_values = pixel_values.reshape(-1, num_channels, patch_size, - patch_size) - pixel_values = pixel_values.to(self.vision_tower.dtype) - return KimiVLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 44688467b8..81daca7dfb 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -164,7 +164,9 @@ class TensorSchema: if len(actual_shape) != len(expected_shape): raise ValueError(f"{field_name} has rank {len(actual_shape)} " - f"but expected {len(expected_shape)}") + f"but expected {len(expected_shape)}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}") for i, dim in enumerate(expected_shape): if dim in dynamic_dims: @@ -172,7 +174,9 @@ class TensorSchema: elif isinstance(dim, int): if actual_shape[i] != dim: raise ValueError(f"{field_name} dim[{i}] expected " - f"{dim}, got {actual_shape[i]}") + f"{dim}, got {actual_shape[i]}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}") elif isinstance(dim, str): if dim in shape_env: if actual_shape[i] != shape_env[dim]: