[Misc]] Move processing context to multimodal directory (#25548)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-24 16:15:00 +08:00
committed by GitHub
parent 27ec3c78f3
commit 6488f3481b
13 changed files with 262 additions and 242 deletions

View File

@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)

View File

@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import InputProcessingContext
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of

View File

@ -11,8 +11,9 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .registry import HF_EXAMPLE_MODELS
@ -264,7 +265,7 @@ def build_model_context(
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_cache_gb: int = 0,
):
"""Creates an InputContext for a given model.
"""Creates an InputProcessingContext for a given model.
Args:
model_id: ID of the model being considered.
@ -273,7 +274,7 @@ def build_model_context(
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
InputProcessingContext for the model being considered.
"""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
@ -298,7 +299,11 @@ def build_model_context(
enforce_eager=model_info.enforce_eager,
**model_config_kwargs,
)
return InputContext(model_config)
return InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
def check_embeddings_close(

View File

@ -8,11 +8,11 @@ import numpy as np
import pytest
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
from vllm.multimodal.processing import (InputProcessingContext,
PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
apply_token_matches,

View File

@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputProcessingContext
__all__ = [
"DataPrompt",
@ -28,6 +27,4 @@ __all__ = [
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"InputContext",
"InputProcessingContext",
]

View File

@ -1,206 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Union
import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
ModelConfig = Any
AnyTokenizer = Any
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
logger = init_logger(__name__)
@dataclass(frozen=True)
class InputContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
def get_hf_config(
self,
typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
/,
) -> _C:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs)
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
*,
num_tries: int = 1,
max_tries: int = 5,
) -> Union[BatchFeature, JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
def maybe_cast_dtype(x):
# This mimics the behavior of transformers.BatchFeature
if isinstance(x, torch.Tensor) and x.is_floating_point():
return x.to(dtype=self.model_config.dtype)
return x
try:
output = hf_processor(**data,
**allowed_kwargs,
return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output)
cast_output = json_map_leaves(maybe_cast_dtype, output)
logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.")
return cast_output
except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537
if (isinstance(exc, RuntimeError) and exc
and exc.args[0] == "Already borrowed"
and num_tries < max_tries):
logger.warning(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)...", num_tries, max_tries)
time.sleep(0.5)
return self.call_hf_processor(
hf_processor,
data,
kwargs,
num_tries=num_tries + 1,
max_tries=max_tries,
)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc

View File

@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors

View File

@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves

View File

@ -13,7 +13,6 @@ from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig,
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -27,8 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -32,7 +32,6 @@ from transformers.models.llama4.image_processing_llama4_fast import (
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
@ -47,8 +46,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -17,7 +17,6 @@ from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -29,8 +28,9 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
@ -7,18 +8,20 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional,
Protocol, Union, cast, overload)
import regex as re
import torch
from typing_extensions import assert_never
from typing_extensions import TypeVar, assert_never
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import flatten_2d_lists, full_groupby
from vllm.utils import (flatten_2d_lists, full_groupby,
get_allowed_kwarg_only_overrides)
from vllm.utils.jsontree import JSONTree, json_map_leaves
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
@ -34,6 +37,8 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
@ -875,6 +880,222 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
_T = TypeVar("_T")
_C = TypeVar("_C", bound="PretrainedConfig", default="PretrainedConfig")
_P = TypeVar("_P", bound="ProcessorMixin", default="ProcessorMixin")
@dataclass(frozen=True)
class InputProcessingContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: "ModelConfig"
"""The configuration of the model."""
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
@overload
def get_hf_config(self, /) -> "PretrainedConfig":
...
@overload
def get_hf_config(
self,
typ: Union[type[_C], tuple[type[_C], ...]],
/,
) -> _C:
...
def get_hf_config(
self,
typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None,
/,
) -> Any:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
if typ is None:
from transformers.configuration_utils import PretrainedConfig
typ = PretrainedConfig
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
@overload
def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin":
...
@overload
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]],
/,
**kwargs: object,
) -> _P:
...
def get_hf_processor(
self,
typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None,
/,
**kwargs: object,
) -> Any:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
if typ is None:
from transformers.processing_utils import ProcessorMixin
typ = ProcessorMixin
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
tokenizer=self.tokenizer,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs)
def _postprocess_output(
self,
output: JSONTree,
) -> JSONTree:
def _postprocess_one(x: object):
if isinstance(x, torch.Tensor): # noqa: SIM102
# This mimics the behavior of transformers.BatchFeature
if x.is_floating_point():
x = x.to(dtype=self.model_config.dtype)
return x
return json_map_leaves(_postprocess_one, output)
def call_hf_processor(
self,
hf_processor: "ProcessorMixin",
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
*,
num_tries: int = 1,
max_tries: int = 5,
) -> Union["BatchFeature", JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
try:
output = hf_processor(**data,
**allowed_kwargs,
return_tensors="pt")
except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537
if (isinstance(exc, RuntimeError) and exc
and exc.args[0] == "Already borrowed"
and num_tries < max_tries):
logger.warning(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)...", num_tries, max_tries)
time.sleep(0.5)
return self.call_hf_processor(
hf_processor,
data,
kwargs,
num_tries=num_tries + 1,
max_tries=max_tries,
)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc
# this emulates output.to(dtype=self.model_config.dtype)
from transformers.feature_extraction_utils import BatchFeature
if isinstance(output, BatchFeature):
output_ = self._postprocess_output(output.data)
return BatchFeature(output_)
logger.warning_once(
"%s did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.",
type(hf_processor).__name__,
)
return self._postprocess_output(output)
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""

View File

@ -6,14 +6,14 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.utils import ClassRegistry
from .cache import BaseMultiModalProcessorCache
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
InputProcessingContext)
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
DummyEncoderData, MultiModalProfiler)
@ -41,7 +41,7 @@ class ProcessingInfoFactory(Protocol[_I_co]):
...
class DummyInputsBuilderFactory(Protocol[_I]):
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
"""
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]