[New Model]Donut model (#23229)

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
汪志鹏
2025-08-24 20:52:24 +08:00
committed by GitHub
parent 5e021b4981
commit 416f05929a
11 changed files with 1240 additions and 3 deletions

View File

@ -615,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |

View File

@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import os
from dataclasses import dataclass
import cv2
import numpy as np
import regex as re
from PIL import Image
from transformers import DonutProcessor
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.multimodal.utils import fetch_image
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
@dataclass
class ImageDimensions:
original_w: int
original_h: int
padded_w: int
padded_h: int
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def map_to_original_coordinates(
x1, y1, x2, y2, dims: ImageDimensions
) -> tuple[int, int, int, int]:
try:
top = (dims.padded_h - dims.original_h) // 2
left = (dims.padded_w - dims.original_w) // 2
orig_x1 = max(0, x1 - left)
orig_y1 = max(0, y1 - top)
orig_x2 = min(dims.original_w, x2 - left)
orig_y2 = min(dims.original_h, y2 - top)
if orig_x2 <= orig_x1:
orig_x2 = min(orig_x1 + 1, dims.original_w)
if orig_y2 <= orig_y1:
orig_y2 = min(orig_y1 + 1, dims.original_h)
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
except Exception as e:
print(f"map_to_original_coordinates error: {str(e)}")
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2):
if isinstance(image, str):
image = cv2.imread(image)
img_h, img_w = image.shape[:2]
new_boxes = []
for box in boxes:
best_box = copy.deepcopy(box)
def check_edge(img, current_box, i, is_vertical):
edge = current_box[i]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
if is_vertical:
line = binary[current_box[1] : current_box[3] + 1, edge]
else:
line = binary[edge, current_box[0] : current_box[2] + 1]
transitions = np.abs(np.diff(line))
return np.sum(transitions) / len(transitions)
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
current_box = copy.deepcopy(box)
current_box[0] = min(max(current_box[0], 0), img_w - 1)
current_box[1] = min(max(current_box[1], 0), img_h - 1)
current_box[2] = min(max(current_box[2], 0), img_w - 1)
current_box[3] = min(max(current_box[3], 0), img_h - 1)
for i, direction, is_vertical in edges:
best_score = check_edge(image, current_box, i, is_vertical)
if best_score <= threshold:
continue
for step in range(max_pixels):
current_box[i] += direction
if i == 0 or i == 2:
current_box[i] = min(max(current_box[i], 0), img_w - 1)
else:
current_box[i] = min(max(current_box[i], 0), img_h - 1)
score = check_edge(image, current_box, i, is_vertical)
if score < best_score:
best_score = score
best_box = copy.deepcopy(current_box)
if score <= threshold:
break
new_boxes.append(best_box)
return new_boxes
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
try:
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
x1, y1, x2, y2 = (
max(0, min(x1, dims.padded_w - 1)),
max(0, min(y1, dims.padded_h - 1)),
max(0, min(x2, dims.padded_w)),
max(0, min(y2, dims.padded_h)),
)
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
x1, y1, x2, y2 = new_boxes[0]
x1, y1, x2, y2 = (
max(0, min(x1, dims.padded_w - 1)),
max(0, min(y1, dims.padded_h - 1)),
max(0, min(x2, dims.padded_w)),
max(0, min(y2, dims.padded_h)),
)
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
if previous_box is not None:
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
y1 = prev_y2
y1 = min(y1, dims.padded_h - 1)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
new_previous_box = [x1, y1, x2, y2]
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
x1, y1, x2, y2, dims
)
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
except Exception as e:
print(f"process_coordinates error: {str(e)}")
orig_x1, orig_y1, orig_x2, orig_y2 = (
0,
0,
min(100, dims.original_w),
min(100, dims.original_h),
)
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]:
try:
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
original_h, original_w = image_cv.shape[:2]
max_size = max(original_h, original_w)
top = (max_size - original_h) // 2
bottom = max_size - original_h - top
left = (max_size - original_w) // 2
right = max_size - original_w - left
padded_image = cv2.copyMakeBorder(
image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
padded_h, padded_w = padded_image.shape[:2]
dimensions = ImageDimensions(
original_w=original_w,
original_h=original_h,
padded_w=padded_w,
padded_h=padded_h,
)
return padded_image, dimensions
except Exception as e:
print(f"prepare_image error: {str(e)}")
h, w = image.height, image.width
dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def parse_layout_string(bbox_str):
"""Parse layout string using regular expressions"""
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
matches = re.finditer(pattern, bbox_str)
parsed_results = []
for match in matches:
coords = [float(match.group(i)) for i in range(1, 5)]
label = match.group(5).strip()
parsed_results.append((coords, label))
return parsed_results
model_id = "ByteDance/Dolphin"
# The input image size for Dolphin is 896 x 896,
# and the patch_size is 4 x 4.
# Therefore, the initial number of patches is:
# Height: 896 / 4 = 224 patches
# Width: 896 / 4 = 224 patches
# The Dolphin model uses a staged downsampling approach,
# defined by the "depths": [2, 2, 14, 2] configuration.
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
# which halves the feature map's dimensions (dividing both height and width by 2).
# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112.
# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56.
# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28.
# Because vLLM needs to fill the image features with an encoder_prompt,
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783.
encoder_prompt = "".join(["0"] * 783)
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=2048,
)
processor = DonutProcessor.from_pretrained(model_id)
llm = LLM(
model=model_id,
dtype="float16",
max_num_seqs=8,
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_path", type=str, default=None, help="Path to a local image file."
)
args = parser.parse_args()
if args.image_path:
if not os.path.exists(args.image_path):
raise FileNotFoundError(f"Error: File not found at {args.image_path}")
image = Image.open(args.image_path).convert("RGB")
else:
image = fetch_image(
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
)
prompt = "Parse the reading order of this document. "
decoder_prompt = f"<s>{prompt}<Answer/>"
decoder_prompt_tokens = TokensPrompt(
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
"input_ids"
]
)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
decoder_prompt=decoder_prompt_tokens,
)
layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params)
layout_result_str = layout_outputs[0].outputs[0].text
print(f"Layout analysis output:\n{layout_result_str}")
padded_image, dims = prepare_image(image)
layout_results = parse_layout_string(layout_result_str)
text_table_elements = []
previous_box = None
reading_order = 0
for bbox_coords, label in layout_results:
if label == "fig":
continue
try:
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = (
process_coordinates(bbox_coords, padded_image, dims, previous_box)
)
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
prompt_ocr = (
"Parse the table in the image. "
if label == "tab"
else "Read text in the image. "
)
text_table_elements.append(
{
"crop": pil_crop,
"prompt": prompt_ocr,
"reading_order": reading_order,
}
)
reading_order += 1
except Exception as e:
print(f"Error processing bbox (label: {label}): {str(e)}")
continue
if text_table_elements:
batch_prompts = []
for elem in text_table_elements:
decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>"
decoder_prompt_tokens = TokensPrompt(
prompt_token_ids=processor.tokenizer(
decoder_prompt_str, add_special_tokens=False
)["input_ids"]
)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]}
),
decoder_prompt=decoder_prompt_tokens,
)
batch_prompts.append(enc_dec_prompt)
batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params)
for i, output in enumerate(batch_outputs):
text_table_elements[i]["text"] = output.outputs[0].text.strip()
print("------" * 8)
text_table_elements.sort(key=lambda x: x["reading_order"])
for elem in text_table_elements:
print(elem.get("text", ""))

View File

@ -13,6 +13,7 @@ from typing import NamedTuple
from vllm import LLM, EngineArgs, PromptType, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple):
prompts: Sequence[PromptType]
def run_donut():
engine_args = EngineArgs(
model="naver-clova-ix/donut-base-finetuned-docvqa",
max_num_seqs=2,
limit_mm_per_prompt={"image": 1},
dtype="float16",
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
)
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
# and the patch_size is 4 x 4.
# Therefore, the initial number of patches is:
# Height: 1920 / 4 = 480 patches
# Width: 2560 / 4 = 640 patches
# The Swin model uses a staged downsampling approach,
# defined by the "depths": [2, 2, 14, 2] configuration.
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
# which halves the feature map's dimensions (dividing both height and width by 2).
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
# Because vLLM needs to fill the image features with an encoder_prompt,
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
prompts = [
{
"encoder_prompt": {
"prompt": "".join(["$"] * 4799),
"multi_modal_data": {
"image": fetch_image(
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
) # noqa: E501
},
},
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_florence2():
engine_args = EngineArgs(
model="microsoft/Florence-2-large",
@ -118,6 +163,7 @@ def run_whisper():
model_example_map = {
"donut": run_donut,
"florence2": run_florence2,
"mllama": run_mllama,
"whisper": run_whisper,

View File

@ -160,6 +160,7 @@ def _test_processing_correctness(
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"donut": False,
"mllama": False,
"ovis": False,
"ovis2_5": False,
@ -270,6 +271,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"naver-clova-ix/donut-base-finetuned-docvqa",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",

View File

@ -513,6 +513,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online=False,
),
# [Encoder-decoder]
"DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501
hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501
extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501

View File

@ -1822,7 +1822,7 @@ class LLMEngine:
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
return # Skip encoder length check for Whisper and Donut
if model_config.is_multimodal_model:
suggestion = (

View File

@ -0,0 +1,398 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
from transformers import BatchFeature, NougatProcessor
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.swin import SwinModel
from vllm.model_executor.models.utils import (AutoWeightsLoader,
_flatten_embeddings, flatten_bn)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
class MBartDecoderWrapper(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = MBartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = MBartDecoderWrapper(vllm_config=vllm_config,
prefix=f"{prefix}.model")
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
Returns:
Output torch.Tensor
"""
return self.model(decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "final_logits_bias" in name:
continue
# if self.config.tie_word_embeddings and "embed_tokens" in name:
# continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DonutImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channel, height, width)"""
class DonutProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self):
return self.ctx.get_hf_processor()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_num_image_tokens(self) -> int:
return 1
class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_hf_config(
).encoder.image_size
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]):
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.info.get_hf_processor()
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs)
if isinstance(hf_processor, NougatProcessor):
processed_outputs["input_ids"] = processed_outputs["labels"]
else:
tokenizer = hf_processor.tokenizer
processed_outputs = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
pad_token_id = tokenizer.pad_token_id
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.start(),
insertion=image_tokens,
)
]
@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor,
info=DonutProcessingInfo,
dummy_inputs=DonutDummyInputsBuilder)
class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
processor_config = vllm_config.model_config.hf_image_processor_config
self.config = config
self.vision_config = config.encoder
self.processor_config = processor_config
self.encoder = SwinModel(config=config.encoder)
self.decoder = DonutLanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.decoder),
prefix=f"{prefix}.decoder",
)
self.pad_token_id = config.pad_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
# size = self.processor_config["size"]
h, w = self.config.encoder.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_dims}. You supplied {actual_dims}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
return DonutImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: DonutImagePixelInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
pixel_values = image_input["data"]
dtype = next(self.encoder.parameters()).dtype
pixel_values = pixel_values.to(dtype)
return self.encoder(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
) -> torch.Tensor:
return _flatten_embeddings(multimodal_embeddings)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
inputs_embeds = None
if encoder_input_ids.numel() > 0:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
vision_embeddings)
hidden_states = self.decoder(input_ids,
positions,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.decoder.compute_logits(hidden_states, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -252,6 +252,7 @@ _MULTIMODAL_MODELS = {
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder]
"DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501

View File

@ -0,0 +1,475 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from transformers import SwinConfig
from transformers.models.swin.modeling_swin import SwinEmbeddings
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
from transformers.models.swin.modeling_swin import SwinPatchMerging
from transformers.pytorch_utils import meshgrid
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
class SwinSelfAttention(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of "
f"attention heads ({num_heads})")
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (window_size if isinstance(window_size, Iterable)
else (window_size, window_size))
self.scale = self.attention_head_size**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
num_heads))
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:,
None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.relative_position_index = nn.Parameter(relative_position_index,
requires_grad=False)
self.qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.attention_head_size,
total_num_heads=self.num_attention_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
qkv_output, _ = self.qkv(hidden_states)
query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
query_layer = self.transpose_for_scores(query_layer)
attention_scores = self._get_rel_pos_bias()
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_mask_expanded = attention_mask.view(
1, mask_shape, 1, dim,
dim).expand(batch_size // mask_shape, mask_shape,
self.num_attention_heads, dim, dim)
attention_scores = attention_scores + \
attention_mask_expanded.unsqueeze(
1).unsqueeze(0)
attention_scores = attention_scores.view(-1,
self.num_attention_heads,
dim, dim)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_scores,
dropout_p=0.,
)
attention_probs = None
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer,
attention_probs) if output_attentions else (context_layer, )
return outputs
class SwinSelfOutput(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = RowParallelLinear(
input_size=dim,
output_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
def forward(self, hidden_states: torch.Tensor,
input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinAttention(nn.Module):
def __init__(self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.self = SwinSelfAttention(config,
dim,
num_heads,
window_size,
quant_config=quant_config,
prefix=f"{prefix}.self")
self.output = SwinSelfOutput(config,
dim,
quant_config=quant_config,
prefix=f"{prefix}.output")
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask,
output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output, ) + self_outputs[1:]
return outputs
class SwinIntermediate(nn.Module):
def __init__(self,
config: SwinConfig,
dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.dense = ColumnParallelLinear(dim,
int(config.mlp_ratio * dim),
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class SwinOutput(nn.Module):
def __init__(self,
config: SwinConfig,
dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.dense = RowParallelLinear(int(config.mlp_ratio * dim),
dim,
quant_config=quant_config,
prefix=f"{prefix}.dense")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinLayer(HFSwinLayer):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
num_heads: int,
drop_path_rate: float = 0.0,
shift_size: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path_rate,
shift_size=shift_size,
)
self.attention = SwinAttention(config,
dim,
num_heads,
window_size=self.window_size,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.intermediate = SwinIntermediate(config,
dim,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = SwinOutput(config,
dim,
quant_config=quant_config,
prefix=f"{prefix}.output")
class SwinStage(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
depth: int,
num_heads: int,
drop_path: list[float],
downsample: Optional[SwinPatchMerging] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList([
SwinLayer(config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path[layer_idx],
shift_size=0 if
(layer_idx % 2 == 0) else config.window_size // 2,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution,
dim=dim,
norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, input_dimensions,
layer_head_mask, output_attentions,
always_partition)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width +
1) // 2
output_dimensions = (height, width, height_downsampled,
width_downsampled)
hidden_states = self.downsample(hidden_states_before_downsampling,
input_dimensions)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, hidden_states_before_downsampling,
output_dimensions)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
class SwinEncoder(nn.Module):
def __init__(
self,
config: SwinConfig,
grid_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [
x.item() for x in torch.linspace(
0, config.drop_path_rate, sum(config.depths), device="cpu")
]
self.layers = nn.ModuleList([
SwinStage(config=config,
dim=int(config.embed_dim * 2**layer_idx),
input_resolution=(grid_size[0] // (2**layer_idx),
grid_size[1] // (2**layer_idx)),
depth=config.depths[layer_idx],
num_heads=config.num_heads[layer_idx],
drop_path=dpr[sum(config.depths[:layer_idx]
):sum(config.depths[:layer_idx + 1])],
downsample=SwinPatchMerging if
(layer_idx < self.num_layers - 1) else None,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(self.num_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
always_partition: Optional[bool] = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, input_dimensions,
layer_head_mask, output_attentions,
always_partition)
hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
return hidden_states
class SwinModel(nn.Module):
config_class: SwinConfig
def __init__(
self,
config: SwinConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2**(self.num_layers - 1))
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config,
self.embeddings.patch_grid,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv", "query", "q"),
("qkv", "key", "k"),
("qkv", "value", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]):
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
# NOTE: Whisper and Donut allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(

View File

@ -389,7 +389,7 @@ class Processor:
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
return # Skip encoder length check for Whisper and Donut
if model_config.is_multimodal_model:
suggestion = (