mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[New Model]Donut model (#23229)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
@ -615,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
|
311
examples/offline_inference/dolphin.py
Normal file
311
examples/offline_inference/dolphin.py
Normal file
@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import regex as re
|
||||
from PIL import Image
|
||||
from transformers import DonutProcessor
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
@dataclass
|
||||
class ImageDimensions:
|
||||
original_w: int
|
||||
original_h: int
|
||||
padded_w: int
|
||||
padded_h: int
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims: ImageDimensions
|
||||
) -> tuple[int, int, int, int]:
|
||||
try:
|
||||
top = (dims.padded_h - dims.original_h) // 2
|
||||
left = (dims.padded_w - dims.original_w) // 2
|
||||
orig_x1 = max(0, x1 - left)
|
||||
orig_y1 = max(0, y1 - top)
|
||||
orig_x2 = min(dims.original_w, x2 - left)
|
||||
orig_y2 = min(dims.original_h, y2 - top)
|
||||
if orig_x2 <= orig_x1:
|
||||
orig_x2 = min(orig_x1 + 1, dims.original_w)
|
||||
if orig_y2 <= orig_y1:
|
||||
orig_y2 = min(orig_y1 + 1, dims.original_h)
|
||||
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
|
||||
except Exception as e:
|
||||
print(f"map_to_original_coordinates error: {str(e)}")
|
||||
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2):
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)
|
||||
img_h, img_w = image.shape[:2]
|
||||
new_boxes = []
|
||||
for box in boxes:
|
||||
best_box = copy.deepcopy(box)
|
||||
|
||||
def check_edge(img, current_box, i, is_vertical):
|
||||
edge = current_box[i]
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(
|
||||
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
||||
)
|
||||
if is_vertical:
|
||||
line = binary[current_box[1] : current_box[3] + 1, edge]
|
||||
else:
|
||||
line = binary[edge, current_box[0] : current_box[2] + 1]
|
||||
transitions = np.abs(np.diff(line))
|
||||
return np.sum(transitions) / len(transitions)
|
||||
|
||||
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
|
||||
current_box = copy.deepcopy(box)
|
||||
current_box[0] = min(max(current_box[0], 0), img_w - 1)
|
||||
current_box[1] = min(max(current_box[1], 0), img_h - 1)
|
||||
current_box[2] = min(max(current_box[2], 0), img_w - 1)
|
||||
current_box[3] = min(max(current_box[3], 0), img_h - 1)
|
||||
|
||||
for i, direction, is_vertical in edges:
|
||||
best_score = check_edge(image, current_box, i, is_vertical)
|
||||
if best_score <= threshold:
|
||||
continue
|
||||
for step in range(max_pixels):
|
||||
current_box[i] += direction
|
||||
if i == 0 or i == 2:
|
||||
current_box[i] = min(max(current_box[i], 0), img_w - 1)
|
||||
else:
|
||||
current_box[i] = min(max(current_box[i], 0), img_h - 1)
|
||||
score = check_edge(image, current_box, i, is_vertical)
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_box = copy.deepcopy(current_box)
|
||||
if score <= threshold:
|
||||
break
|
||||
new_boxes.append(best_box)
|
||||
return new_boxes
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
|
||||
try:
|
||||
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
|
||||
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
|
||||
x1, y1, x2, y2 = new_boxes[0]
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
if previous_box is not None:
|
||||
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
|
||||
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
|
||||
y1 = prev_y2
|
||||
y1 = min(y1, dims.padded_h - 1)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_previous_box = [x1, y1, x2, y2]
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims
|
||||
)
|
||||
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
|
||||
except Exception as e:
|
||||
print(f"process_coordinates error: {str(e)}")
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = (
|
||||
0,
|
||||
0,
|
||||
min(100, dims.original_w),
|
||||
min(100, dims.original_h),
|
||||
)
|
||||
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]:
|
||||
try:
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
original_h, original_w = image_cv.shape[:2]
|
||||
max_size = max(original_h, original_w)
|
||||
top = (max_size - original_h) // 2
|
||||
bottom = max_size - original_h - top
|
||||
left = (max_size - original_w) // 2
|
||||
right = max_size - original_w - left
|
||||
padded_image = cv2.copyMakeBorder(
|
||||
image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)
|
||||
)
|
||||
padded_h, padded_w = padded_image.shape[:2]
|
||||
dimensions = ImageDimensions(
|
||||
original_w=original_w,
|
||||
original_h=original_h,
|
||||
padded_w=padded_w,
|
||||
padded_h=padded_h,
|
||||
)
|
||||
return padded_image, dimensions
|
||||
except Exception as e:
|
||||
print(f"prepare_image error: {str(e)}")
|
||||
h, w = image.height, image.width
|
||||
dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
|
||||
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def parse_layout_string(bbox_str):
|
||||
"""Parse layout string using regular expressions"""
|
||||
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
|
||||
matches = re.finditer(pattern, bbox_str)
|
||||
|
||||
parsed_results = []
|
||||
for match in matches:
|
||||
coords = [float(match.group(i)) for i in range(1, 5)]
|
||||
label = match.group(5).strip()
|
||||
parsed_results.append((coords, label))
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
model_id = "ByteDance/Dolphin"
|
||||
|
||||
# The input image size for Dolphin is 896 x 896,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 896 / 4 = 224 patches
|
||||
# Width: 896 / 4 = 224 patches
|
||||
|
||||
# The Dolphin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112.
|
||||
# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56.
|
||||
# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28.
|
||||
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783.
|
||||
encoder_prompt = "".join(["0"] * 783)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
processor = DonutProcessor.from_pretrained(model_id)
|
||||
llm = LLM(
|
||||
model=model_id,
|
||||
dtype="float16",
|
||||
max_num_seqs=8,
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--image_path", type=str, default=None, help="Path to a local image file."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.image_path:
|
||||
if not os.path.exists(args.image_path):
|
||||
raise FileNotFoundError(f"Error: File not found at {args.image_path}")
|
||||
image = Image.open(args.image_path).convert("RGB")
|
||||
else:
|
||||
image = fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
)
|
||||
|
||||
|
||||
prompt = "Parse the reading order of this document. "
|
||||
decoder_prompt = f"<s>{prompt}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params)
|
||||
layout_result_str = layout_outputs[0].outputs[0].text
|
||||
print(f"Layout analysis output:\n{layout_result_str}")
|
||||
|
||||
padded_image, dims = prepare_image(image)
|
||||
layout_results = parse_layout_string(layout_result_str)
|
||||
text_table_elements = []
|
||||
previous_box = None
|
||||
reading_order = 0
|
||||
for bbox_coords, label in layout_results:
|
||||
if label == "fig":
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = (
|
||||
process_coordinates(bbox_coords, padded_image, dims, previous_box)
|
||||
)
|
||||
cropped = padded_image[y1:y2, x1:x2]
|
||||
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
||||
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
||||
prompt_ocr = (
|
||||
"Parse the table in the image. "
|
||||
if label == "tab"
|
||||
else "Read text in the image. "
|
||||
)
|
||||
text_table_elements.append(
|
||||
{
|
||||
"crop": pil_crop,
|
||||
"prompt": prompt_ocr,
|
||||
"reading_order": reading_order,
|
||||
}
|
||||
)
|
||||
reading_order += 1
|
||||
except Exception as e:
|
||||
print(f"Error processing bbox (label: {label}): {str(e)}")
|
||||
continue
|
||||
|
||||
if text_table_elements:
|
||||
batch_prompts = []
|
||||
for elem in text_table_elements:
|
||||
decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(
|
||||
decoder_prompt_str, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]}
|
||||
),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
batch_prompts.append(enc_dec_prompt)
|
||||
batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params)
|
||||
for i, output in enumerate(batch_outputs):
|
||||
text_table_elements[i]["text"] = output.outputs[0].text.strip()
|
||||
|
||||
print("------" * 8)
|
||||
text_table_elements.sort(key=lambda x: x["reading_order"])
|
||||
for elem in text_table_elements:
|
||||
print(elem.get("text", ""))
|
@ -13,6 +13,7 @@ from typing import NamedTuple
|
||||
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple):
|
||||
prompts: Sequence[PromptType]
|
||||
|
||||
|
||||
def run_donut():
|
||||
engine_args = EngineArgs(
|
||||
model="naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="float16",
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 1920 / 4 = 480 patches
|
||||
# Width: 2560 / 4 = 640 patches
|
||||
# The Swin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
|
||||
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
|
||||
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
|
||||
prompts = [
|
||||
{
|
||||
"encoder_prompt": {
|
||||
"prompt": "".join(["$"] * 4799),
|
||||
"multi_modal_data": {
|
||||
"image": fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
) # noqa: E501
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_florence2():
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
@ -118,6 +163,7 @@ def run_whisper():
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"donut": run_donut,
|
||||
"florence2": run_florence2,
|
||||
"mllama": run_mllama,
|
||||
"whisper": run_whisper,
|
||||
|
@ -160,6 +160,7 @@ def _test_processing_correctness(
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||
"donut": False,
|
||||
"mllama": False,
|
||||
"ovis": False,
|
||||
"ovis2_5": False,
|
||||
@ -270,6 +271,7 @@ def _test_processing_correctness_one(
|
||||
"facebook/chameleon-7b",
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
|
@ -513,6 +513,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
),
|
||||
# [Encoder-decoder]
|
||||
"DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501
|
||||
extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||
|
@ -1822,7 +1822,7 @@ class LLMEngine:
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
return # Skip encoder length check for Whisper and Donut
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
|
398
vllm/model_executor/models/donut.py
Normal file
398
vllm/model_executor/models/donut.py
Normal file
@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, NougatProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder
|
||||
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.swin import SwinModel
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
_flatten_embeddings, flatten_bn)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
|
||||
|
||||
class MBartDecoderWrapper(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.decoder = MBartDecoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
self.model = MBartDecoderWrapper(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.model")
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = BartParallelLMHead(self.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
torch.Tensor of *decoder* input token ids.
|
||||
positions
|
||||
torch.Tensor of *decoder* position indices.
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
|
||||
return self.model(decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "final_logits_bias" in name:
|
||||
continue
|
||||
# if self.config.tie_word_embeddings and "embed_tokens" in name:
|
||||
# continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class DonutImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channel, height, width)"""
|
||||
|
||||
|
||||
class DonutProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self.info.get_hf_config(
|
||||
).encoder.image_size
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
|
||||
class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]):
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
return prompt
|
||||
|
||||
def create_decoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def pad_dummy_encoder_prompt(self) -> bool:
|
||||
return True
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
if mm_data:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs, tok_kwargs)
|
||||
if isinstance(hf_processor, NougatProcessor):
|
||||
processed_outputs["input_ids"] = processed_outputs["labels"]
|
||||
else:
|
||||
tokenizer = hf_processor.tokenizer
|
||||
processed_outputs = tokenizer(prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
tokenizer = hf_processor.tokenizer
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [pad_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.start(),
|
||||
insertion=image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor,
|
||||
info=DonutProcessingInfo,
|
||||
dummy_inputs=DonutDummyInputsBuilder)
|
||||
class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
processor_config = vllm_config.model_config.hf_image_processor_config
|
||||
|
||||
self.config = config
|
||||
self.vision_config = config.encoder
|
||||
self.processor_config = processor_config
|
||||
self.encoder = SwinModel(config=config.encoder)
|
||||
|
||||
self.decoder = DonutLanguageForConditionalGeneration(
|
||||
vllm_config=vllm_config.with_hf_config(config.decoder),
|
||||
prefix=f"{prefix}.decoder",
|
||||
)
|
||||
self.pad_token_id = config.pad_token_id
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
|
||||
# size = self.processor_config["size"]
|
||||
h, w = self.config.encoder.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per batch "
|
||||
f"is {expected_dims}. You supplied {actual_dims}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
pixel_values: Optional[Union[list[list[torch.Tensor]],
|
||||
list[torch.Tensor],
|
||||
torch.Tensor]] = kwargs.pop(
|
||||
"pixel_values", None)
|
||||
image_embeds: Optional[Union[list[list[torch.Tensor]],
|
||||
list[torch.Tensor],
|
||||
torch.Tensor]] = kwargs.pop(
|
||||
"image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Both pixel values and image embeds are provided.")
|
||||
|
||||
if pixel_values is not None:
|
||||
return DonutImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: DonutImagePixelInputs) -> torch.Tensor:
|
||||
assert image_input["type"] == "pixel_values"
|
||||
pixel_values = image_input["data"]
|
||||
dtype = next(self.encoder.parameters()).dtype
|
||||
pixel_values = pixel_values.to(dtype)
|
||||
return self.encoder(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.decoder
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
) -> torch.Tensor:
|
||||
return _flatten_embeddings(multimodal_embeddings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
torch.Tensor of *decoder* input token ids.
|
||||
positions
|
||||
torch.Tensor of *decoder* position indices.
|
||||
encoder_input_ids
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
|
||||
inputs_embeds = None
|
||||
if encoder_input_ids.numel() > 0:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
|
||||
vision_embeddings)
|
||||
|
||||
hidden_states = self.decoder(input_ids,
|
||||
positions,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.decoder.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
@ -252,6 +252,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
"DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
|
||||
|
475
vllm/model_executor/models/swin.py
Normal file
475
vllm/model_executor/models/swin.py
Normal file
@ -0,0 +1,475 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import SwinConfig
|
||||
from transformers.models.swin.modeling_swin import SwinEmbeddings
|
||||
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
|
||||
from transformers.models.swin.modeling_swin import SwinPatchMerging
|
||||
from transformers.pytorch_utils import meshgrid
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class SwinSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({dim}) is not a multiple of the number of "
|
||||
f"attention heads ({num_heads})")
|
||||
|
||||
self.num_attention_heads = num_heads
|
||||
self.attention_head_size = int(dim / num_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.window_size = (window_size if isinstance(window_size, Iterable)
|
||||
else (window_size, window_size))
|
||||
self.scale = self.attention_head_size**-0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(
|
||||
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
||||
num_heads))
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:,
|
||||
None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
|
||||
self.relative_position_index = nn.Parameter(relative_position_index,
|
||||
requires_grad=False)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=dim,
|
||||
head_size=self.attention_head_size,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def _get_rel_pos_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)]
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1], -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous()
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
|
||||
qkv_output, _ = self.qkv(hidden_states)
|
||||
query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)
|
||||
|
||||
key_layer = self.transpose_for_scores(key_layer)
|
||||
value_layer = self.transpose_for_scores(value_layer)
|
||||
query_layer = self.transpose_for_scores(query_layer)
|
||||
|
||||
attention_scores = self._get_rel_pos_bias()
|
||||
if attention_mask is not None:
|
||||
mask_shape = attention_mask.shape[0]
|
||||
attention_mask_expanded = attention_mask.view(
|
||||
1, mask_shape, 1, dim,
|
||||
dim).expand(batch_size // mask_shape, mask_shape,
|
||||
self.num_attention_heads, dim, dim)
|
||||
attention_scores = attention_scores + \
|
||||
attention_mask_expanded.unsqueeze(
|
||||
1).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(-1,
|
||||
self.num_attention_heads,
|
||||
dim, dim)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_scores,
|
||||
dropout_p=0.,
|
||||
)
|
||||
attention_probs = None
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer,
|
||||
attention_probs) if output_attentions else (context_layer, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinSelfOutput(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(
|
||||
input_size=dim,
|
||||
output_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.self = SwinSelfAttention(config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self")
|
||||
self.output = SwinSelfOutput(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
self.pruned_heads = set()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask,
|
||||
output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output, ) + self_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinIntermediate(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(dim,
|
||||
int(config.mlp_ratio * dim),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinOutput(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(int(config.mlp_ratio * dim),
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinLayer(HFSwinLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
input_resolution: int,
|
||||
num_heads: int,
|
||||
drop_path_rate: float = 0.0,
|
||||
shift_size: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path_rate,
|
||||
shift_size=shift_size,
|
||||
)
|
||||
|
||||
self.attention = SwinAttention(config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=self.window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.intermediate = SwinIntermediate(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
self.output = SwinOutput(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
|
||||
class SwinStage(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
input_resolution: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
drop_path: list[float],
|
||||
downsample: Optional[SwinPatchMerging] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinLayer(config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path[layer_idx],
|
||||
shift_size=0 if
|
||||
(layer_idx % 2 == 0) else config.window_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution,
|
||||
dim=dim,
|
||||
norm_layer=nn.LayerNorm)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.pointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
always_partition: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
height, width = input_dimensions
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions,
|
||||
layer_head_mask, output_attentions,
|
||||
always_partition)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width +
|
||||
1) // 2
|
||||
output_dimensions = (height, width, height_downsampled,
|
||||
width_downsampled)
|
||||
hidden_states = self.downsample(hidden_states_before_downsampling,
|
||||
input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (hidden_states, hidden_states_before_downsampling,
|
||||
output_dimensions)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
return stage_outputs
|
||||
|
||||
|
||||
class SwinEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
grid_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = len(config.depths)
|
||||
self.config = config
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(
|
||||
0, config.drop_path_rate, sum(config.depths), device="cpu")
|
||||
]
|
||||
self.layers = nn.ModuleList([
|
||||
SwinStage(config=config,
|
||||
dim=int(config.embed_dim * 2**layer_idx),
|
||||
input_resolution=(grid_size[0] // (2**layer_idx),
|
||||
grid_size[1] // (2**layer_idx)),
|
||||
depth=config.depths[layer_idx],
|
||||
num_heads=config.num_heads[layer_idx],
|
||||
drop_path=dpr[sum(config.depths[:layer_idx]
|
||||
):sum(config.depths[:layer_idx + 1])],
|
||||
downsample=SwinPatchMerging if
|
||||
(layer_idx < self.num_layers - 1) else None,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(self.num_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
always_partition: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions,
|
||||
layer_head_mask, output_attentions,
|
||||
always_partition)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[2]
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinModel(nn.Module):
|
||||
config_class: SwinConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2**(self.num_layers - 1))
|
||||
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(config,
|
||||
self.embeddings.patch_grid,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv", "query", "q"),
|
||||
("qkv", "key", "k"),
|
||||
("qkv", "value", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
# NOTE: Whisper allows total_len > seq_len.
|
||||
# NOTE: Whisper and Donut allows total_len > seq_len.
|
||||
elif total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
# `max_num_batched_tokens` is defined by `SchedulerConfig`
|
||||
logger.warning_once(
|
||||
|
@ -389,7 +389,7 @@ class Processor:
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
return # Skip encoder length check for Whisper and Donut
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
|
Reference in New Issue
Block a user