mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model][VLM] Support Bee-8B Model (#27012)
Signed-off-by: uyzhang <yi.zhang.4096@gmail.com> Signed-off-by: Yi Zhang <zhangyi970819@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@ -634,6 +634,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ |
|
||||
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |
|
||||
|
@ -90,6 +90,33 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Bee-8B
|
||||
def run_bee(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "Open-Bee/Bee-8B-RL"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<image>\n{question}<|im_end|>"
|
||||
f"<|im_start|>assistant\n<think>\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1708,6 +1735,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"aya_vision": run_aya_vision,
|
||||
"bee": run_bee,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
"dots_ocr": run_dots_ocr,
|
||||
|
@ -107,6 +107,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_bee(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Open-Bee/Bee-8B-RL"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "CohereLabs/command-a-vision-07-2025"
|
||||
|
||||
@ -1215,6 +1250,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"aya_vision": load_aya_vision,
|
||||
"bee": load_bee,
|
||||
"command_a_vision": load_command_a_vision,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"gemma3": load_gemma3,
|
||||
|
@ -326,6 +326,7 @@ def _test_processing_correctness_one(
|
||||
[
|
||||
"rhymes-ai/Aria",
|
||||
"CohereForAI/aya-vision-8b",
|
||||
"Open-Bee/Bee-8B-RL",
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
"facebook/chameleon-7b",
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
|
@ -566,6 +566,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"),
|
||||
"BeeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Open-Bee/Bee-8B-RL",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo(
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
extras={"6b": "Salesforce/blip2-opt-6.7b"},
|
||||
|
157
vllm/model_executor/models/bee.py
Normal file
157
vllm/model_executor/models/bee.py
Normal file
@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
|
||||
from .llava_next import (
|
||||
LlavaDummyInputsBuilder,
|
||||
LlavaNextMultiModalProcessor,
|
||||
LlavaNextProcessingInfo,
|
||||
)
|
||||
from .llava_onevision import LlavaOnevisionForConditionalGeneration
|
||||
from .utils import WeightsMapper
|
||||
|
||||
|
||||
class BeeProcessingInfo(LlavaNextProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Override to use correct max_num_patches from vision_aspect_ratio."""
|
||||
import math
|
||||
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(
|
||||
round(original_height * (current_width / original_width), 7)
|
||||
)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = int(
|
||||
round(original_width * (current_height / original_height), 7)
|
||||
)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
# Get max_num_patches from vision_aspect_ratio config
|
||||
hf_config = self.get_hf_config()
|
||||
vision_aspect_ratio = getattr(hf_config, "vision_aspect_ratio", "anyres_max_9")
|
||||
max_num_patches = int(vision_aspect_ratio.replace("anyres_max_", ""))
|
||||
|
||||
ratio = math.sqrt(
|
||||
current_height * current_width / (max_num_patches * npatches**2)
|
||||
)
|
||||
if ratio > 1.1:
|
||||
height_factor = int(current_height // ratio)
|
||||
width_factor = int(current_width // ratio)
|
||||
unpadded_features = height_factor * width_factor
|
||||
newline_features = height_factor
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
class BeeDummyInputsBuilder(LlavaDummyInputsBuilder[BeeProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
image_token = "<image>"
|
||||
|
||||
return image_token * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class BeeMultiModalProjector(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size,
|
||||
config.text_config.hidden_size * 4,
|
||||
bias=True,
|
||||
)
|
||||
self.act = GELUActivation()
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size * 4,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
|
||||
image_feature = self.pre_norm(image_feature)
|
||||
hidden_states = self.linear_1(image_feature)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaNextMultiModalProcessor,
|
||||
info=BeeProcessingInfo,
|
||||
dummy_inputs=BeeDummyInputsBuilder,
|
||||
)
|
||||
class BeeForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers
|
||||
# v4.55
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.multi_modal_projector = BeeMultiModalProjector(config)
|
@ -247,6 +247,7 @@ _MULTIMODAL_MODELS = {
|
||||
"aya_vision",
|
||||
"AyaVisionForConditionalGeneration",
|
||||
),
|
||||
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration": (
|
||||
"chameleon",
|
||||
|
Reference in New Issue
Block a user