mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Model] Use merge_by_field_config
for MM models (H-L) (#26230)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_model_len=32768,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
@ -53,7 +53,7 @@ from .idefics2_vision_model import (
|
||||
# yapf: enable
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TensorSchema):
|
||||
@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||
pixel_attention_mask: torch.Tensor
|
||||
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
|
||||
dummy_inputs=Idefics3DummyInputsBuilder)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return Idefics3ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
|
||||
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_attention_mask. "
|
||||
f"Got type: {type(pixel_attention_mask)}")
|
||||
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
|
||||
return Idefics3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
||||
pixel_attention_mask=flatten_bn(pixel_attention_mask,
|
||||
concat=True),
|
||||
num_patches=flatten_bn(num_patches, concat=True),
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w
|
||||
|
@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, VideoItem)
|
||||
@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -100,8 +99,7 @@ def smart_resize(
|
||||
class KeyeImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- bnp: Batch size * Number of patches
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
|
||||
class KeyeVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- bnp: Batch size * Number of patches
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||
|
||||
|
||||
@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
||||
|
||||
|
||||
class BaseKeyeModule(nn.Module):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
prefix: str = "") -> nn.Module:
|
||||
return Projector(text_config, vision_config, quant_config, prefix)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: NestedTensors,
|
||||
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim == 5:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
elif is_list_of(mm_input, torch.Tensor):
|
||||
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
||||
for p in mm_input):
|
||||
return mm_input
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
return KeyeImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
return KeyeImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos,
|
||||
"video pixel values",
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
return KeyeVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
return KeyeVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, VideoItem)
|
||||
@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
|
||||
class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- bnp: Batch size * Number of patches
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
|
||||
class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- bnp: Batch size * Number of patches
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||
|
||||
num_frames: torch.Tensor
|
||||
@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
self.merge_size = config.vision_config.spatial_merge_size
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
|
||||
expected_dim: int, name: str):
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == expected_dim:
|
||||
return mm_input
|
||||
elif mm_input.ndim == expected_dim + 1:
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{name} should be {expected_dim}D or "
|
||||
f"batched {expected_dim}D tensor."
|
||||
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, expected_dim=4, name="image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, expected_dim=2, name="image grid_thw")
|
||||
|
||||
return KeyeVL1_5ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, expected_dim=2, name="image embeds")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, expected_dim=2, name="image grid_thw")
|
||||
|
||||
return KeyeVL1_5ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos,
|
||||
expected_dim=4,
|
||||
name="video pixel values",
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, expected_dim=2, name="video grid_thw")
|
||||
|
||||
num_frames = self._validate_and_reshape_mm_tensor(
|
||||
num_frames, expected_dim=1, name="video num frames")
|
||||
|
||||
return KeyeVL1_5VideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||
num_frames=num_frames)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, expected_dim=2, name="video embeds")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, expected_dim=2, name="video grid_thw")
|
||||
|
||||
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
video_grid_thw=video_grid_thw,
|
||||
|
@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.vocab_size, logit_scale)
|
||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||
|
||||
# ref: qwen2_vl.py
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
||||
# image input type must be pixel values now
|
||||
@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_grid_hws = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_hws, "image grid hws")
|
||||
# pixel_values may have complex shapes
|
||||
num_channels = 3
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
if isinstance(pixel_values, list):
|
||||
pixel_values = torch.cat([
|
||||
x.reshape(-1, num_channels, patch_size, patch_size)
|
||||
for x in pixel_values
|
||||
])
|
||||
else:
|
||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||
patch_size)
|
||||
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
||||
|
||||
return KimiVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
|
@ -164,7 +164,9 @@ class TensorSchema:
|
||||
|
||||
if len(actual_shape) != len(expected_shape):
|
||||
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
||||
f"but expected {len(expected_shape)}")
|
||||
f"but expected {len(expected_shape)}. "
|
||||
f"Expected shape: {expected_shape}, "
|
||||
f"but got {actual_shape}")
|
||||
|
||||
for i, dim in enumerate(expected_shape):
|
||||
if dim in dynamic_dims:
|
||||
@ -172,7 +174,9 @@ class TensorSchema:
|
||||
elif isinstance(dim, int):
|
||||
if actual_shape[i] != dim:
|
||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||
f"{dim}, got {actual_shape[i]}")
|
||||
f"{dim}, got {actual_shape[i]}. "
|
||||
f"Expected shape: {expected_shape}, "
|
||||
f"but got {actual_shape}")
|
||||
elif isinstance(dim, str):
|
||||
if dim in shape_env:
|
||||
if actual_shape[i] != shape_env[dim]:
|
||||
|
Reference in New Issue
Block a user