[Model] Use merge_by_field_config for MM models (Llava family) (#26280)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-06 17:45:26 +08:00
committed by GitHub
parent 391612e78b
commit 19a00eb210
9 changed files with 155 additions and 229 deletions

View File

@ -371,6 +371,115 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
@ -505,115 +614,6 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa
)
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

View File

@ -57,7 +57,6 @@ from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -507,6 +506,8 @@ def init_vision_tower_for_llava(
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -592,37 +593,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if self.config.vision_config.model_type == "pixtral":
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
pixel_values=pixel_values,
)
expected_h = expected_w = self.config.vision_config.image_size
return LlavaImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_values=pixel_values,
resolve_bindings={"h": expected_h, "w": expected_w},
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
if self.config.vision_config.model_type == "pixtral":
raise ValueError("Pixtral-HF does not support image_embeds.")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")

View File

@ -34,7 +34,6 @@ from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -222,6 +221,8 @@ class LlavaNextMultiModalProcessor(
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@ -302,21 +303,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": expected_h,
"w": expected_w,
@ -324,14 +315,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError(
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
)
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")

View File

@ -51,14 +51,13 @@ from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TensorSchema):
"""
Dimensions:
- bs: Batch size
- nv: Number of videos
- nf: Number of frames
- nc: Number of channels (3)
- bn: Batch size * number of videos
- f: Number of frames
- c: Number of channels (3)
- h: Height of each frame
- w: Width of each frame
Note that `num_frames` may be different for each batch, in which case
Note that `f` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
@ -66,9 +65,9 @@ class LlavaNextVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"] = "pixel_values_videos"
data: Annotated[
pixel_values_videos: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w"),
TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
]
@ -300,6 +299,8 @@ class LlavaNextMultiModalProjector(nn.Module):
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@ -371,7 +372,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values_videos,
pixel_values_videos=pixel_values_videos,
resolve_bindings={
"h": expected_h,
"w": expected_w,
@ -396,19 +397,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor):
# TODO: support multiple videos per input
b, num_videos, num_frames, c, h, w = video_pixels.shape
assert num_videos == 1
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, h, w)
bn, f, c, h, w = video_pixels.shape
stacked_pixels = video_pixels.view(bn * f, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels
)
embeds = stacked_embeddings.view(
b, num_frames, *stacked_embeddings.shape[1:]
)
embeds = stacked_embeddings.view(bn, f, *stacked_embeddings.shape[1:])
elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels]

View File

@ -44,7 +44,6 @@ from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -62,7 +61,7 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema):
- h: Height
- w: Width
Note that `num_videos` may be different for each batch, and 'num_frames'
Note that `f` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
"""
@ -480,6 +479,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
dummy_inputs=LlavaOnevisionDummyInputsBuilder,
)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@ -539,20 +540,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
return LlavaOnevisionImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
@ -560,14 +551,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError(
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
)
return LlavaOnevisionImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
@ -586,15 +572,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
if pixel_values_videos is None:
return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}"
)
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=flatten_bn(pixel_values_videos),
pixel_values_videos=pixel_values_videos,
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,

View File

@ -32,7 +32,6 @@ from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -180,6 +179,8 @@ class MiniMaxVL01MultiModalProcessor(
dummy_inputs=MiniMaxVL01DummyInputsBuilder,
)
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -338,32 +339,16 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
return None
if pixel_values is not None and image_sizes is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
pixel_values=pixel_values,
image_sizes=image_sizes,
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return MiniMaxVL01ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")

View File

@ -52,7 +52,6 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -424,6 +423,8 @@ def init_vision_tower_for_llava(
class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -510,15 +511,9 @@ class Mistral3ForConditionalGeneration(
if pixel_values is None and image_embeds is None:
return None
assert pixel_values is not None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
pixel_values=pixel_values,
)
def _process_image_input(

View File

@ -64,7 +64,7 @@ from vllm.transformers_utils.tokenizer import (
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .utils import init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
VisionFeatureSelectStrategy,
@ -365,6 +365,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
dummy_inputs=PixtralDummyInputsBuilder,
)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -424,7 +426,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
images=images,
)
def _process_image_input(

View File

@ -49,7 +49,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@ -404,6 +403,8 @@ def init_vision_tower_for_tarsier(
dummy_inputs=TarsierDummyInputsBuilder,
)
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -467,25 +468,15 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
return TarsierImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_values=pixel_values,
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return TarsierImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")