Files
vllm-dev/vllm/inputs/registry.py
2025-08-27 14:19:13 +08:00

252 lines
7.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
MultiModalRegistry)
from vllm.sequence import SequenceData
from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
ModelConfig = Any
MultiModalDataDict = Any
MultiModalPlaceholderDict = Any
MultiModalRegistry = Any
SequenceData = Any
AnyTokenizer = Any
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
logger = init_logger(__name__)
@dataclass(frozen=True)
class InputContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
def get_hf_config(
self,
typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
/,
) -> _C:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs)
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
) -> Union[BatchFeature, JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
def maybe_cast_dtype(x):
# This mimics the behavior of transformers.BatchFeature
if isinstance(x, torch.Tensor) and x.is_floating_point():
return x.to(dtype=self.model_config.dtype)
return x
try:
output = hf_processor(**data,
**allowed_kwargs,
return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output)
cast_output = json_map_leaves(maybe_cast_dtype, output)
logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.")
return cast_output
except Exception as exc:
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc
class DummyData(NamedTuple):
"""
Dummy data used for profiling.
Note: This is only used in V0.
"""
seq_data: SequenceData
multi_modal_data: Optional[MultiModalDataDict] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
class InputRegistry:
"""
Note: This is only used in V0.
"""
def dummy_data_for_profiling(
self,
model_config: ModelConfig,
seq_len: int,
mm_registry: MultiModalRegistry,
is_encoder_data: bool = False,
) -> DummyData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
multi_modal_data=dec_data.multi_modal_data.get_data(),
multi_modal_placeholders=dec_data.multi_modal_placeholders,
)