mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Use merge_by_field_config for MM models (O-P) (#26776)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -56,7 +56,6 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@ -70,7 +69,6 @@ from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
@ -564,6 +562,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
dummy_inputs=Phi3VDummyInputsBuilder,
|
||||
)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.wte": "embed_tokens",
|
||||
@ -631,8 +631,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
if pixel_values is not None:
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={
|
||||
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
@ -642,7 +642,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
if image_embeds is not None:
|
||||
return Phi3VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -652,19 +652,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
image_input: Phi3VImageInputs,
|
||||
) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_data = image_input["data"]
|
||||
if is_list_of(image_data, torch.Tensor):
|
||||
# it's already a list of tensors
|
||||
return image_data
|
||||
if len(image_data.shape) == 3:
|
||||
# 3D tensor
|
||||
return list(torch.unbind(image_data, dim=0))
|
||||
raise ValueError(
|
||||
"We expect batched 2D tensors; "
|
||||
"this can be either a list of 2D tensors or a single 3D tensor."
|
||||
)
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
|
||||
image_embeds = self.vision_embed_tokens(
|
||||
image_input["pixel_values"], image_input["image_sizes"]
|
||||
)
|
||||
|
@ -64,7 +64,6 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
@ -72,7 +71,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
@ -672,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
data: Annotated[
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape(
|
||||
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
@ -721,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
|
||||
|
||||
type: Literal["audio_features"]
|
||||
|
||||
data: Annotated[
|
||||
audio_features: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
|
||||
]
|
||||
@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
Implements the Phi-4-multimodal-instruct model in vLLM.
|
||||
"""
|
||||
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"qkv_proj",
|
||||
@ -1273,7 +1273,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
if audio_features is not None:
|
||||
return Phi4MMAudioFeatureInputs(
|
||||
type="audio_features", data=flatten_bn(audio_features)
|
||||
type="audio_features",
|
||||
audio_features=audio_features,
|
||||
)
|
||||
|
||||
if audio_embeds is not None:
|
||||
@ -1298,7 +1299,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["data"]
|
||||
|
||||
audio_features = audio_input["data"]
|
||||
audio_features = audio_input["audio_features"]
|
||||
# (e.g. multiple examples) and the second dim is the multi-audio dim
|
||||
# (e.g. multiple audios in the same example)
|
||||
|
||||
@ -1315,8 +1316,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Phi4MMImagePixelInputs | None:
|
||||
image_pixel_values: NestedTensors = kwargs.get("image_pixel_values")
|
||||
if image_pixel_values is None:
|
||||
pixel_values = kwargs.get("image_pixel_values")
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_sizes = kwargs.get("image_sizes")
|
||||
@ -1328,52 +1329,9 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
and num_img_tokens is not None
|
||||
), "Missing image inputs"
|
||||
|
||||
if is_list_of(image_pixel_values, torch.Tensor):
|
||||
assert all(p.dim() == 5 for p in image_pixel_values), (
|
||||
"Incorrect image inputs"
|
||||
)
|
||||
# list len is batch_size.
|
||||
# each tensor has dimension: num_img_per_example, num_hd_patches,
|
||||
# channels, height, width.
|
||||
# need to pad along num_hd_patches.
|
||||
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
|
||||
image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
|
||||
elif isinstance(image_pixel_values, torch.Tensor):
|
||||
# dimension: batch_size, num_img_per_example, num_hd_patches,
|
||||
# channels, height, width.
|
||||
# we flatten first 2 dims to make it a single large batch for
|
||||
# SigLIP Encoder.
|
||||
assert image_pixel_values.dim() == 6, "Incorrect image inputs"
|
||||
image_pixel_values = image_pixel_values.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_pixel_values inputs")
|
||||
|
||||
if isinstance(image_attention_mask, list):
|
||||
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
|
||||
elif isinstance(image_attention_mask, torch.Tensor):
|
||||
image_attention_mask = image_attention_mask.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
|
||||
if isinstance(image_sizes, list):
|
||||
image_sizes = torch.cat(image_sizes, dim=0)
|
||||
elif isinstance(image_sizes, torch.Tensor):
|
||||
image_sizes = image_sizes.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_sizes inputs")
|
||||
|
||||
if isinstance(num_img_tokens, list):
|
||||
num_img_tokens = [
|
||||
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
|
||||
]
|
||||
elif isinstance(num_img_tokens, torch.Tensor):
|
||||
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
|
||||
else:
|
||||
raise ValueError("Incorrect num_img_tokens inputs")
|
||||
|
||||
return Phi4MMImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=image_pixel_values,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
image_attention_mask=image_attention_mask,
|
||||
num_img_tokens=num_img_tokens,
|
||||
@ -1405,7 +1363,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
dtype = next(self.image_embed.parameters()).dtype
|
||||
pixel_values = image_input["data"].to(dtype)
|
||||
pixel_values = image_input["pixel_values"].to(dtype)
|
||||
image_sizes = image_input["image_sizes"]
|
||||
image_attention_mask = image_input["image_attention_mask"]
|
||||
image_embeds = self.image_embed(
|
||||
|
@ -50,13 +50,12 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .phi4mm_audio import AudioEmbedding
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
# <|endoftext10|> (see vocab.json in hf model)
|
||||
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
||||
@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
data: Annotated[
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape(
|
||||
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
|
||||
|
||||
type: Literal["audio_features"]
|
||||
|
||||
data: Annotated[
|
||||
audio_features: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
|
||||
]
|
||||
@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
Implements the Phi-4-multimodal-instruct model in vLLM.
|
||||
"""
|
||||
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"qkv_proj",
|
||||
@ -1094,7 +1095,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
if audio_features is not None:
|
||||
return Phi4MMAudioFeatureInputs(
|
||||
type="audio_features", data=flatten_bn(audio_features)
|
||||
type="audio_features",
|
||||
audio_features=audio_features,
|
||||
)
|
||||
|
||||
if audio_embeds is not None:
|
||||
@ -1119,7 +1121,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["data"]
|
||||
|
||||
audio_features = audio_input["data"]
|
||||
audio_features = audio_input["audio_features"]
|
||||
# (e.g. multiple examples) and the second dim is the multi-audio dim
|
||||
# (e.g. multiple audios in the same example)
|
||||
|
||||
@ -1136,8 +1138,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Phi4MMImagePixelInputs | None:
|
||||
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
||||
if input_image_embeds is None:
|
||||
pixel_values = kwargs.get("input_image_embeds")
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_sizes = kwargs.get("image_sizes")
|
||||
@ -1149,52 +1151,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
and num_img_tokens is not None
|
||||
), "Missing image inputs"
|
||||
|
||||
if is_list_of(input_image_embeds, torch.Tensor):
|
||||
assert all(p.dim() == 5 for p in input_image_embeds), (
|
||||
"Incorrect image inputs"
|
||||
)
|
||||
# list len is batch_size.
|
||||
# each tensor has dimension: num_img_per_example, num_hd_patches,
|
||||
# channels, height, width.
|
||||
# need to pad along num_hd_patches.
|
||||
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
|
||||
input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
|
||||
elif isinstance(input_image_embeds, torch.Tensor):
|
||||
# dimension: batch_size, num_img_per_example, num_hd_patches,
|
||||
# channels, height, width.
|
||||
# we flatten first 2 dims to make it a single large batch for
|
||||
# SigLIP Encoder.
|
||||
assert input_image_embeds.dim() == 6, "Incorrect image inputs"
|
||||
input_image_embeds = input_image_embeds.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect input_image_embeds inputs")
|
||||
|
||||
if isinstance(image_attention_mask, list):
|
||||
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
|
||||
elif isinstance(image_attention_mask, torch.Tensor):
|
||||
image_attention_mask = image_attention_mask.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
|
||||
if isinstance(image_sizes, list):
|
||||
image_sizes = torch.cat(image_sizes, dim=0)
|
||||
elif isinstance(image_sizes, torch.Tensor):
|
||||
image_sizes = image_sizes.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_sizes inputs")
|
||||
|
||||
if isinstance(num_img_tokens, list):
|
||||
num_img_tokens = [
|
||||
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
|
||||
]
|
||||
elif isinstance(num_img_tokens, torch.Tensor):
|
||||
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
|
||||
else:
|
||||
raise ValueError("Incorrect num_img_tokens inputs")
|
||||
|
||||
return Phi4MMImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=input_image_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
image_attention_mask=image_attention_mask,
|
||||
num_img_tokens=num_img_tokens,
|
||||
@ -1223,7 +1182,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
self, image_input: Phi4MMImagePixelInputs
|
||||
) -> list[torch.Tensor]:
|
||||
dtype = next(self.vision_encoder.parameters()).dtype
|
||||
pixel_values = image_input["data"].to(dtype)
|
||||
pixel_values = image_input["pixel_values"].to(dtype)
|
||||
image_sizes = image_input["image_sizes"]
|
||||
image_attention_mask = image_input["image_attention_mask"]
|
||||
image_embeds = self.vision_encoder(
|
||||
|
Reference in New Issue
Block a user