mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Remove legacy input mapper/processor from V0 (#15686)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -1596,7 +1596,6 @@ class Scheduler:
|
||||
multi_modal_placeholders=(
|
||||
seq_group.multi_modal_placeholders
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None),
|
||||
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
else:
|
||||
|
@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
processed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
|
@ -29,8 +29,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputs)
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
@ -213,7 +212,6 @@ class LLMEngine:
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
) -> None:
|
||||
@ -274,11 +272,7 @@ class LLMEngine:
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
self.model_config)
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config, )
|
||||
self.model_executor = executor_class(vllm_config=vllm_config)
|
||||
|
||||
if self.model_config.runner_type != "pooling":
|
||||
self._initialize_kv_caches()
|
||||
@ -762,12 +756,11 @@ class LLMEngine:
|
||||
prompt,
|
||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
|
@ -2,10 +2,9 @@
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
|
||||
TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
|
||||
token_inputs, zip_enc_dec_prompts)
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||
InputRegistry)
|
||||
|
||||
@ -27,7 +26,6 @@ __all__ = [
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
"SingletonInputs",
|
||||
"SingletonInputsAdapter",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
|
@ -1,17 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict)
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
@ -147,46 +141,11 @@ class TokenInputs(TypedDict):
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
multi_modal_inputs: NotRequired["MultiModalKwargs"]
|
||||
"""
|
||||
Optional multi-modal inputs to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
|
||||
"""
|
||||
Placeholder ranges for the multi-modal data.
|
||||
"""
|
||||
|
||||
multi_modal_hashes: NotRequired[list[str]]
|
||||
"""
|
||||
The hashes of the multi-modal data.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
|
||||
def token_inputs(
|
||||
prompt_token_ids: list[int],
|
||||
token_type_ids: Optional[list[int]] = None,
|
||||
prompt: Optional[str] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
|
||||
multi_modal_hashes: Optional[list[str]] = None,
|
||||
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> TokenInputs:
|
||||
"""Construct :class:`TokenInputs` from optional values."""
|
||||
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
|
||||
@ -195,16 +154,6 @@ def token_inputs(
|
||||
inputs["prompt"] = prompt
|
||||
if token_type_ids is not None:
|
||||
inputs["token_type_ids"] = token_type_ids
|
||||
if multi_modal_data is not None:
|
||||
inputs["multi_modal_data"] = multi_modal_data
|
||||
if multi_modal_inputs is not None:
|
||||
inputs["multi_modal_inputs"] = multi_modal_inputs
|
||||
if multi_modal_hashes is not None:
|
||||
inputs["multi_modal_hashes"] = multi_modal_hashes
|
||||
if multi_modal_placeholders is not None:
|
||||
inputs["multi_modal_placeholders"] = multi_modal_placeholders
|
||||
if mm_processor_kwargs is not None:
|
||||
inputs["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
return inputs
|
||||
|
||||
@ -237,112 +186,6 @@ A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingletonInputsAdapter:
|
||||
"""
|
||||
Unified interface to access the components of :class:`SingletonInputs`.
|
||||
"""
|
||||
inputs: SingletonInputs
|
||||
|
||||
@cached_property
|
||||
def prompt(self) -> Optional[str]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("prompt")
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("prompt_token_ids", [])
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def token_type_ids(self) -> list[int]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("token_type_ids", [])
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return None
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_data", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_kwargs", {})
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_inputs", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_kwargs", {})
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def multi_modal_hashes(self) -> list[str]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_hashes", [])
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
# only the case when we use MultiModalInputs
|
||||
return inputs.get("mm_hashes", []) # type: ignore[return-value]
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_placeholders", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_placeholders", {})
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
@cached_property
|
||||
def mm_processor_kwargs(self) -> dict[str, Any]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("mm_processor_kwargs", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return {}
|
||||
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
|
||||
"""
|
||||
The inputs to :data:`vllm.inputs.InputProcessor`.
|
||||
|
@ -223,28 +223,6 @@ class InputPreprocessor:
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
||||
def _can_process_multimodal(self) -> bool:
|
||||
model_config = self.model_config
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError("Your model does not support multi-modal inputs")
|
||||
|
||||
# Interim measure so we can handle models that have yet to be
|
||||
# updated to use the new multi-modal processor
|
||||
can_process_multimodal = self.mm_registry.has_processor(model_config)
|
||||
if not can_process_multimodal:
|
||||
from vllm.model_executor.models.registry import _VLLM_MODELS
|
||||
if not any(arch in _VLLM_MODELS
|
||||
for arch in model_config.architectures):
|
||||
logger.warning_once(
|
||||
"Your model uses the legacy input pipeline, which will be "
|
||||
"removed in an upcoming release. "
|
||||
"Please upgrade to the new multi-modal processing pipeline "
|
||||
"(https://docs.vllm.ai/en/latest/design/mm_processing.html)"
|
||||
)
|
||||
|
||||
return can_process_multimodal
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
@ -258,8 +236,7 @@ class InputPreprocessor:
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
# initialized without a tokenizer while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
tokenizer = object() # Dummy
|
||||
else:
|
||||
@ -285,8 +262,7 @@ class InputPreprocessor:
|
||||
) -> MultiModalInputs:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
# initialized without a tokenizer while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
tokenizer = object() # Dummy
|
||||
else:
|
||||
@ -343,7 +319,7 @@ class InputPreprocessor:
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
if multi_modal_data is not None:
|
||||
return self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
@ -355,8 +331,6 @@ class InputPreprocessor:
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
@ -366,7 +340,7 @@ class InputPreprocessor:
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
if multi_modal_data is not None:
|
||||
return self._process_multimodal(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
@ -383,8 +357,6 @@ class InputPreprocessor:
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@ -417,7 +389,7 @@ class InputPreprocessor:
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
if multi_modal_data is not None:
|
||||
return await self._process_multimodal_async(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
@ -426,11 +398,7 @@ class InputPreprocessor:
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
return token_inputs(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
text_content = parsed["content"]
|
||||
@ -439,7 +407,7 @@ class InputPreprocessor:
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
if multi_modal_data is not None:
|
||||
return await self._process_multimodal_async(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
@ -456,8 +424,6 @@ class InputPreprocessor:
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@ -594,15 +560,13 @@ class InputPreprocessor:
|
||||
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
if self.model_config.is_multimodal_model:
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = self._prompt_to_llm_inputs(prompt)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
@ -637,15 +601,13 @@ class InputPreprocessor:
|
||||
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
if self.model_config.is_multimodal_model:
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = await self._prompt_to_llm_inputs_async(prompt)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
|
@ -1,24 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional,
|
||||
Protocol, Union)
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
from .parse import split_enc_dec_inputs
|
||||
from vllm.utils import resolve_mm_processor_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -26,8 +16,6 @@ if TYPE_CHECKING:
|
||||
MultiModalRegistry)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
@ -172,142 +160,23 @@ class InputProcessingContext(InputContext):
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
|
||||
|
||||
class DummyData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
"""
|
||||
Dummy data used for profiling.
|
||||
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
seq_data: "SequenceData"
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None
|
||||
|
||||
|
||||
class DummyDataFactory(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
**mm_processor_kwargs: Any,
|
||||
) -> DummyData:
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
|
||||
The :code:`mm_processor_kwargs` are overrides provided at
|
||||
initialization time to values in the config whose values
|
||||
may affect the number of tokens per instance.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class _MultiModalCounts(UserDict[str, int]):
|
||||
"""
|
||||
Wraps `mm_counts` for a more informative error message
|
||||
when attempting to access a plugin that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no multi-modal plugin with the key: {key}. "
|
||||
f"Available keys: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
|
||||
|
||||
class InputRegistry:
|
||||
"""
|
||||
A registry to dispatch data processing
|
||||
according to the target model.
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dummy_factories_by_model_type = \
|
||||
ClassRegistry[nn.Module, DummyDataFactory]()
|
||||
self._dummy_encoder_factories_by_model_type = \
|
||||
ClassRegistry[nn.Module, DummyDataFactory]()
|
||||
self._input_processors_by_model_type = \
|
||||
ClassRegistry[nn.Module, InputProcessor]()
|
||||
|
||||
def _default_dummy_data_factory(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> DummyData:
|
||||
"""
|
||||
The default dummy data factory represents the longest possible text
|
||||
that can be inputted to the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
|
||||
|
||||
def register_dummy_data(self, factory: DummyDataFactory):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The resulting memory usage
|
||||
should be an upper bound of what the model would use at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._dummy_factories_by_model_type.contains(model_cls,
|
||||
strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
|
||||
return self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
def register_dummy_encoder_data(self, factory: DummyDataFactory):
|
||||
"""
|
||||
Register a dummy encoder data factory to a model class
|
||||
|
||||
This is similar to :meth:`~register_dummy_data`, but for encoder input.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._dummy_encoder_factories_by_model_type.contains(
|
||||
model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has dummy encoder data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_encoder_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
|
||||
return self._dummy_encoder_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
@ -319,169 +188,25 @@ class InputRegistry:
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
Note:
|
||||
This should be called after
|
||||
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
if mm_registry.has_processor(model_config):
|
||||
processor = mm_registry.create_processor(model_config,
|
||||
disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
if not model_config.is_multimodal_model:
|
||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
|
||||
if is_encoder_data else
|
||||
profiler.get_decoder_dummy_data(seq_len))
|
||||
_seq_data = SequenceData.from_seqs(
|
||||
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
|
||||
# Encoder dummy data does not contain multi-modal data
|
||||
if is_encoder_data:
|
||||
enc_data = mm_registry.get_encoder_dummy_data(
|
||||
model_config, seq_len)
|
||||
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
dummy_data = DummyData(
|
||||
seq_data=_seq_data,
|
||||
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
|
||||
None),
|
||||
multi_modal_placeholders=getattr(dummy_data_v1,
|
||||
"multi_modal_placeholders",
|
||||
None),
|
||||
)
|
||||
else:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
||||
else:
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory,
|
||||
overrides=model_config.mm_processor_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
|
||||
|
||||
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = dummy_data.seq_data.prompt_token_ids
|
||||
if len(num_tokens) < seq_len:
|
||||
if is_encoder_data:
|
||||
logger.warning_once(
|
||||
f"Expected at least {seq_len} dummy encoder tokens for "
|
||||
f"profiling, but found {len(num_tokens)} tokens instead.")
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
|
||||
if (dummy_data.multi_modal_data is not None and
|
||||
not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
|
||||
for k, v in dummy_data.multi_modal_data.items():
|
||||
num_items = len(v) if isinstance(v, list) else 1
|
||||
num_expected = mm_counts[k]
|
||||
assert num_items >= num_expected, (
|
||||
f"Expected at least {num_expected} dummy '{k}' instances "
|
||||
f"for profiling, but found {num_items} instances instead.")
|
||||
|
||||
return dummy_data
|
||||
|
||||
def _default_input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: ProcessorInputs,
|
||||
**kwargs: object,
|
||||
) -> ProcessorInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
|
||||
def register_input_processor(self, processor: InputProcessor):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
|
||||
The provided function is invoked on each input to the model. This
|
||||
happens before
|
||||
:meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._input_processors_by_model_type.contains(model_cls,
|
||||
strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors_by_model_type[model_cls] = processor
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_input_processor(self, model_cls: type[nn.Module]):
|
||||
return self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def _ensure_mm_kwargs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
mm_processor_kwargs: dict[str, Any],
|
||||
):
|
||||
if inputs["type"] == "token":
|
||||
# In case the input processor for that model fails to set it
|
||||
if "mm_processor_kwargs" not in inputs:
|
||||
inputs["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
elif inputs["type"] == "multimodal":
|
||||
# Be more strict in V2
|
||||
assert "mm_kwargs" in inputs
|
||||
else:
|
||||
assert_never(inputs["type"]) # type: ignore[arg-type]
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: ProcessorInputs) -> ProcessorInputs:
|
||||
"""
|
||||
Apply an input processor to an instance of model inputs.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
processor = self._get_model_input_processor(model_cls)
|
||||
|
||||
# Handle multimodal processor kwargs with priority:
|
||||
# Inference kwargs -> Init kwargs -> {}
|
||||
# If it's empty, it'll fall back to the default kwarg values
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
inputs.get("mm_processor_kwargs", {}), # type: ignore
|
||||
processor,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
||||
multi_modal_data=dec_data.multi_modal_data,
|
||||
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
||||
)
|
||||
|
||||
processed_inputs = processor(
|
||||
InputContext(model_config),
|
||||
inputs,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
if encoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
|
||||
if decoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
|
||||
|
||||
return processed_inputs
|
||||
|
||||
def create_input_processor(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
Create an input processor (see :meth:`_process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input, model_config)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .base import MultiModalPlaceholderMap, MultiModalPlugin
|
||||
from .base import MultiModalPlaceholderMap
|
||||
from .hasher import MultiModalHashDict, MultiModalHasher
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
@ -26,7 +25,6 @@ __all__ = [
|
||||
"MultiModalKwargs",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalPlaceholderMap",
|
||||
"MultiModalPlugin",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
|
@ -7,11 +7,9 @@ from typing import Literal, Optional
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import AudioItem, ModalityData, MultiModalKwargs
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
@ -24,25 +22,6 @@ except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
class AudioPlugin(MultiModalPlugin):
|
||||
"""Plugin for audio data."""
|
||||
|
||||
def get_data_key(self) -> str:
|
||||
return "audio"
|
||||
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: ModalityData[AudioItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
raise NotImplementedError("There is no default audio input mapper")
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
raise NotImplementedError(
|
||||
"There is no default maximum multimodal tokens")
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
|
@ -1,247 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
|
||||
Optional, TypeVar, Union)
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
|
||||
MultiModalKwargs]
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
and processors in HuggingFace Transformers.
|
||||
|
||||
If the data is not supported, throw :exc:`TypeError`.
|
||||
"""
|
||||
|
||||
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
|
||||
"""
|
||||
Calculate the maximum number of multimodal tokens input to the language
|
||||
model. This does not include tokens that correspond to the input text.
|
||||
"""
|
||||
from .inputs import MultiModalKwargs, PlaceholderRange
|
||||
|
||||
_T = TypeVar("_T")
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
|
||||
|
||||
class MultiModalPlugin(ABC):
|
||||
"""
|
||||
Base class that defines data processing logic for a specific modality.
|
||||
|
||||
In particular, we adopt a registry pattern to dispatch data processing
|
||||
according to the model being used (considering that different models may
|
||||
process the same data differently). This registry is in turn used by
|
||||
:class:`~MultiModalRegistry` which acts at a higher level
|
||||
(i.e., the modality of the data).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
|
||||
self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
|
||||
|
||||
@abstractmethod
|
||||
def get_data_key(self) -> str:
|
||||
"""
|
||||
Get the data key corresponding to the modality.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: ModalityData[Any],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to
|
||||
tokenizers and processors in HuggingFace Transformers.
|
||||
|
||||
If the data is not supported, throw :exc:`TypeError`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_input_mapper(
|
||||
self,
|
||||
mapper: Optional[MultiModalInputMapper] = None,
|
||||
):
|
||||
"""
|
||||
Register an input mapper to a model class.
|
||||
|
||||
When the model receives input data that matches the modality served by
|
||||
this plugin (see :meth:`get_data_key`), the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
|
||||
If `None` is provided, then the default input mapper is used instead.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._input_mappers.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has an input mapper "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls,
|
||||
self,
|
||||
)
|
||||
|
||||
self._input_mappers[model_cls] = (mapper
|
||||
or self._default_input_mapper)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def map_input(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
data: ModalityData[Any],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Transform the data into a dictionary of model inputs using the
|
||||
input mapper registered for that model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type is not supported.
|
||||
"""
|
||||
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
mapper = self._input_mappers.get(model_cls)
|
||||
|
||||
if mapper is None:
|
||||
raise KeyError(f"No input mapper in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
# In the case of the default mapper, we have to get resource
|
||||
# processor through its HuggingFace autoclass; since this goes
|
||||
# through **kwargs, we can't inspect it the same way, so we allow
|
||||
# drop mm_processor_kwargs based on signature inspection
|
||||
# if we're using the default mapper.
|
||||
#
|
||||
# This should be safe in general due to the sanitation, since the
|
||||
# transformers resource should filter unused kwargs anyway.
|
||||
uses_default_mapper = mapper == self._default_input_mapper
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
mm_processor_kwargs,
|
||||
callable=mapper,
|
||||
allow_var_kwargs=uses_default_mapper,
|
||||
)
|
||||
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
"""
|
||||
Calculate the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data, that are passed to the language model.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
|
||||
if max_mm_tokens < 1:
|
||||
raise ValueError("You should set the number of tokens to a "
|
||||
f"positive integer. Found: {max_mm_tokens}")
|
||||
|
||||
def register_max_multimodal_tokens(
|
||||
self,
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data, that are passed to the language model
|
||||
for a model class.
|
||||
|
||||
If `None` is provided, then the default calculation is used instead.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._max_mm_tokens.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already calculates maximum number of "
|
||||
"tokens in %s. It is overwritten by the new one.",
|
||||
model_cls,
|
||||
self,
|
||||
)
|
||||
|
||||
if isinstance(max_mm_tokens, int):
|
||||
self._validate_max_multimodal_tokens(max_mm_tokens)
|
||||
|
||||
self._max_mm_tokens[model_cls] = (
|
||||
max_mm_tokens or self._default_max_multimodal_tokens)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum number of multi-modal tokens
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
If this registry is not applicable to the model, `0` is returned.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
if not supports_multimodal(model_cls):
|
||||
return 0
|
||||
|
||||
max_mm_tokens = self._max_mm_tokens.get(model_cls)
|
||||
if max_mm_tokens is None:
|
||||
return 0
|
||||
|
||||
if callable(max_mm_tokens):
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
max_mm_tokens,
|
||||
overrides=model_config.mm_processor_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
max_mm_tokens = max_mm_tokens(InputContext(model_config),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
self._validate_max_multimodal_tokens(max_mm_tokens)
|
||||
|
||||
return max_mm_tokens
|
||||
|
||||
|
||||
class MultiModalPlaceholderMap:
|
||||
"""
|
||||
Relates multi-modal embeddings to their corresponding placeholders.
|
||||
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
class IndexMap(NamedTuple):
|
||||
@ -279,8 +55,7 @@ class MultiModalPlaceholderMap:
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||
) -> tuple[Optional[MultiModalDataDict], dict[str,
|
||||
"MultiModalPlaceholderMap"]]:
|
||||
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
|
||||
"""
|
||||
Returns the multi-modal items that intersect with the portion of a
|
||||
prompt (``seq_group``) represented by ``positions``, as well as a
|
||||
@ -323,48 +98,24 @@ class MultiModalPlaceholderMap:
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
if not seq_mm_data or not seq_mm_placeholders:
|
||||
return seq_mm_data, {}
|
||||
return MultiModalKwargs({}), {}
|
||||
|
||||
# For merged processor, we directly use mm_kwargs as mm_data
|
||||
if isinstance(seq_mm_data, MultiModalKwargs):
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
for modality, placeholders in seq_mm_placeholders.items():
|
||||
placeholder_map = MultiModalPlaceholderMap()
|
||||
|
||||
if positions:
|
||||
placeholder_map.append_items_from_seq_group(
|
||||
positions,
|
||||
# Dummy, since we don't care about intersecting items
|
||||
[None] * len(placeholders),
|
||||
placeholders,
|
||||
)
|
||||
|
||||
placeholder_maps[modality] = placeholder_map
|
||||
|
||||
return seq_mm_data, placeholder_maps
|
||||
|
||||
mm_data = {**seq_mm_data}
|
||||
placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
|
||||
MultiModalPlaceholderMap)
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
for modality, placeholders in seq_mm_placeholders.items():
|
||||
mm_items = mm_data.pop(modality)
|
||||
if not isinstance(mm_items, list):
|
||||
mm_items = [mm_items]
|
||||
placeholder_map = MultiModalPlaceholderMap()
|
||||
|
||||
if positions:
|
||||
intersecting_items = placeholder_maps[modality] \
|
||||
.append_items_from_seq_group(
|
||||
positions,
|
||||
mm_items,
|
||||
placeholders,
|
||||
)
|
||||
placeholder_map.append_items_from_seq_group(
|
||||
positions,
|
||||
# Dummy, since we don't care about intersecting items
|
||||
[None] * len(placeholders),
|
||||
placeholders,
|
||||
)
|
||||
|
||||
if intersecting_items:
|
||||
mm_data[modality] = intersecting_items
|
||||
placeholder_maps[modality] = placeholder_map
|
||||
|
||||
return mm_data, placeholder_maps
|
||||
return seq_mm_data, placeholder_maps
|
||||
|
||||
def append_items_from_seq_group(
|
||||
self,
|
||||
@ -445,8 +196,7 @@ class MultiModalPlaceholderMap:
|
||||
f"The number of source ({len(src_indices)}) and destination "
|
||||
f"indices ({len(dest_indices)}) must be the same.")
|
||||
|
||||
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
|
||||
dest=dest_indices)
|
||||
return self.IndexMap(src=src_indices, dest=dest_indices)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
|
@ -3,89 +3,11 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_get_image_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import ImageItem, ModalityData, MultiModalKwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ImagePlugin(MultiModalPlugin):
|
||||
"""Plugin for image data."""
|
||||
|
||||
def get_data_key(self) -> str:
|
||||
return "image"
|
||||
|
||||
def _get_hf_image_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: ModalityData[ImageItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
try:
|
||||
# NOTE: It may make sense to forward the mm_processor_kwargs
|
||||
# here too. For now, to keep it simple, we only allow it be
|
||||
# used for the initialization call though, just in case the
|
||||
# signatures of the preprocessor initializer don't match
|
||||
# preprocess()
|
||||
batch_data = image_processor \
|
||||
.preprocess(data, return_tensors="pt") \
|
||||
.data
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Failed to process image (%s) with the default mapper. "
|
||||
"This is most likely an edge-case with this model's image "
|
||||
"processor in transformers (type: %s), and not vLLM.",
|
||||
data,
|
||||
type(image_processor).__name__)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
|
||||
# Image embedding
|
||||
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
||||
return MultiModalKwargs({"image_embeds": data})
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
return 3000
|
||||
from .base import MediaIO
|
||||
|
||||
|
||||
def rescale_image_size(image: Image.Image,
|
||||
|
@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
import json
|
||||
from collections import UserDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.inputs import InputProcessingContext
|
||||
@ -16,15 +13,10 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils import ClassRegistry
|
||||
|
||||
from .audio import AudioPlugin
|
||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
ProcessingCache)
|
||||
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
|
||||
DummyEncoderData, MultiModalProfiler)
|
||||
from .video import VideoPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -85,169 +77,23 @@ class _ProcessorFactories(Generic[_I]):
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]):
|
||||
"""
|
||||
Wraps `_limits_by_model` for a more informative error message
|
||||
when attempting to access a model that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: "ModelConfig") -> dict[str, int]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
|
||||
"forget to call `init_mm_limits_per_prompt`?")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
# This is used for non-multimodal models
|
||||
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
||||
|
||||
self._limits_by_model = _MultiModalLimits()
|
||||
|
||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
||||
"""
|
||||
Register a multi-modal plugin so it can be recognized by vLLM.
|
||||
"""
|
||||
data_type_key = plugin.get_data_key()
|
||||
|
||||
if data_type_key in self._plugins:
|
||||
logger.warning(
|
||||
"A plugin is already registered for data type %s, "
|
||||
"and will be overwritten by the new plugin %s.", data_type_key,
|
||||
plugin)
|
||||
|
||||
self._plugins[data_type_key] = plugin
|
||||
|
||||
def _get_plugin(self, data_type_key: str):
|
||||
plugin = self._plugins.get(data_type_key)
|
||||
if plugin is not None:
|
||||
return plugin
|
||||
|
||||
msg = f"Unknown multi-modal data type: {data_type_key}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def register_input_mapper(
|
||||
self,
|
||||
data_type_key: str,
|
||||
mapper: Optional[MultiModalInputMapper] = None,
|
||||
):
|
||||
"""
|
||||
Register an input mapper for a specific modality to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
||||
"""
|
||||
return self._get_plugin(data_type_key).register_input_mapper(mapper)
|
||||
|
||||
def register_image_input_mapper(
|
||||
self,
|
||||
mapper: Optional[MultiModalInputMapper] = None,
|
||||
):
|
||||
"""
|
||||
Register an input mapper for image data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
||||
"""
|
||||
return self.register_input_mapper("image", mapper)
|
||||
|
||||
def map_input(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Apply an input mapper to the data passed to the model.
|
||||
|
||||
The data belonging to each modality is passed to the corresponding
|
||||
plugin which in turn converts the data into into keyword arguments
|
||||
via the input mapper registered for that model.
|
||||
|
||||
See :meth:`MultiModalPlugin.map_input` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
merged_dict = dict[str, NestedTensors]()
|
||||
|
||||
for data_key, data_value in data.items():
|
||||
plugin = self._get_plugin(data_key)
|
||||
|
||||
num_items = len(data_value) if isinstance(data_value, list) else 1
|
||||
max_items = self._limits_by_model[model_config][data_key]
|
||||
if num_items > max_items:
|
||||
raise ValueError(
|
||||
f"You set '{json.dumps({data_key: max_items})}' (or "
|
||||
"defaulted to 1) in `--limit-mm-per-prompt`, but found "
|
||||
f"{num_items} items in the same prompt.")
|
||||
|
||||
input_dict = plugin.map_input(model_config, data_value,
|
||||
mm_processor_kwargs)
|
||||
for input_key, input_tensor in input_dict.items():
|
||||
if input_key in merged_dict:
|
||||
raise ValueError(f"The input mappers (keys={set(data)}) "
|
||||
f"resulted in a conflicting keyword "
|
||||
f"argument to `forward()`: {input_key}")
|
||||
|
||||
merged_dict[input_key] = input_tensor
|
||||
|
||||
return MultiModalKwargs(merged_dict)
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def create_input_mapper(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
Create an input mapper (see :meth:`map_input`) for a specific model.
|
||||
"""
|
||||
# NOTE - we currently make the assumption that if a model has multiple
|
||||
# supported modalities, they take the same kwargs. For the default,
|
||||
# this could be an issue in the future if it falls back to two HF
|
||||
# resources and we can't inspect the signature easily since it's
|
||||
# getting initialized through the autoclass.
|
||||
#
|
||||
# If this is a problem in the future, we should revisit it, but since
|
||||
# it potentially introduces a lot of complexity for a currently
|
||||
# uncommon case, we do not for simplicity of both use & implementation
|
||||
return functools.partial(self.map_input, model_config)
|
||||
|
||||
def register_max_multimodal_tokens(
|
||||
self,
|
||||
data_type_key: str,
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data belonging to a specific modality, that are
|
||||
passed to the language model for a model class.
|
||||
"""
|
||||
return self._get_plugin(data_type_key) \
|
||||
.register_max_multimodal_tokens(max_mm_tokens)
|
||||
|
||||
def register_max_image_tokens(
|
||||
self,
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of image tokens, corresponding to a single
|
||||
image, that are passed to the language model for a model class.
|
||||
"""
|
||||
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
||||
return lambda data, mm_processor_kwargs: data
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
@ -257,25 +103,22 @@ class MultiModalRegistry:
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if self.has_processor(model_config):
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{
|
||||
modality: 1
|
||||
for modality, limit in mm_limits.items() if limit > 0
|
||||
},
|
||||
)
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
for key, plugin in self._plugins.items()
|
||||
}
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{
|
||||
modality: 1
|
||||
for modality, limit in mm_limits.items() if limit > 0
|
||||
},
|
||||
)
|
||||
|
||||
def get_max_tokens_per_item_by_nonzero_modality(
|
||||
self,
|
||||
@ -308,9 +151,6 @@ class MultiModalRegistry:
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
@ -326,47 +166,18 @@ class MultiModalRegistry:
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
return sum(self.get_max_tokens_by_modality(model_config).values())
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def init_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the maximum number of multi-modal input instances for each
|
||||
modality that are allowed per prompt for a model class.
|
||||
"""
|
||||
if model_config in self._limits_by_model:
|
||||
logger.warning(
|
||||
"`mm_limits` has already been set for model=%s, and will "
|
||||
"be overwritten by the new values.", model_config.model)
|
||||
|
||||
multimodal_config = model_config.multimodal_config
|
||||
if multimodal_config is None:
|
||||
limits_per_plugin = self._disabled_limits_per_plugin
|
||||
else:
|
||||
config_limits_per_plugin = multimodal_config.limit_per_prompt
|
||||
|
||||
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
|
||||
if extra_keys:
|
||||
logger.warning(
|
||||
"Detected extra keys in `--limit-mm-per-prompt` which "
|
||||
"are not registered as multi-modal plugins: %s. "
|
||||
"They will be ignored.", extra_keys)
|
||||
|
||||
# NOTE: Currently the default is set to 1 for each plugin
|
||||
# TODO: Automatically determine the limits based on budget
|
||||
# once more models support multi-image inputs
|
||||
limits_per_plugin = {
|
||||
key: multimodal_config.get_limit_per_prompt(key)
|
||||
for key in self._plugins
|
||||
}
|
||||
|
||||
self._limits_by_model[model_config] = limits_per_plugin
|
||||
pass
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
@ -375,16 +186,13 @@ class MultiModalRegistry:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
if self.has_processor(model_config):
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
return self._limits_by_model[model_config]
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
@ -428,14 +236,12 @@ class MultiModalRegistry:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Test whether a multi-modal processor is defined for a specific model.
|
||||
|
||||
See also:
|
||||
:ref:`mm-processing`
|
||||
"""
|
||||
return self._get_model_cls(model_config) in self._processor_factories
|
||||
return True
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
@ -450,6 +256,9 @@ class MultiModalRegistry:
|
||||
See also:
|
||||
:ref:`mm-processing`
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
if disable_cache is None:
|
||||
|
@ -4,80 +4,13 @@ import base64
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_get_video_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MediaIO, ModalityData
|
||||
from .image import ImageMediaIO, ImagePlugin
|
||||
from .inputs import MultiModalKwargs, VideoItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VideoPlugin(ImagePlugin):
|
||||
"""Plugin for video data."""
|
||||
|
||||
def get_data_key(self) -> str:
|
||||
return "video"
|
||||
|
||||
def _get_hf_video_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return cached_get_video_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: ModalityData[VideoItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
data = data[0] # type: ignore
|
||||
|
||||
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
|
||||
video_processor = self._get_hf_video_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
if video_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the video object")
|
||||
try:
|
||||
# NOTE: Similar to image; it may be a good idea to filter and
|
||||
# pass mm_processor_kwargs here too, but for now we don't to
|
||||
# avoid extra complexity if the initializer and preprocess
|
||||
# signatures of the processor don't align
|
||||
batch_data = video_processor(data, return_tensors="pt").data
|
||||
except Exception:
|
||||
logger.error("Failed to process video (%s)", data)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
|
||||
raise TypeError(f"Invalid video type: {type(data)}")
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
return 4096
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
|
@ -14,9 +14,9 @@ from typing import Any, Callable, Optional, Union
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.inputs import SingletonInputs, SingletonInputsAdapter
|
||||
from vllm.inputs import SingletonInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
|
||||
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
@ -419,7 +419,7 @@ class Sequence:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = SingletonInputsAdapter(inputs)
|
||||
self.inputs = inputs
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
@ -448,31 +448,29 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
return self.inputs.prompt
|
||||
return self.inputs.get("prompt")
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
return self.inputs.prompt_token_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self.inputs.prompt_embeds
|
||||
return self.inputs["prompt_token_ids"]
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> list[int]:
|
||||
return self.inputs.token_type_ids
|
||||
return self.inputs.get("token_type_ids", [])
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
return self.inputs.multi_modal_data
|
||||
def multi_modal_data(self) -> MultiModalKwargs:
|
||||
if self.inputs["type"] == "multimodal":
|
||||
return self.inputs["mm_kwargs"]
|
||||
|
||||
return MultiModalKwargs({})
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
return self.inputs.multi_modal_placeholders
|
||||
if self.inputs["type"] == "multimodal":
|
||||
return self.inputs["mm_placeholders"]
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> dict[str, Any]:
|
||||
return self.inputs.mm_processor_kwargs
|
||||
return {}
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
@ -723,12 +721,12 @@ class SequenceGroup:
|
||||
return self.first_seq.token_type_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> MultiModalDataDict:
|
||||
def multi_modal_data(self) -> MultiModalKwargs:
|
||||
if self.first_seq.multi_modal_data:
|
||||
return self.first_seq.multi_modal_data
|
||||
elif self.encoder_seq is not None:
|
||||
return self.encoder_seq.multi_modal_data
|
||||
return {}
|
||||
return MultiModalKwargs({})
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
@ -738,14 +736,6 @@ class SequenceGroup:
|
||||
return self.encoder_seq.multi_modal_placeholders
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> dict[str, Any]:
|
||||
if self.first_seq.multi_modal_data:
|
||||
return self.first_seq.mm_processor_kwargs
|
||||
elif self.encoder_seq is not None:
|
||||
return self.encoder_seq.mm_processor_kwargs
|
||||
return {}
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
@ -969,12 +959,9 @@ class SequenceGroupMetadata(
|
||||
computed_block_nums: Optional[list[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
token_type_ids: Optional[list[int]] = None
|
||||
multi_modal_data: Optional[Any] = None
|
||||
multi_modal_data: Optional[MultiModalKwargs] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[list[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
@ -208,38 +208,3 @@ def cached_image_processor_from_config(
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
def get_video_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load a video processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
processor = get_processor(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return cast(BaseImageProcessor, processor.video_processor)
|
||||
|
||||
|
||||
cached_get_video_processor = lru_cache(get_video_processor)
|
||||
|
||||
|
||||
def cached_video_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
**kwargs: Any,
|
||||
):
|
||||
return cached_get_video_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap)
|
||||
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
|
||||
MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -154,7 +154,6 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
if self.runner.attn_backend is not None:
|
||||
# spec decode (e.g. Medusa) does not have atten backend
|
||||
@ -359,22 +358,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = self.input_data.seq_lens[-1]
|
||||
|
||||
# NOTE: mm_data only includes the subset of multi-modal items that
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group_metadata, range(computed_len, seq_len))
|
||||
|
||||
if not mm_data:
|
||||
if not mm_kwargs:
|
||||
return
|
||||
|
||||
if self.runner.mm_registry.has_processor(self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
if self.runner.model_config.uses_mrope:
|
||||
assert not self.chunked_prefill, \
|
||||
@ -480,12 +471,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
use_mla=self.model_config.use_mla,
|
||||
) if needs_attn_backend else None
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.multi_modal_input_mapper = self.mm_registry \
|
||||
.create_input_mapper(self.model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
# Set after load_model.
|
||||
|
@ -100,6 +100,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
input_registry=input_registry,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# Crash for unsupported encoder/scenarios
|
||||
|
@ -45,8 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceData, SequenceGroupMetadata,
|
||||
@ -545,10 +544,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
]
|
||||
gc.set_threshold(*requested_gc_thrs)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
|
||||
'false').lower() == 'true'
|
||||
|
||||
@ -731,9 +726,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(list(range(context_len, seq_len)))
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
mm_kwargs = seq_group_metadata.multi_modal_data
|
||||
if mm_kwargs:
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
|
@ -457,7 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
|
||||
is not None)
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
|
||||
# Attention metadata inputs.
|
||||
if self.attn_backend is not None:
|
||||
@ -675,23 +674,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If multi-modal data is given, add it to the input."""
|
||||
# NOTE: mm_data only includes the subset of multi-modal items that
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
positions = inter_data.input_positions[0]
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group_metadata,
|
||||
range(positions[0], positions[0] + len(positions)))
|
||||
if not mm_data:
|
||||
if not mm_kwargs:
|
||||
return
|
||||
|
||||
if self.runner.mm_registry.has_processor(self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
inter_data.multi_modal_kwargs = mm_kwargs
|
||||
inter_data.multi_modal_placeholder_maps = placeholder_maps
|
||||
|
||||
@ -1085,9 +1076,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry \
|
||||
.create_input_mapper(model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
@ -1327,8 +1315,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
|
@ -15,8 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
@ -69,11 +68,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.multi_modal_input_mapper = self.mm_registry \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
@ -149,16 +143,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
if self.mm_registry.has_processor(self.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
mm_kwargs = seq_group_metadata.multi_modal_data
|
||||
if mm_kwargs:
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
|
@ -188,20 +188,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
input_positions.extend(list(positions_range))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
# NOTE: mm_data only includes the subset of multi-modal items
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items
|
||||
# that intersect with the current prefill positions.
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap \
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
if self.runner.mm_registry.has_processor(
|
||||
self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.runner.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
@ -404,9 +395,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry \
|
||||
.create_input_mapper(model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
Reference in New Issue
Block a user