[Misc] Automatically resolve HF processor init kwargs (#22005)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-01 13:44:10 +08:00
committed by GitHub
parent ad57f23f6a
commit 82de9b9d46
40 changed files with 334 additions and 727 deletions

View File

@ -449,25 +449,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
)
# omni-research/Tarsier-7b
def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "omni-research/Tarsier-7b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
)
prompts = [(f"USER: <image>\n{question} ASSISTANT:") for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
model_name = "internlm/Intern-S1"
@ -1293,6 +1274,25 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
)
# omni-research/Tarsier-7b
def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "omni-research/Tarsier-7b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
)
prompts = [(f"USER: <image>\n{question} ASSISTANT:") for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
model_name = "omni-research/Tarsier2-Recap-7b"

View File

@ -4,8 +4,6 @@ from dataclasses import dataclass
from typing import Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
import vllm
from vllm.assets.image import ImageAsset
@ -185,10 +183,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
current_platform.is_rocm(),
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
)
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) < Version("4.49.0"),
reason="Qwen2.5-VL require transformers version no lower than 4.49.0",
)
def test_qwen25vl_lora(qwen25vl_lora_files):
"""Test Qwen 2.5 VL model with LoRA"""
config = TestConfig(model_path=QWEN25VL_MODEL_PATH,

View File

@ -702,13 +702,38 @@ VLM_TEST_SETTINGS = {
"smolvlm": VLMTestInfo(
models=["HuggingFaceTB/SmolVLM2-2.2B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|im_start|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
prompt_formatter=lambda img_prompt: f"<|im_start|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
hf_output_post_proc=model_utils.smolvlm_trunc_hf_output,
),
"tarsier": VLMTestInfo(
models=["omni-research/Tarsier-7b"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"USER: {img_prompt} ASSISTANT:",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
patch_hf_runner=model_utils.tarsier_patch_hf_runner,
),
"tarsier2": VLMTestInfo(
models=["omni-research/Tarsier2-Recap-7b"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO,
),
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.skip("Model initialization hangs")],
),
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"],

View File

@ -818,3 +818,15 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
thinker.get_output_embeddings = lambda: thinker.lm_head
hf_model.model = thinker
return hf_model
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from vllm.model_executor.models.tarsier import get_vision_encoder_info
vision_encoder_info = get_vision_encoder_info(hf_model.config)
hf_processor = hf_model.processor
if hf_processor.patch_size is None:
hf_processor.patch_size = vision_encoder_info.get_patch_size()
return hf_model

View File

@ -16,7 +16,7 @@ def test_multimodal_processor(model_id):
model_impl="transformers",
)
mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config, )
mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config)
image_pil = ImageAsset('cherry_blossom').pil_image
mm_data = {"image": image_pil}

View File

@ -465,8 +465,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online=False),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
hf_overrides={"architectures": ["TarsierForConditionalGeneration"]}), # noqa: E501
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501
"Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo(

View File

@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import nullcontext
from types import MethodType
from typing import cast
from typing import Optional, cast
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from transformers import ProcessorMixin
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
@ -1013,57 +1012,91 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)
class _ProcessorProxy:
class DummyProcessor:
def __init__(self, processor: ProcessorMixin) -> None:
def __init__(self, a: int = 0, b: int = 0) -> None:
super().__init__()
self.__processor = processor
def __getattr__(self, key: str):
return getattr(self.__processor, key)
self.a = a
self.b = b
def __call__(
self,
text=None,
images=None,
videos=None,
exists=None,
return_tensors=None,
):
return dict(exists=exists)
a: int = 0,
c: int = 0,
return_tensors: Optional[str] = None,
) -> dict[str, int]:
return dict(a=a, c=c)
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
# yapf: disable
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
@pytest.mark.parametrize(
("call_kwargs", "expected_kwargs"),
("config_kwargs", "inference_kwargs", "expected_kwargs"),
[
# Should ignore invalid kwargs
({"does_not_exist": 100}, {"exists": None}),
({"exists": 1}, {"exists": 1}),
({"does_not_exist": 100, "exists": 1}, {"exists": 1}),
({"a": 1}, {}, {"a": 1, "b": 0}),
({}, {"a": 1}, {"a": 1, "b": 0}),
# inference_kwargs should take precedence
({"a": 1}, {"a": 2}, {"a": 2, "b": 0}),
# Should ignore extra kwargs
({"a": 1, "c": 1}, {}, {"a": 1, "b": 0}),
({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}),
],
)
# yapf: enable
def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
model_config = ModelConfig(model_id)
def test_hf_processor_init_kwargs(
model_id,
config_kwargs,
inference_kwargs,
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
orig_get_hf_processor = processor.info.get_hf_processor
def get_hf_processor(self, **kwargs):
assert kwargs == call_kwargs
return _ProcessorProxy(orig_get_hf_processor())
processor.info.get_hf_processor = MethodType(get_hf_processor,
processor.info)
out_kwargs = processor._call_hf_processor(
prompt="",
mm_data={},
mm_kwargs=call_kwargs,
tok_kwargs={},
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer,
)
assert out_kwargs == expected_kwargs
processor = ctx.get_hf_processor(
DummyProcessor, # type: ignore[arg-type]
**inference_kwargs,
)
for k, v in expected_kwargs.items():
assert getattr(processor, k) == v
# yapf: disable
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
@pytest.mark.parametrize(
("config_kwargs", "inference_kwargs", "expected_kwargs"),
[
({"a": 1}, {}, {"a": 1, "c": 0}),
({}, {"a": 1}, {"a": 1, "c": 0}),
# inference_kwargs should take precedence
({"a": 1}, {"a": 2}, {"a": 2, "c": 0}),
# Should ignore extra kwargs
({"a": 1, "c": 1}, {}, {"a": 1, "c": 1}),
({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}),
],
)
# yapf: enable
def test_hf_processor_call_kwargs(
model_id,
config_kwargs,
inference_kwargs,
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer,
)
processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type]
result = ctx.call_hf_processor(processor, {}, inference_kwargs)
assert result == expected_kwargs

View File

@ -11,6 +11,7 @@ import textwrap
import uuid
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
@ -3332,7 +3333,16 @@ class MultiModalConfig:
999 if envs.VLLM_USE_V1 else 1,
)
# TODO: Add configs to init vision tower or not.
def merge_mm_processor_kwargs(
self,
inference_kwargs: Mapping[str, object],
) -> dict[str, object]:
"""
Get the keyword arguments to pass to the multi-modal processor
according to the extra arguments passed during inference.
"""
kwargs = self.mm_processor_kwargs or {}
return kwargs | dict(inference_kwargs)
@config

View File

@ -11,7 +11,7 @@ from typing_extensions import TypeVar
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils import resolve_mm_processor_kwargs
from vllm.utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -154,14 +154,11 @@ class InputProcessingContext(InputContext):
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
kwargs,
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
@ -173,7 +170,9 @@ class InputProcessingContext(InputContext):
return x
try:
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
output = hf_processor(**data,
**allowed_kwargs,
return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
@ -189,7 +188,7 @@ class InputProcessingContext(InputContext):
except Exception as exc:
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc

View File

@ -123,16 +123,10 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_config(AyaVisionConfig)
def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
processor = self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
# Temporary workaround since this processor has multiple image tokens
# See https://github.com/huggingface/transformers/issues/38350
processor._check_special_mm_tokens = lambda *args, **kwargs: None
return processor
def get_image_processor(self) -> GotOcr2ImageProcessor:
return self.get_hf_processor().image_processor
def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

View File

@ -214,25 +214,25 @@ class DeepseekVL2MultiModalProcessor(
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop
]
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
else:
if not mm_data:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
return tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
return processed_outputs

View File

@ -761,12 +761,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
class Florence2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self):
return self.ctx.get_hf_processor()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

View File

@ -83,8 +83,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor
def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

View File

@ -809,11 +809,11 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1}
def get_image_processor(self) -> Glm4vImageProcessor:
return self.get_hf_processor().image_processor
def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_video_processor(self) -> Glm4vVideoProcessor:
return self.get_hf_processor().video_processor
def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor
def _get_vision_info(
self,

View File

@ -392,21 +392,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> H2OVLProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor:
return self.ctx.init_processor(
H2OVLProcessor,
config=self.get_hf_config(),

View File

@ -25,8 +25,7 @@ import torch
import torch.nn as nn
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
from transformers import (AutoProcessor, BatchFeature, CLIPVisionConfig,
SiglipVisionConfig)
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
@ -80,26 +79,9 @@ HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
class HCXVisionProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def get_hf_processor(
self,
**kwargs: object,
):
processor_cls = type(
AutoProcessor.from_pretrained(
self.ctx.model_config.model,
trust_remote_code=self.ctx.model_config.trust_remote_code,
))
return self.ctx.get_hf_processor(
processor_cls,
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}

View File

@ -88,15 +88,7 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Idefics3Processor:
if size is not None:
kwargs["size"] = size
def get_hf_processor(self, **kwargs: object) -> Idefics3Processor:
return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:

View File

@ -665,14 +665,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
"""Basic image-only ProcessingInfo for InternVL-style models."""
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseInternVLProcessor:
def get_hf_processor(self, **kwargs: object) -> BaseInternVLProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
@ -882,27 +875,12 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
return max(max_frames_per_video, 1)
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> InternVLProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
kwargs["video_token"] = self.get_video_token()
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
return self.ctx.init_processor(
InternVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(),
**kwargs,
)

View File

@ -44,8 +44,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -980,72 +978,8 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
class KeyeProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
return self.ctx.get_hf_processor(
image_processor=self.get_image_processor(
min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
),
**kwargs,
)
def _get_image_processor_kwargs(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
if self.ctx.model_config.mm_processor_kwargs:
kwargs.update(self.ctx.model_config.mm_processor_kwargs)
if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
if size is None:
size = {"shortest_edge": min_pixels}
else:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
kwargs["max_pixels"] = max_pixels
if size is None:
size = {"longest_edge": max_pixels}
else:
size["longest_edge"] = max_pixels
if size is not None:
kwargs["size"] = size
return kwargs
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
return cached_image_processor_from_config(
self.ctx.model_config,
**self._get_image_processor_kwargs(
min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
**kwargs,
),
)
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
@ -1246,20 +1180,6 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return KeyeMultiModalDataParser()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@ -8,11 +8,9 @@ from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
import torch
import torch.nn as nn
from packaging.version import Version
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@ -307,29 +305,14 @@ class PixtralHFMultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
# Before/after https://github.com/huggingface/transformers/pull/35122
if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
images = mm_data["images"]
assert isinstance(images, list)
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
assert len(pixel_values) == len(image_sizes)
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
else:
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
assert len(pixel_values) == len(image_sizes)
processed_outputs["pixel_values"] = [
p[:, :h, :w]
for p, (h, w) in zip(pixel_values, image_sizes)
]
processed_outputs["pixel_values"] = [
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
]
return processed_outputs
@ -784,17 +767,10 @@ class MantisProcessingInfo(LlavaProcessingInfo):
vision_info = self.get_vision_encoder_info()
kwargs.setdefault("patch_size", vision_info.get_patch_size())
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
# BUG: num_additional_image_tokens = 0 but treated as 1,
# so we set vision_feature_select_strategy to None to offset this
kwargs.setdefault("vision_feature_select_strategy", None)
else:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
kwargs.setdefault(
"vision_feature_select_strategy",
hf_config.vision_feature_select_strategy,
)
kwargs.setdefault(
"vision_feature_select_strategy",
hf_config.vision_feature_select_strategy,
)
return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)

View File

@ -331,10 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return hf_processor
def get_image_processor(self):
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
return image_processor
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
def get_model_version(self):
return get_version_by_config(self.get_hf_config())

View File

@ -533,7 +533,7 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor,
use_fast=True,
use_fast=kwargs.pop("use_fast", True),
**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:

View File

@ -137,34 +137,16 @@ class NemotronVLProcessor(InternVLProcessor):
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
"""Processing info for Nemotron VL models."""
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NemotronVLProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
image_processor = self.get_image_processor()
def get_hf_processor(self, **kwargs: object) -> NemotronVLProcessor:
return self.ctx.init_processor(
NemotronVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
image_processor=image_processor,
image_processor=self.get_image_processor(),
**kwargs,
)
def get_image_processor(
self,
**kwargs: object,
):
def get_image_processor(self, **kwargs: object):
return cached_image_processor_from_config(
self.ctx.model_config,
**kwargs,

View File

@ -63,21 +63,7 @@ class NVLMProcessor(BaseInternVLProcessor):
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NVLMProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
return self.ctx.init_processor(
NVLMProcessor,
config=self.get_hf_config(),

View File

@ -25,7 +25,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import gumbel_softmax, pad, softmax
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
@ -245,11 +245,12 @@ class VisualEmbedding(torch.nn.Embedding):
class OvisProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(
OvisProcessor,
image_pad_token=self.get_image_pad_token(),
image_segment_len=self.get_image_segment_len(),
**kwargs,
)
def get_image_segment_len(self) -> int:
@ -269,9 +270,6 @@ class OvisProcessingInfo(BaseProcessingInfo):
text_model_type = hf_text_config.model_type
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

View File

@ -318,17 +318,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if num_crops is not None:
kwargs["num_crops"] = num_crops
return self.ctx.get_hf_processor(**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

View File

@ -696,19 +696,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Phi4MultimodalConfig:
return self.ctx.get_hf_config(Phi4MultimodalConfig)
def get_hf_processor(
self,
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> Phi4MMProcessor:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor:
return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs)
return self.ctx.get_hf_processor(**kwargs)
def get_feature_extractor(self) -> Phi4MultimodalFeatureExtractor:
return self.get_hf_processor().audio_processor
def get_feature_extractor(
self, **kwargs: object) -> Phi4MultimodalFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_image_processor(
self,
@ -1007,7 +1000,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
if audio_data:
audio_features = processed_outputs['audio_input_features']
sr = self.info.get_feature_extractor().sampling_rate
sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
feature_sizes = [
self.info.get_audio_num_frames(len(audio), sr)
for audio in audio_data
@ -1043,7 +1036,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_token_id = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor()
audio_processor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int):
images = mm_items.get_items(

View File

@ -459,17 +459,6 @@ def cat_with_pad(tensors, dim, padding_value=0):
class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
return self.ctx.get_hf_processor(**kwargs)
@property
def image_tokens(self) -> list[str]:
return [f"<|image_{i+1}|>" for i in range(100)]
@ -487,8 +476,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
image_processor = processor.image_processor
return image_processor.dynamic_hd
def get_feature_extractor(self) -> SequenceFeatureExtractor:
return self.get_hf_processor().audio_processor
def get_feature_extractor(self,
**kwargs: object) -> SequenceFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None, "image": None}
@ -769,7 +759,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
sr = self.info.get_feature_extractor().sampling_rate
sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
if (audio_data := mm_data.get("audios", [])):
mm_data['audios'] = [(data, sr) for data in audio_data]
@ -816,7 +806,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
) -> Sequence[PromptUpdate]:
image_tokens: list[str] = self.info.image_tokens # type: ignore
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int):

View File

@ -132,50 +132,15 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config
def get_hf_processor(
self,
*,
sampling_rate: Optional[int] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5OmniProcessor:
if fps is not None:
kwargs["fps"] = fps
# Monkey patch for Transformers v4.53
processor_class = Qwen2_5OmniProcessor
if processor_class.image_processor_class != "AutoImageProcessor":
processor_class.image_processor_class = "AutoImageProcessor"
if processor_class.video_processor_class != "AutoVideoProcessor":
processor_class.video_processor_class = "AutoVideoProcessor"
processor = self.ctx.get_hf_processor(
processor_class,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor:
return self.ctx.get_hf_processor(
Qwen2_5OmniProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
if not hasattr(processor, "audio_token"):
processor.audio_token = "<|AUDIO|>"
if not hasattr(processor, "image_token"):
processor.image_token = "<|IMAGE|>"
if not hasattr(processor, "video_token"):
processor.video_token = "<|VIDEO|>"
return processor
def get_feature_extractor(
self,
*,
sampling_rate: Optional[int] = None,
**kwargs: object,
):
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor

View File

@ -780,25 +780,10 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5_VLConfig)
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5_VLProcessor:
if fps is not None:
kwargs["fps"] = fps
def get_hf_processor(self, **kwargs: object) -> Qwen2_5_VLProcessor:
return self.ctx.get_hf_processor(
Qwen2_5_VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)

View File

@ -86,22 +86,12 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> Qwen2AudioProcessor:
def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor

View File

@ -69,8 +69,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -752,73 +750,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLProcessor:
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
return self.ctx.get_hf_processor(
Qwen2VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def _get_image_processor_kwargs(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
mm_config = self.ctx.model_config.get_multimodal_config()
if mm_config.mm_processor_kwargs:
kwargs.update(mm_config.mm_processor_kwargs)
if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
if size is None:
size = {"shortest_edge": min_pixels}
else:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
kwargs["max_pixels"] = max_pixels
if size is None:
size = {"longest_edge": max_pixels}
else:
size["longest_edge"] = max_pixels
if size is not None:
kwargs["size"] = size
return kwargs
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLImageProcessor:
kwargs["use_fast"] = kwargs.get("use_fast", True)
return cached_image_processor_from_config(
self.ctx.model_config,
**self._get_image_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
**kwargs),
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
@ -1023,20 +963,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@ -7,9 +7,8 @@
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, TypeVar, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v(
return pixel_values
class BaseSkyworkR1VProcessor(ABC):
class SkyworkR1VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC):
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
return self.tokenizer.get_vocab()[IMG_CONTEXT]
@abstractmethod
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
raise NotImplementedError
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num(
self,
@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC):
}
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[BaseSkyworkR1VProcessor],
processor: Optional[SkyworkR1VProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class SkyworkR1VDummyInputsBuilder(
BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
}
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
class SkyworkR1VMultiModalProcessor(
BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
def _call_hf_processor(
self,
@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo,

View File

@ -19,15 +19,7 @@ from .idefics3 import Idefics3ProcessingInfo
class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(
self,
*,
max_image_size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> SmolVLMProcessor:
if max_image_size is not None:
kwargs["max_image_size"] = max_image_size
def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token(

View File

@ -178,13 +178,11 @@ class TarsierProcessingInfo(BaseProcessingInfo):
return get_vision_encoder_info(self.get_hf_config())
def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
hf_processor = self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
# Patch for patch_size if needed (copied from vLLM LLaVA)
if hasattr(hf_processor,
'patch_size') and hf_processor.patch_size is None:
patch_size = self.get_vision_encoder_info().get_patch_size()
hf_processor.patch_size = patch_size
return hf_processor
vision_info = self.get_vision_encoder_info()
kwargs.setdefault("patch_size", vision_info.get_patch_size())
return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

View File

@ -48,7 +48,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
@ -189,10 +188,6 @@ class MultiModalProcessingInfo(BaseProcessingInfo):
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size

View File

@ -71,13 +71,7 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
class UltravoxProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
config = self.ctx.model_config.hf_config
hf_processor = self.ctx.get_hf_processor(**kwargs)
@ -89,13 +83,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
return hf_processor
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
@ -156,7 +146,7 @@ class UltravoxMultiModalProcessor(
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,

View File

@ -623,23 +623,22 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
# HACK: Transformers 4.53.0 has issue with whisper tokenizer to
def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
# HACK: Transformers 4.53.2 has issue with whisper tokenizer to
# initialize processor. We use a monkeypatch to fix it here.
# See: https://github.com/vllm-project/vllm/issues/20224
processor_class = WhisperProcessor
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
if processor_class.tokenizer_class != tokenizer_class:
processor_class.tokenizer_class = tokenizer_class
return self.ctx.get_hf_processor(processor_class)
return self.ctx.get_hf_processor(processor_class, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
@ -702,7 +701,7 @@ class WhisperMultiModalProcessor(
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,

View File

@ -4,9 +4,15 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from transformers import (AutoFeatureExtractor, AutoImageProcessor,
AutoProcessor)
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar
from vllm.utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -33,23 +39,42 @@ class HashableList(list):
return hash(tuple(self))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
mm_config = model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]):
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
return AutoProcessor.from_pretrained
if hasattr(processor_cls, "from_pretrained"):
return processor_cls.from_pretrained
merged_kwargs = {**base_kwargs, **kwargs}
return processor_cls
def _merge_mm_kwargs(
model_config: "ModelConfig",
processor_cls: Union[type, tuple[type, ...]],
/,
**kwargs,
):
mm_config = model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
factory = _get_processor_factory_fn(processor_cls)
allowed_kwargs = get_allowed_kwarg_only_overrides(
factory,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
for key, value in allowed_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
allowed_kwargs[key] = HashableDict(value)
if isinstance(value, list):
merged_kwargs[key] = HashableList(value)
return merged_kwargs
allowed_kwargs[key] = HashableList(value)
return allowed_kwargs
def get_processor(
@ -61,21 +86,29 @@ def get_processor(
**kwargs: Any,
) -> _P:
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
isinstance(processor_cls, tuple) else processor_cls)
if revision is None:
revision = "main"
try:
processor = processor_factory.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif issubclass(processor_cls, ProcessorMixin):
processor = processor_cls.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
else:
# Processors that are standalone classes unrelated to HF
processor = processor_cls(*args, **kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
@ -112,7 +145,7 @@ def cached_processor_from_config(
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, processor_cls, **kwargs),
)
@ -125,10 +158,6 @@ def get_feature_extractor(
):
"""Load an audio feature extractor for the given model name
via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name,
@ -164,7 +193,7 @@ def cached_feature_extractor_from_config(
model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
)
@ -176,11 +205,6 @@ def get_image_processor(
**kwargs: Any,
):
"""Load an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try:
processor = AutoImageProcessor.from_pretrained(
processor_name,
@ -217,5 +241,5 @@ def cached_image_processor_from_config(
model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
)

View File

@ -2010,49 +2010,6 @@ def supports_kw(
return False
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
return mm_processor_kwargs
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Mapping[str, object]],