mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Refactoring of MiniCPM-V and add MiniCPM-o-2.6 support for vLLM (#12069)
Signed-off-by: hzh <hezhihui_thu@163.com> Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com> Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Oleg Mosalov <oleg@krai.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Signed-off-by: Yida Wu <yidawu@alumni.cmu.edu> Signed-off-by: Chenguang Li <757486878@qq.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Shanshan Shen <467638484@qq.com> Signed-off-by: elijah <f1renze.142857@gmail.com> Signed-off-by: Yikun <yikunkero@gmail.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Konrad Zawora <kzawora@habana.ai> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Co-authored-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Co-authored-by: shaochangxu <85155497+shaochangxu@users.noreply.github.com> Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: sixgod <evethwillbeok@outlook.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Rafael Vasquez <rafvasq21@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Akshat Tripathi <Akshat.tripathi6568@gmail.com> Co-authored-by: Oleg Mosalov <oleg@krai.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Yangcheng Li <liyangcheng.lyc@alibaba-inc.com> Co-authored-by: Siyuan Li <94890248+liaoyanqing666@users.noreply.github.com> Co-authored-by: Concurrensee <yida.wu@amd.com> Co-authored-by: Chenguang Li <757486878@qq.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Alex Brooks <alex.brooks@ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: elijah <30852919+e1ijah1@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Konrad Zawora <kzawora@habana.ai> Co-authored-by: TJian <tunjian1996@gmail.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: maang-h <55082429+maang-h@users.noreply.github.com> Co-authored-by: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@ -693,9 +693,16 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `MiniCPMO`
|
||||
* MiniCPM-O
|
||||
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup>
|
||||
* `openbmb/MiniCPM-o-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
- * `MiniCPMV`
|
||||
* MiniCPM-V
|
||||
* T + I<sup>E+</sup>
|
||||
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
|
@ -67,7 +67,37 @@ def run_qwen2_audio(question: str, audio_count: int):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio}
|
||||
def run_minicpmo(question: str, audio_count: int):
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
llm = LLM(model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
|
||||
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
audio_placeholder = "(<audio>./</audio>)" * audio_count
|
||||
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f'{audio_placeholder}\n{question}'
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=audio_chat_template)
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"ultravox": run_ultravox,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"minicpmo": run_minicpmo
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -265,8 +265,9 @@ def run_mantis(question: str, modality: str):
|
||||
|
||||
|
||||
# MiniCPM-V
|
||||
def run_minicpmv(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
def run_minicpmv_base(question: str, modality: str, model_name):
|
||||
assert modality in ["image", "video"]
|
||||
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
|
||||
|
||||
# 2.0
|
||||
# The official repo doesn't work yet, so we need to use a fork for now
|
||||
@ -277,7 +278,15 @@ def run_minicpmv(question: str, modality: str):
|
||||
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
|
||||
|
||||
# 2.6
|
||||
model_name = "openbmb/MiniCPM-V-2_6"
|
||||
# model_name = "openbmb/MiniCPM-V-2_6"
|
||||
# o2.6
|
||||
|
||||
# modality supports
|
||||
# 2.0: image
|
||||
# 2.5: image
|
||||
# 2.6: image, video
|
||||
# o2.6: image, video, audio
|
||||
# model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
llm = LLM(
|
||||
@ -294,13 +303,18 @@ def run_minicpmv(question: str, modality: str):
|
||||
# 2.5
|
||||
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||
|
||||
# 2.6
|
||||
# 2.6 / o2.6
|
||||
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
modality_placeholder = {
|
||||
"image": "(<image>./</image>)",
|
||||
"video": "(<video>./</video>)",
|
||||
}
|
||||
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f'(<image>./</image>)\n{question}'
|
||||
'content': f'{modality_placeholder[modality]}\n{question}'
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
@ -308,6 +322,14 @@ def run_minicpmv(question: str, modality: str):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
def run_minicpmo(question: str, modality: str):
|
||||
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6")
|
||||
|
||||
|
||||
def run_minicpmv(question: str, modality: str):
|
||||
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6")
|
||||
|
||||
|
||||
# LLama 3.2
|
||||
def run_mllama(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
@ -523,6 +545,7 @@ model_example_map = {
|
||||
"llava-next-video": run_llava_next_video,
|
||||
"llava-onevision": run_llava_onevision,
|
||||
"mantis": run_mantis,
|
||||
"minicpmo": run_minicpmo,
|
||||
"minicpmv": run_minicpmv,
|
||||
"mllama": run_mllama,
|
||||
"molmo": run_molmo,
|
||||
|
@ -4,5 +4,6 @@
|
||||
# Dependencies for CPUs
|
||||
torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin"
|
||||
torch==2.5.1; platform_machine == "aarch64" or platform_system == "Darwin"
|
||||
torchaudio; platform_machine != "ppc64le" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch
|
||||
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
||||
datasets # for benchmark scripts
|
||||
|
@ -5,6 +5,7 @@
|
||||
ray[default] >= 2.9
|
||||
nvidia-ml-py >= 12.560.30 # for pynvml package
|
||||
torch == 2.5.1
|
||||
torchaudio==2.5.1
|
||||
# These must be updated alongside torch
|
||||
torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1
|
||||
|
@ -12,6 +12,8 @@ decord # required for video tests
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
vector_quantize_pytorch # required for minicpmo_26 test
|
||||
vocos # required for minicpmo_26 test
|
||||
peft
|
||||
pqdm
|
||||
ray[adag]==2.40.0
|
||||
@ -19,6 +21,7 @@ sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
timm # required for internvl test
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.5.0 # required for pixtral test
|
||||
|
@ -106,9 +106,17 @@ dnspython==2.7.0
|
||||
docutils==0.16
|
||||
# via awscli
|
||||
einops==0.8.0
|
||||
# via -r requirements-test.in
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
# via vector-quantize-pytorch
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
encodec==0.1.1
|
||||
# via vocos
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastparquet==2024.11.0
|
||||
@ -125,6 +133,8 @@ filelock==3.16.1
|
||||
# triton
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
frozendict==2.4.6
|
||||
# via einx
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
@ -261,6 +272,8 @@ numpy==1.26.4
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
@ -283,6 +296,7 @@ numpy==1.26.4
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -455,6 +469,7 @@ pyyaml==6.0.2
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
# vocos
|
||||
ray[adag]==2.40.0
|
||||
# via -r requirements-test.in
|
||||
redis==5.2.0
|
||||
@ -517,6 +532,7 @@ scipy==1.13.1
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
@ -540,7 +556,9 @@ sqlitedict==2.1.0
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
# via
|
||||
# einx
|
||||
# torch
|
||||
tabledata==1.3.3
|
||||
# via pytablewriter
|
||||
tabulate==0.9.0
|
||||
@ -568,12 +586,21 @@ torch==2.5.1
|
||||
# -r requirements-test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# encodec
|
||||
# lm-eval
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchvision==0.20.1
|
||||
# via timm
|
||||
tqdm==4.66.6
|
||||
@ -584,6 +611,7 @@ tqdm==4.66.6
|
||||
# lm-eval
|
||||
# nltk
|
||||
# peft
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
@ -615,6 +643,7 @@ typing-extensions==4.12.2
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
@ -626,6 +655,10 @@ urllib3==2.2.3
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements-test.in
|
||||
vocos==0.1.0
|
||||
# via -r requirements-test.in
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
|
@ -350,6 +350,20 @@ VLM_TEST_SETTINGS = {
|
||||
postprocess_inputs=model_utils.wrap_inputs_post_processor,
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
),
|
||||
"minicpmo_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
postprocess_inputs=model_utils.ignore_inputs_post_processor(
|
||||
"image_sizes"
|
||||
),
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_patch_hf_runner
|
||||
),
|
||||
"minicpmv_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
|
@ -497,6 +497,17 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def _generate_greedy_logprobs_limit(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
@ -152,6 +152,8 @@ def _test_processing_correctness(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
|
@ -245,7 +245,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
|
||||
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
|
||||
trust_remote_code=True),
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
|
||||
trust_remote_code=True),
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
trust_remote_code=True),
|
||||
|
@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
if model_type == "minicpmv":
|
||||
if model_type in ("minicpmo", "minicpmv"):
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||
"pixtral"):
|
||||
@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
if model_type == "qwen2_audio":
|
||||
return (f"Audio {current_count}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
||||
if model_type == "minicpmo":
|
||||
return "(<audio>./</audio>)"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
if model_type in ("minicpmo", "minicpmv"):
|
||||
return "(<video>./</video>)"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.video_token_index)
|
||||
|
811
vllm/model_executor/models/minicpmo.py
Normal file
811
vllm/model_executor/models/minicpmo.py
Normal file
@ -0,0 +1,811 @@
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
from torch import nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (ModalityData, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser,
|
||||
VideoItem)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
which is the same as image.
|
||||
Padding is used therefore `data` is `torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `data.shape[-1]`
|
||||
"""
|
||||
|
||||
audio_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
data: List[torch.Tensor]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
audio_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
"""
|
||||
|
||||
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
MiniCPMOAudioEmbeddingInputs]
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems):
|
||||
|
||||
def __init__(self, data: Dict) -> None:
|
||||
super().__init__(data, "audio")
|
||||
audio_embeds = self.data.get("audio_embeds", None)
|
||||
if audio_embeds is None:
|
||||
raise ValueError("Incorrect type of video_embeds",
|
||||
"Got type: None")
|
||||
self.data["audio_embeds"] = audio_embeds
|
||||
|
||||
def get(self, index: int) -> object:
|
||||
return self.data["audio_embeds"][index]
|
||||
|
||||
|
||||
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMOAudioEmbeddingItems(data)
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
audio_pattern = "(<audio>./</audio>)"
|
||||
|
||||
def get_supported_mm_modalities(self) -> List[str]:
|
||||
return ["image", "video", "audio"]
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"audio": self.get_max_audio_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len)
|
||||
}
|
||||
|
||||
def get_default_audio_pool_step(self) -> int:
|
||||
return 2
|
||||
|
||||
def get_default_audio_sampling_rate(self) -> int:
|
||||
return 16000
|
||||
|
||||
def get_chunk_length(self) -> int:
|
||||
return self.get_hf_config().audio_chunk_length
|
||||
|
||||
def get_max_audio_tokens_per_chunk(self) -> int:
|
||||
pool_step = self.get_default_audio_pool_step()
|
||||
fbank_feat_in_chunk = 100
|
||||
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
|
||||
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
||||
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
|
||||
|
||||
def get_max_audio_chunks_with_most_features(self) -> int:
|
||||
return 30
|
||||
|
||||
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
||||
sampling_rate = self.get_default_audio_sampling_rate()
|
||||
# exclude <audio> </audio>
|
||||
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
|
||||
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
max_audios = mm_config.limit_per_prompt.get("audio", 1)
|
||||
|
||||
# count <image_idx></image_idx> tokens
|
||||
# which are not in get_max_image_tokens
|
||||
max_image_tokens = self.get_max_image_tokens(
|
||||
) * max_images + 4 * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens(
|
||||
) * max_audios + 2 * max_audios
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens -
|
||||
max_audio_tokens)
|
||||
|
||||
num_frames = max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
return num_frames
|
||||
|
||||
|
||||
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self, seq_len: int, mm_counts: Mapping[str,
|
||||
int]) -> ProcessorInputs:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_len = self.info.get_max_audio_chunks_with_most_features() * \
|
||||
self.info.get_default_audio_sampling_rate()
|
||||
|
||||
processor_inputs = super().get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
mm_data = {
|
||||
"image":
|
||||
processor_inputs.mm_data["image"],
|
||||
"video":
|
||||
processor_inputs.mm_data["video"],
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||
|
||||
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
|
||||
audio_prompt_texts,
|
||||
mm_data=mm_data)
|
||||
|
||||
|
||||
class MiniCPMOMultiModalProcessor(
|
||||
MiniCPMVMultiModalProcessor,
|
||||
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MiniCPMOMultiModalDataParser(
|
||||
target_sr=self.info.get_default_audio_sampling_rate())
|
||||
|
||||
def get_audio_prompt_texts(self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1) -> str:
|
||||
return self.info.get_hf_processor().get_audio_placeholder(
|
||||
audio_lens, chunk_input, chunk_length)
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens = super().get_special_tokens()
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
special_tokens["audio_start_id"] = torch.tensor(
|
||||
tokenizer.audio_start_id)
|
||||
special_tokens["audio_end_id"] = torch.tensor(
|
||||
tokenizer.audio_end_id)
|
||||
return special_tokens
|
||||
|
||||
def process_audios(self, mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
|
||||
audios = mm_data.pop("audios", [])
|
||||
audio_embeds = mm_data.pop("audio_embeds", [])
|
||||
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
|
||||
audio_outputs = {
|
||||
"audio_lens": [],
|
||||
"audio_features": [],
|
||||
"audio_feature_lens": [],
|
||||
"audio_num_segments": []
|
||||
}
|
||||
for audio in audios:
|
||||
single_audio_outputs = super().call_base_hf_processor(
|
||||
prompt=self.info.audio_pattern,
|
||||
mm_data={
|
||||
"audios": audio,
|
||||
"chunk_input": True
|
||||
},
|
||||
mm_kwargs=mm_kwargs)
|
||||
audio_outputs["audio_lens"].append(len(audio))
|
||||
audio_outputs["audio_features"].append(
|
||||
single_audio_outputs["audio_features"])
|
||||
audio_outputs["audio_num_segments"].append(
|
||||
len(single_audio_outputs["audio_feature_lens"][0]))
|
||||
audio_outputs["audio_feature_lens"] += \
|
||||
single_audio_outputs["audio_feature_lens"]
|
||||
audio_outputs["audio_features"] = [
|
||||
audio_feature for single_audio_features in \
|
||||
audio_outputs["audio_features"]
|
||||
for audio_feature in single_audio_features
|
||||
]
|
||||
audio_outputs["audio_feature_lens"] = torch.cat(
|
||||
audio_outputs["audio_feature_lens"])
|
||||
elif len(audio_embeds):
|
||||
audio_outputs = {
|
||||
"audio_lens": [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
for single_audio_embeds in audio_embeds
|
||||
],
|
||||
"audio_embeds": [
|
||||
chunk_embeds for single_audio_embeds in audio_embeds
|
||||
for chunk_embeds in single_audio_embeds
|
||||
],
|
||||
"audio_num_segments": [
|
||||
len(single_audio_embeds)
|
||||
for single_audio_embeds in audio_embeds
|
||||
]
|
||||
}
|
||||
else:
|
||||
audio_outputs = {}
|
||||
return audio_outputs
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video|audio)>./</\1>\)"
|
||||
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
|
||||
|
||||
def process_mm_inputs(self, mm_data, mm_kwargs) -> object:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
"audio": self.process_audios(mm_data, mm_kwargs)
|
||||
}
|
||||
|
||||
def get_modality_num_counter(self, modality: str) -> str:
|
||||
if modality == "audio":
|
||||
return "audio_lens"
|
||||
return super().get_modality_num_counter(modality)
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> int:
|
||||
if modality == "audio":
|
||||
return inputs["audio"]["audio_num_segments"][index]
|
||||
return super().get_num_slices_by_modality(inputs, modality, index)
|
||||
|
||||
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> str:
|
||||
if modality == "audio":
|
||||
return self.get_audio_prompt_texts(
|
||||
inputs["audio"]["audio_lens"][index])
|
||||
return super().get_prompt_texts_by_modality(inputs, modality, index)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
|
||||
placeholder = {
|
||||
"image": self.info.image_pattern,
|
||||
"video": self.info.video_pattern,
|
||||
"audio": self.info.audio_pattern
|
||||
}
|
||||
|
||||
def get_replacement_minicpmv(item_idx: int, modality: str):
|
||||
if modality == "image":
|
||||
return self.get_image_prompt_texts(
|
||||
mm_items["image"].get_image_size(item_idx), item_idx)
|
||||
elif modality == "video":
|
||||
return self.get_video_prompt_texts(
|
||||
mm_items["video"].get_frame_size(item_idx),
|
||||
mm_items["video"].get_num_frames(item_idx))
|
||||
else: # audio
|
||||
if isinstance(mm_items["audio"], MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = mm_items["audio"].get(item_idx)
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
return self.get_audio_prompt_texts(audio_len)
|
||||
return self.get_audio_prompt_texts(
|
||||
len(mm_items["audio"].get(item_idx)))
|
||||
|
||||
return [
|
||||
PromptReplacement(modality=modality,
|
||||
target=placeholder[modality],
|
||||
replacement=partial(get_replacement_minicpmv,
|
||||
modality=modality))
|
||||
for modality in ("image", "video", "audio")
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
def get_slices(num_slices: List[int]) -> List[int]:
|
||||
slice_indices = [0] + list(accumulate(num_slices))
|
||||
slices = [(slice_indices[i], slice_indices[i + 1])
|
||||
for i in range(len(num_slices))]
|
||||
return [slice(*slice_item) for slice_item in slices]
|
||||
|
||||
audio_slices = get_slices(
|
||||
hf_inputs.get("audio_num_slices", torch.empty(0)))
|
||||
return dict(
|
||||
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
|
||||
audio_features=MultiModalFieldConfig.flat("audio", audio_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat(
|
||||
"audio", audio_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices))
|
||||
|
||||
|
||||
class MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(in_features=in_dim,
|
||||
out_features=out_dim,
|
||||
bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(in_features=out_dim,
|
||||
out_features=out_dim,
|
||||
bias=True)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.relu(self.linear1(audio_features))
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[
|
||||
config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
past_key_values = None
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, past_key_values = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_values,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.activation_dropout,
|
||||
training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
|
||||
outputs = (hidden_states, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMWhisperEncoderLayer(config, layer_idx=i)
|
||||
for i in range(config.encoder_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> BaseModelOutputWithPast:
|
||||
# Ignore copy
|
||||
input_features = input_features.to(dtype=self.conv1.weight.dtype,
|
||||
device=self.conv1.weight.device)
|
||||
|
||||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
|
||||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||
|
||||
embed_pos = self.embed_positions.weight
|
||||
|
||||
embed_pos = embed_pos[:inputs_embeds.shape[1], :]
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
|
||||
encoder_states = ()
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
to_drop = False
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
if dropout_probability < self.layerdrop: # skip the layer
|
||||
to_drop = True
|
||||
|
||||
# Ignore copy
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
MiniCPMOMultiModalProcessor,
|
||||
info=MiniCPMOProcessingInfo,
|
||||
dummy_inputs=MiniCPMODummyInputsBuilder)
|
||||
class MiniCPMO(MiniCPMV2_6):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "apm"))
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
model = MiniCPMWhisperEncoder(audio_config)
|
||||
audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
|
||||
self.audio_avg_pooler = \
|
||||
nn.AvgPool1d(self.config.audio_pool_step,
|
||||
stride=self.config.audio_pool_step)
|
||||
self.audio_projection_layer = \
|
||||
MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim)
|
||||
self.audio_encoder_layer = -1
|
||||
return model
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
self,
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = CPU_DEVICE,
|
||||
num_lookhead: int = 0,
|
||||
) -> torch.Tensor:
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size,
|
||||
0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
|
||||
size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
def _get_feat_extract_output_lengths(self,
|
||||
input_lengths: torch.LongTensor):
|
||||
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
|
||||
input_lengths_after_pooling = (
|
||||
input_lengths_after_cnn -
|
||||
self.config.audio_pool_step) // self.config.audio_pool_step + 1
|
||||
input_lengths_after_pooling = input_lengths_after_pooling.to(
|
||||
dtype=torch.int32)
|
||||
|
||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||
|
||||
# Copied from HF repo of MiniCPM-o-2_6,
|
||||
# designed for batched inputs and outputs
|
||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
wavforms = data.get(
|
||||
"data",
|
||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
||||
[])] # list, [[x1, x2], [y1], [z1]]
|
||||
|
||||
# exist audio
|
||||
if len(wavforms) > 0:
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
|
||||
max_seq_len, max_seq_len)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
chunk_mask = self.subsequent_chunk_mask(
|
||||
size=max_seq_len,
|
||||
chunk_size=chunk_num_frame,
|
||||
num_left_chunks=-1,
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
||||
audio_inputs: Optional[MiniCPMOAudioInputs],
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
||||
if audio_inputs["type"] == "audio_embeds":
|
||||
audio_embeddings = audio_inputs["data"]
|
||||
audio_embeddings = [
|
||||
audio_embeddings[i].to(device=device, dtype=dtype)
|
||||
for i in range(len(audio_embeddings))
|
||||
]
|
||||
else:
|
||||
audio_embeddings = self.get_audio_hidden_states(
|
||||
audio_inputs, chunk_length)[0]
|
||||
if audio_embeddings is None or len(audio_embeddings) == 0:
|
||||
return vlm_embedding
|
||||
audio_bounds = audio_inputs["audio_bounds"]
|
||||
if self.config.chunk_input:
|
||||
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
|
||||
dtype=dtype)
|
||||
audio_start_pos = 0
|
||||
for bound in audio_bounds:
|
||||
audio_len = bound[1] - bound[0]
|
||||
vlm_embedding[bound[0]:bound[1]] = audio_embs[
|
||||
audio_start_pos:audio_start_pos + audio_len, :]
|
||||
audio_start_pos += audio_len
|
||||
else:
|
||||
for embs, bound in zip(audio_embeddings, audio_bounds):
|
||||
audio_indices = torch.arange(bound[0],
|
||||
bound[1],
|
||||
dtype=torch.long).to(device)
|
||||
|
||||
if embs.shape[0] != len(audio_indices):
|
||||
raise ValueError(
|
||||
"Shape mismatch: Trying to assign embeddings "
|
||||
f"of shape {embs.shape} "
|
||||
f"to input indices of length {len(audio_indices)}")
|
||||
vlm_embedding[audio_indices] = embs.to(dtype)
|
||||
return vlm_embedding
|
||||
|
||||
def _get_audio_bounds(self, input_ids: torch.Tensor,
|
||||
audio_start_id: torch.Tensor,
|
||||
audio_end_id: torch.Tensor) -> torch.Tensor:
|
||||
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
|
||||
audio_start_tokens += 1
|
||||
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
|
||||
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
|
||||
return torch.hstack([
|
||||
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
|
||||
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
|
||||
])
|
||||
|
||||
def _parse_and_validate_audio_inputs(
|
||||
self, input_ids: torch.Tensor,
|
||||
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", [])
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
audio_start_id = kwargs.pop("audio_start_id", None)
|
||||
audio_end_id = kwargs.pop("audio_end_id", None)
|
||||
if audio_embeds is not None:
|
||||
audio_embeds = [
|
||||
audio_embeds[i][j] for i in range(len(audio_embeds))
|
||||
for j in range(len(audio_embeds[i]))
|
||||
]
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_embeds,
|
||||
type="audio_embeds")
|
||||
if len(audio_features) > 0:
|
||||
audio_features_all = [
|
||||
i.permute(1, 0) for audio_feature in audio_features
|
||||
for i in audio_feature
|
||||
]
|
||||
audio_features = torch.nn.utils.rnn.pad_sequence(
|
||||
audio_features_all, batch_first=True,
|
||||
padding_value=0.0).permute(0, 2, 1)
|
||||
audio_feature_lens = torch.cat(
|
||||
[item for item in audio_feature_lens])
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_features,
|
||||
audio_feature_lens=audio_feature_lens,
|
||||
type="audio_features")
|
||||
return None
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
image_inputs = self._parse_and_validate_image_inputs(
|
||||
input_ids, **kwargs)
|
||||
if not any("audio" in key for key in kwargs):
|
||||
return image_inputs, None
|
||||
audio_inputs = self._parse_and_validate_audio_inputs(
|
||||
input_ids, **kwargs)
|
||||
return image_inputs, audio_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
vlm_embeddings = None
|
||||
else:
|
||||
image_inputs, audio_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings, _ = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
if audio_inputs is not None:
|
||||
vlm_embeddings = self.get_embedding_with_audios(
|
||||
vlm_embeddings, audio_inputs,
|
||||
self.config.audio_chunk_length)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
)
|
||||
return output
|
File diff suppressed because it is too large
Load Diff
@ -162,6 +162,7 @@ _MULTIMODAL_MODELS = {
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
||||
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
|
||||
"MiniCPMO": ("minicpmo", "MiniCPMO"),
|
||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||
|
Reference in New Issue
Block a user