[Model] Adding support for MiniCPM-V (#4087)

This commit is contained in:
Alphi
2024-07-25 11:59:30 +08:00
committed by GitHub
parent 5689e256ba
commit 9e169a4c61
11 changed files with 942 additions and 18 deletions

View File

@ -40,6 +40,8 @@ Registry
Base Classes
------------
.. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins

View File

@ -206,6 +206,10 @@ Vision Language Models
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`

View File

@ -0,0 +1,53 @@
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
# 2.0
# MODEL_NAME = "HwwwH/MiniCPM-V-2"
# 2.5
MODEL_NAME = "openbmb/MiniCPM-Llama3-V-2_5"
image = ImageAsset("stop_sign").pil_image.convert("RGB")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
llm = LLM(model=MODEL_NAME,
gpu_memory_utilization=1,
trust_remote_code=True,
max_model_len=4096)
messages = [{
'role':
'user',
'content':
'(<image>./</image>)\n' + "What's the content of the image?"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# 2.0
# stop_token_ids = [tokenizer.eos_id]
# 2.5
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
sampling_params = SamplingParams(
stop_token_ids=stop_token_ids,
# temperature=0.7,
# top_p=0.8,
# top_k=100,
# seed=3472,
max_tokens=1024,
# min_tokens=150,
temperature=0,
use_beam_search=True,
# length_penalty=1.2,
best_of=3)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
print(outputs[0].outputs[0].text)

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding)
AutoTokenizer, BatchEncoding, BatchFeature)
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
class HfRunner:
@ -339,7 +339,6 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate(
**self.wrap_device(inputs),
@ -381,7 +380,7 @@ class HfRunner:
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
@ -514,10 +513,12 @@ class VllmRunner:
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)

View File

@ -0,0 +1,163 @@
from collections import UserDict
from typing import List, Optional, Tuple, Type
import pytest
import torch
import torch.types
from transformers import BatchFeature
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat's the content of the image?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
"cherry_blossom":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
})
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
output_str = output_str.split("<|eot_id|>")[0]
return output_ids, output_str, out_logprobs
target_dtype = "half"
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images,
stop_token_ids=stop_token_ids)
for prompts, vllm_images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
tokenizer=tokenizer)
for prompts, hf_images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=[
trunc_hf_output(hf_output) for hf_output in hf_outputs
],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -50,6 +50,7 @@ _GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),

View File

@ -418,9 +418,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
input_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
input_embeds)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -463,10 +463,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, input_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -0,0 +1,682 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers.configuration_utils import PretrainedConfig
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
# tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim,
grid_size,
cls_token=False,
version=2.0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if version == 2.0:
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == 2.0:
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
if version == 2.0:
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)
def __init__(self,
num_queries,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=default_norm_layer,
adaptive=False,
max_size=(70, 70),
version=2.0):
super().__init__()
self.version = version
if self.version == 2.0:
self.num_queries = grid_size**2
else:
self.num_queries = num_queries
self.max_size = max_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.adaptive = adaptive
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
if self.version == 2.0:
self.pos_embed = nn.Parameter(
torch.from_numpy(
get_2d_sincos_pos_embed(
embed_dim, grid_size,
version=self.version)).float()).requires_grad_(False)
else:
self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _set_2d_pos_cache(self, max_size, device='cpu'):
pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(self.embed_dim,
max_size,
version=self.version)).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes, device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [
max(max_h, self.max_size[0]),
max(max_w, self.max_size[1])
]
self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_2_5(self, x, tgt_sizes=None):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = torch.max(patch_len)
key_padding_mask = torch.zeros((bs, max_patch_len),
dtype=torch.bool,
device=device)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i]
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
(tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i]:] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
batch_first=True,
padding_value=0.0).permute(
1, 0,
2) # BLD => L * B * D
x = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def forward_2(self, x, tgt_sizes=None, attn_mask=None):
if self.adaptive:
pos_embed = torch.Tensor(
get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes)).float().to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
def forward(self, x, tgt_sizes=None, attn_mask=None):
if self.version == 2.0:
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
else:
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
return getattr(hf_config, "query_num", 64)
def dummy_seq_data_for_minicpmv(seq_len: int):
token_ids = [0] * seq_len
return SequenceData(token_ids)
def dummy_image_for_minicpmv(hf_config):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PretrainedConfig)
# image_feature_size = get_max_minicpmv_image_tokens(ctx)
seq_data = dummy_seq_data_for_minicpmv(seq_len)
mm_data = dummy_image_for_minicpmv(hf_config)
return seq_data, mm_data
def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
prompt = llm_inputs.get("prompt")
if prompt is None:
token_ids = llm_inputs.get("prompt_token_ids")
prompt = tokenizer.decode(token_ids)
image_processor = cached_get_image_processor(model_config.tokenizer)
pattern = "(<image>./</image>)"
image = multi_modal_data["image"]
image_tags = re.findall(pattern, prompt)
assert len(image_tags) <= 1
text_chunks = prompt.split(pattern)
new_prompt = text_chunks[0] \
+ image_processor.get_slice_image_placeholder(image.size) \
+ text_chunks[1]
new_token_ids = tokenizer.encode(new_prompt)
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
class MiniCPMV(nn.Module, SupportsVision):
def __init__(
self,
config,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.version = float(self.config.version)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \
else self.vpm.embeddings.embed_dim
self.embed_dim = self.llm.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype)
self.sampler = Sampler()
def init_llm(self, config, cache_config, quant_config):
if self.version == 2.0:
return MiniCPMForCausalLM(config,
cache_config=cache_config,
quant_config=quant_config)
else:
return LlamaForCausalLM(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self):
if self.version == 2.0:
try:
import timm
except ImportError:
raise ImportError(
'Please install timm==0.9.10') from ImportError
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
model = timm.create_model('vit_so400m_patch14_siglip_384.webli',
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True)
torch.set_default_dtype(default_dtype)
if isinstance(model, timm.models.VisionTransformer
) and model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
else:
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim, vision_dim):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
if self.version == 2.0:
resampler = Resampler(grid_size=int(
math.sqrt(self.config.query_num)),
num_queries=None,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
version=self.version)
else:
resampler = Resampler(num_queries=self.config.query_num,
grid_size=None,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
version=self.version)
torch.set_default_dtype(default_dtype)
return resampler
def get_vision_embedding(self,
pixel_values,
patch_attn_mask=None,
tgt_sizes=None,
version=2.0):
if version == 2.0:
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
# V2.0 start
H, W = pixel_value[0].shape[-2:]
tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]))
# V2.0 end
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if hasattr(self.vpm, 'num_prefix_tokens'
) and self.vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
else:
vision_embedding = self.vpm(
pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
def get_image_bounds(self, input_ids):
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
im_start_token_id = tokenizer.im_start_id
im_end_token_id = tokenizer.im_end_id
image_start_tokens = torch.where(input_ids == im_start_token_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(input_ids == im_end_token_id)[0]
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return []
image_bound = torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
])
return image_bound
def get_vision_hidden_states(self, data):
if "vision_hidden_states" not in data:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
vision_hidden_states = []
if self.version == 2.0:
if pixel_values is not None and len(pixel_values) > 0:
vision_hidden_states = self.get_vision_embedding(
pixel_values)
else:
vision_hidden_states = torch.tensor([]).to(
data["input_ids"].device)
else:
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
if all_pixel_values:
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(
0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_hidden_states = self.resampler(
vision_embedding, tgt_sizes)
else: # no image
dummy_feature = []
vision_hidden_states = dummy_feature
else:
vision_hidden_states = data["vision_hidden_states"]
return vision_hidden_states
def get_embedding(self, data):
input_ids = data["input_ids"]
vision_hidden_states = self.get_vision_hidden_states(data)
if vision_hidden_states is not None and len(vision_hidden_states) > 0:
image_bounds = self.get_image_bounds(input_ids)
else:
image_bounds = []
if hasattr(self.llm.config, 'scale_emb'):
vlm_embedding = self.llm.model.embed_tokens(
input_ids) * self.llm.config.scale_emb
else:
vlm_embedding = self.llm.model.embed_tokens(input_ids)
vision_hidden_states = [
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
for i in vision_hidden_states
]
if len(vision_hidden_states) > 0 and len(image_bounds) > 0:
vision_hidden_states = torch.cat(vision_hidden_states, dim=0)
image_indices = torch.stack([
torch.arange(r[0], r[1], dtype=torch.long)
for r in image_bounds
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
return vlm_embedding, vision_hidden_states
def process_multimodal_inputs(self, inputs):
pixel_values = []
tgt_sizes = []
for b in range(len(inputs["pixel_values"])):
pixel_values += inputs["pixel_values"][b]
tgt_sizes += inputs["tgt_sizes"][b]
return {
"pixel_values": pixel_values,
"input_ids": inputs["input_ids"],
"tgt_sizes": tgt_sizes
}
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
inputs = {
"pixel_values": kwargs.pop("pixel_values", []),
"input_ids": input_ids,
"tgt_sizes": kwargs.pop("tgt_sizes", None),
}
inputs = self.process_multimodal_inputs(inputs)
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
output = self.llm(input_ids=None,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
input_embeds=vlm_embeddings)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.llm.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
# if key_to_modify in name:
# name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading = False
if "vpm" in name or 'resampler' in name:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -1,5 +1,5 @@
from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict,
MultiModalInputs, MultiModalPlugin)
MultiModalInputs, MultiModalPlugin, NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -17,6 +17,7 @@ __all__ = [
"MultiModalDataDict",
"MultiModalInputs",
"MultiModalPlugin",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]

View File

@ -2,7 +2,7 @@ import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
TypeVar, Union)
TypeVar, Union, cast)
import torch
import torch.types
@ -15,10 +15,17 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
NestedTensors = Union[List[torch.Tensor], torch.Tensor]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors = Union[List[NestedTensors], NestedTensors]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
per item in the batch.
"""
if sys.version_info < (3, 9):
@ -27,7 +34,7 @@ if sys.version_info < (3, 9):
pass
else:
class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
class _MultiModalInputsBase(UserDict[str, NestedTensors]):
pass
@ -39,19 +46,26 @@ class MultiModalInputs(_MultiModalInputsBase):
@staticmethod
def try_concat(
tensors: List[torch.Tensor],
tensors: List[NestedTensors],
*,
device: torch.types.Device,
) -> BatchedTensors:
unbatched_shape = tensors[0].shape[1:]
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t.to(device=device) for t in tensor[0]]
for tensor in tensors]
for tensor in tensors:
tensors_ = cast(List[torch.Tensor], tensors)
unbatched_shape = tensors_[0].shape[1:]
for tensor in tensors_:
if tensor.shape[1:] != unbatched_shape:
return [
tensor.squeeze(0).to(device=device) for tensor in tensors
tensor.squeeze(0).to(device=device) for tensor in tensors_
]
return torch.cat(tensors, dim=0).to(device=device)
return torch.cat(tensors_, dim=0).to(device=device)
@staticmethod
def batch(
@ -64,7 +78,7 @@ class MultiModalInputs(_MultiModalInputsBase):
keys = inputs_list[0].keys()
item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list:
if inputs.keys() != keys: