mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Add support for LightOnOCR (#26916)
Signed-off-by: Said Taghadouini <taghadouinisaid@gmail.com> Signed-off-by: Said Taghadouini <84044788+staghado@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -663,6 +663,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
|
||||
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
||||
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ |
|
||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ |
|
||||
|
@ -734,6 +734,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# LightOnOCR
|
||||
def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
"<|im_start|>system<|im_end|>\n<|im_start|>user\n<|image_pad|><|im_end|>\n<|im_start|>assistant\n"
|
||||
for _ in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="lightonai/LightOnOCR-1B",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
@ -1709,6 +1729,7 @@ model_example_map = {
|
||||
"keye_vl": run_keye_vl,
|
||||
"keye_vl1_5": run_keye_vl1_5,
|
||||
"kimi_vl": run_kimi_vl,
|
||||
"lightonocr": run_lightonocr,
|
||||
"llama4": run_llama4,
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
|
@ -652,6 +652,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
max_model_len=10240,
|
||||
|
195
vllm/model_executor/models/lightonocr.py
Normal file
195
vllm/model_executor/models/lightonocr.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (
|
||||
BatchFeature,
|
||||
PixtralVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.mistral3 import (
|
||||
Mistral3DummyInputsBuilder,
|
||||
Mistral3ForConditionalGeneration,
|
||||
Mistral3MultiModalProjector,
|
||||
Mistral3ProcessingInfo,
|
||||
_build_mistral3_info,
|
||||
init_vision_tower_for_llava,
|
||||
)
|
||||
from vllm.model_executor.models.pixtral import PixtralHFEncoderInfo
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
|
||||
_I = TypeVar("_I", bound=Mistral3ProcessingInfo)
|
||||
|
||||
|
||||
class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: LightOnOCR does not use break/end tokens, so we remove them here.
|
||||
input_ids = processed_outputs.get("input_ids")
|
||||
if input_ids is not None:
|
||||
processor = self.info.get_hf_processor()
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
break_id = vocab.get(processor.image_break_token)
|
||||
end_id = vocab.get(processor.image_end_token)
|
||||
|
||||
# create mask to remove break/end tokens
|
||||
keep_mask = ~torch.isin(
|
||||
input_ids,
|
||||
torch.tensor([break_id, end_id]),
|
||||
)
|
||||
|
||||
processed_outputs["input_ids"] = input_ids[keep_mask].unsqueeze(0)
|
||||
if "attention_mask" in processed_outputs:
|
||||
processed_outputs["attention_mask"] = processed_outputs[
|
||||
"attention_mask"
|
||||
][keep_mask].unsqueeze(0)
|
||||
|
||||
# un-pad pixel_values per-image so caches remain independent.
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
image_sizes = processed_outputs["image_sizes"]
|
||||
assert len(pixel_values) == len(image_sizes)
|
||||
processed_outputs["pixel_values"] = [
|
||||
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(hf_config)
|
||||
|
||||
def replace(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
size = images.get_image_size(item_idx)
|
||||
ncols, nrows = encoder_info.get_patch_grid_size(
|
||||
image_width=size.width, image_height=size.height
|
||||
)
|
||||
# break/end tokens are not used in LightOnOCR
|
||||
tokens = [image_token_id] * (ncols * nrows)
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image", target=[image_token_id], replacement=replace
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _build_LightOnOCR_processor(
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
):
|
||||
assert isinstance(info, Mistral3ProcessingInfo)
|
||||
return LightOnOCRMultiModalProcessor(info, dummy_inputs, cache=cache)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
_build_LightOnOCR_processor,
|
||||
info=_build_mistral3_info,
|
||||
dummy_inputs=Mistral3DummyInputsBuilder,
|
||||
)
|
||||
class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_encoder.": "vision_tower.",
|
||||
"model.vision_projection.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
patch_size=config.vision_config.patch_size,
|
||||
multimodal_projector_bias=config.multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
@ -297,6 +297,10 @@ _MULTIMODAL_MODELS = {
|
||||
),
|
||||
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||
"LightOnOCRForConditionalGeneration": (
|
||||
"lightonocr",
|
||||
"LightOnOCRForConditionalGeneration",
|
||||
),
|
||||
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
||||
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
|
Reference in New Issue
Block a user