[Model][VLM] Support Bee-8B Model (#27012)

Signed-off-by: uyzhang <yi.zhang.4096@gmail.com>
Signed-off-by: Yi Zhang <zhangyi970819@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yi Zhang
2025-10-20 10:31:26 +08:00
committed by GitHub
parent 8a81d776ce
commit f32bf7582e
7 changed files with 228 additions and 0 deletions

View File

@ -634,6 +634,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |

View File

@ -90,6 +90,33 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
)
# Bee-8B
def run_bee(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "Open-Bee/Bee-8B-RL"
prompts = [
(
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<image>\n{question}<|im_end|>"
f"<|im_start|>assistant\n<think>\n"
)
for question in questions
]
engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# BLIP-2
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1708,6 +1735,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
model_example_map = {
"aria": run_aria,
"aya_vision": run_aya_vision,
"bee": run_bee,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"dots_ocr": run_dots_ocr,

View File

@ -107,6 +107,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_bee(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Open-Bee/Bee-8B-RL"
engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
max_num_seqs=16,
limit_mm_per_prompt={"image": len(image_urls)},
trust_remote_code=True,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "CohereLabs/command-a-vision-07-2025"
@ -1215,6 +1250,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"aya_vision": load_aya_vision,
"bee": load_bee,
"command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,

View File

@ -326,6 +326,7 @@ def _test_processing_correctness_one(
[
"rhymes-ai/Aria",
"CohereForAI/aya-vision-8b",
"Open-Bee/Bee-8B-RL",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",

View File

@ -566,6 +566,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"),
"BeeForConditionalGeneration": _HfExamplesInfo(
"Open-Bee/Bee-8B-RL",
trust_remote_code=True,
),
"Blip2ForConditionalGeneration": _HfExamplesInfo(
"Salesforce/blip2-opt-2.7b",
extras={"6b": "Salesforce/blip2-opt-6.7b"},

View File

@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
import torch
import torch.nn as nn
from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from .llava_next import (
LlavaDummyInputsBuilder,
LlavaNextMultiModalProcessor,
LlavaNextProcessingInfo,
)
from .llava_onevision import LlavaOnevisionForConditionalGeneration
from .utils import WeightsMapper
class BeeProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(**kwargs)
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
"""Override to use correct max_num_patches from vision_aspect_ratio."""
import math
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = int(
round(original_height * (current_width / original_width), 7)
)
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = int(
round(original_width * (current_height / original_height), 7)
)
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width
newline_features = current_height
# Get max_num_patches from vision_aspect_ratio config
hf_config = self.get_hf_config()
vision_aspect_ratio = getattr(hf_config, "vision_aspect_ratio", "anyres_max_9")
max_num_patches = int(vision_aspect_ratio.replace("anyres_max_", ""))
ratio = math.sqrt(
current_height * current_width / (max_num_patches * npatches**2)
)
if ratio > 1.1:
height_factor = int(current_height // ratio)
width_factor = int(current_width // ratio)
unpadded_features = height_factor * width_factor
newline_features = height_factor
return (unpadded_features, newline_features)
class BeeDummyInputsBuilder(LlavaDummyInputsBuilder[BeeProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
image_token = "<image>"
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
}
class BeeMultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size,
config.text_config.hidden_size * 4,
bias=True,
)
self.act = GELUActivation()
self.linear_2 = nn.Linear(
config.text_config.hidden_size * 4,
config.text_config.hidden_size,
bias=True,
)
def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
image_feature = self.pre_norm(image_feature)
hidden_states = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextMultiModalProcessor,
info=BeeProcessingInfo,
dummy_inputs=BeeDummyInputsBuilder,
)
class BeeForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers
# v4.55
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
self.multi_modal_projector = BeeMultiModalProjector(config)

View File

@ -247,6 +247,7 @@ _MULTIMODAL_MODELS = {
"aya_vision",
"AyaVisionForConditionalGeneration",
),
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": (
"chameleon",