Compare commits

...

1 Commits

Author SHA1 Message Date
42018e8d96 Revert "Implicit language-model-only mode via limit-mm-per-prompt (#22299)"
This reverts commit 08b751ba749541259e5450d6371d822fdf769b8a.
2025-08-08 22:51:13 -07:00
16 changed files with 116 additions and 271 deletions

View File

@ -1,38 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for MultiModalRegistry.supports_multimodal_inputs and
Qwen2.5-VL visual component loading behavior.
"""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ..models.utils import build_model_context
@pytest.mark.parametrize(
"model_id,limit_mm_per_prompt,expected",
[
("Qwen/Qwen2-0.5B-Instruct", {}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0,
"video": 0
}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0
}, True),
],
)
@pytest.mark.core_model
def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
"""Test supports_multimodal_inputs returns correct boolean for various
configs."""
ctx = build_model_context(
model_id,
limit_mm_per_prompt=limit_mm_per_prompt,
)
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(
ctx.model_config) is expected

View File

@ -1695,6 +1695,15 @@ class ModelConfig:
return mm_config.mm_processor_cache_gb > 0 return mm_config.mm_processor_cache_gb > 0
@property
def enable_mm_input_cache(self) -> bool:
"""Whether the multi-modal input cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int: def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config mm_config = self.multimodal_config
if mm_config is None: if mm_config is None:

View File

@ -521,22 +521,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"): self.vision_tower = init_vision_tower_for_llava(
self.vision_tower = init_vision_tower_for_llava( config,
config, quant_config,
quant_config, require_post_norm=False,
require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"))
prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaMultiModalProjector(
self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size,
vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act,
projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=config.multimodal_projector_bias,
multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config,
quant_config=quant_config, prefix=maybe_prefix(prefix, "multi_modal_projector"))
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -760,11 +756,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -428,24 +428,20 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"): self.vision_tower = init_vision_tower_for_llava(
self.vision_tower = init_vision_tower_for_llava( config,
config, quant_config,
quant_config, require_post_norm=False,
require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"))
prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = Mistral3MultiModalProjector(
self.multi_modal_projector = Mistral3MultiModalProjector( vision_hidden_size=config.vision_config.hidden_size,
vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act,
projector_hidden_act=config.projector_hidden_act, spatial_merge_size=config.spatial_merge_size,
spatial_merge_size=config.spatial_merge_size, patch_size=config.vision_config.patch_size,
patch_size=config.vision_config.patch_size, multimodal_projector_bias=config.multimodal_projector_bias,
multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config,
quant_config=quant_config, prefix=maybe_prefix(prefix, "multi_modal_projector"))
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -615,11 +611,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -737,20 +737,16 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Llama4VisionModel(
self.vision_model = Llama4VisionModel( config.vision_config,
config.vision_config, None,
None, prefix=maybe_prefix(prefix, "vision_model"),
prefix=maybe_prefix(prefix, "vision_model"), use_data_parallel=self.use_data_parallel,
use_data_parallel=self.use_data_parallel, )
) self.multi_modal_projector = Llama4MultiModalProjector(
self.multi_modal_projector = Llama4MultiModalProjector( self.config,
self.config, None,
None, prefix=maybe_prefix(prefix, "multi_modal_projector"))
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_model = None
self.multi_modal_projector = None
self.language_model = initialize_model( self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config, vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]), ["LlamaForCausalLM"]),
@ -787,8 +783,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input( def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"] flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist() patches_per_image = image_input["patches_per_image"].tolist()
@ -1054,10 +1048,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model_weights, other_weights = ( language_model_weights, other_weights = (
self._separate_and_rename_weights(weights)) self._separate_and_rename_weights(weights))
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
other_weights = []
# Handle expert scale parameters # Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = ( regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights, self._handle_expert_scale_broadcasting(language_model_weights,

View File

@ -722,24 +722,13 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"exactly same result as the transformers implementation " "exactly same result as the transformers implementation "
"in the audio tower part.") "in the audio tower part.")
if multimodal_config.get_limit_per_prompt("audio"): self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
self.audio_tower = Qwen2_5OmniAudioEncoder( self.visual = Qwen2_5_VisionTransformer(
thinker_config.audio_config) vision_config=thinker_config.vision_config,
else: norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
self.audio_tower = None quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
if multimodal_config.get_limit_per_prompt( )
"image") or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps",
1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -897,15 +886,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."]
if self.audio_tower is None:
skip_prefixes.extend(["audio_tower."])
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=skip_prefixes, skip_prefixes=["talker.", "token2wav."],
) )
loaded_weights = loader.load_weights(weights, loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper) mapper=self.hf_to_vllm_mapper)

View File

@ -843,17 +843,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image") or \ self.visual = Qwen2_5_VisionTransformer(
multimodal_config.get_limit_per_prompt("video"): config.vision_config,
self.visual = Qwen2_5_VisionTransformer( norm_eps=getattr(config, "rms_norm_eps", 1e-6),
config.vision_config, quant_config=self._maybe_ignore_quant_config(self.quant_config),
norm_eps=getattr(config, "rms_norm_eps", 1e-6), prefix=maybe_prefix(prefix, "visual"),
quant_config=self._maybe_ignore_quant_config( )
self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -1157,10 +1152,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -1049,16 +1049,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image") or \ self.visual = Qwen2VisionTransformer(
multimodal_config.get_limit_per_prompt("video"): config.vision_config,
self.visual = Qwen2VisionTransformer( norm_eps=getattr(config, "rms_norm_eps", 1e-6),
config.vision_config, quant_config=self._maybe_ignore_quant_config(quant_config),
norm_eps=getattr(config, "rms_norm_eps", 1e-6), prefix=maybe_prefix(prefix, "visual"),
quant_config=self._maybe_ignore_quant_config(quant_config), )
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -1354,10 +1350,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
@ -1452,8 +1445,5 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [] loader = AutoWeightsLoader(self)
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -837,35 +837,27 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer(config.vision_config,
self.vision_model = Step3VisionTransformer(config.vision_config, None,
None, prefix=maybe_prefix(
prefix=maybe_prefix( prefix, "vision_model"))
prefix, self.vit_downsampler = nn.Conv2d(
"vision_model")) config.vision_config.hidden_size,
self.vit_downsampler = nn.Conv2d( config.vision_config.output_hidden_size,
config.vision_config.hidden_size, kernel_size=2,
config.vision_config.output_hidden_size, stride=config.understand_projector_stride)
kernel_size=2, self.vit_downsampler2 = nn.Conv2d(
stride=config.understand_projector_stride) config.vision_config.output_hidden_size,
self.vit_downsampler2 = nn.Conv2d( config.vision_config.output_hidden_size * 2,
config.vision_config.output_hidden_size, kernel_size=3,
config.vision_config.output_hidden_size * 2, stride=2,
kernel_size=3, padding=1,
stride=2, )
padding=1, self.vit_large_projector = nn.Linear(
) config.vision_config.output_hidden_size * 2,
self.vit_large_projector = nn.Linear( config.hidden_size,
config.vision_config.output_hidden_size * 2, bias=config.projector_bias,
config.hidden_size, )
bias=config.projector_bias,
)
else:
self.vision_model = None
self.vit_downsampler = None
self.vit_downsampler2 = None
self.vit_large_projector = None
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
@ -1054,15 +1046,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.vision_model is None and self.vit_large_projector is None:
skip_prefixes = [
"vision_model.", "vit_downsampler.", "vit_downsampler2.",
"vit_large_projector."
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights, loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper) mapper=self.hf_to_vllm_mapper)
return loaded_weights return loaded_weights

View File

@ -115,45 +115,6 @@ class MultiModalRegistry:
return True # Success return True # Success
def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
"""Whether the multi-modal input cache should be enabled.
NOTE: This is put under MultiModalRegistry on purpose to respect
text-only mode for multimodal models.
"""
if not self.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in
text-only mode.
"""
if not model_config.is_multimodal_model:
return False
processor = self.create_processor(model_config, disable_cache=False)
supported_modalities = processor.info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config()
# Check if all supported modalities have limit == 0
if all(
mm_config.get_limit_per_prompt(modality) == 0
for modality in supported_modalities):
logger.info_once(
"All limits of multimodal modalities supported by the model "
"are set to 0, running in text-only mode.")
return False
return True
def get_max_tokens_per_item_by_modality( def get_max_tokens_per_item_by_modality(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",

View File

@ -189,7 +189,7 @@ def compute_encoder_budget(
in the input sequence. in the input sequence.
""" """
if not mm_registry.supports_multimodal_inputs(model_config): if not model_config.is_multimodal_model:
return 0, 0 return 0, 0
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.

View File

@ -21,7 +21,6 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
@ -126,7 +125,7 @@ class EngineCore:
) )
self.mm_input_cache_server = MultiModalInputCacheServer( self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY) vllm_config.model_config)
# Setup batch queue for pipeline parallelism. # Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously # Batch queue for scheduled batches. This enables us to asynchronously

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -46,11 +46,10 @@ if TYPE_CHECKING:
class MultiModalInputCacheClient: class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1.""" """Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config: "ModelConfig", def __init__(self, model_config: "ModelConfig") -> None:
mm_registry: MultiModalRegistry) -> None:
super().__init__() super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config) self.enabled = model_config.enable_mm_input_cache
self.mm_cache = MultiModalCache.get_lru_cache( self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(), model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata, MultiModalCacheItemMetadata,
@ -86,11 +85,10 @@ class MultiModalInputCacheClient:
class MultiModalInputCacheServer: class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0.""" """Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig", def __init__(self, model_config: "ModelConfig") -> None:
mm_registry: MultiModalRegistry) -> None:
super().__init__() super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config) self.enabled = model_config.enable_mm_input_cache
self.mm_cache = MultiModalCache.get_lru_cache( self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(), model_config.get_mm_input_cache_gb(),
MultiModalKwargs, MultiModalKwargs,

View File

@ -51,7 +51,7 @@ class Processor:
mm_registry) mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient( self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config, mm_registry) self.model_config)
@property @property
def mm_registry(self): def mm_registry(self):

View File

@ -129,6 +129,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype] cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False self.is_encoder_only_model = False
self.is_multimodal_raw_input_supported = ( self.is_multimodal_raw_input_supported = (
@ -148,8 +149,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@ -331,8 +330,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry, self.mm_registry,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
) if self.supports_mm_inputs \ ) if self.is_multimodal_model else None)
else None)
self.reorder_batch_threshold: Optional[int] = None self.reorder_batch_threshold: Optional[int] = None
@ -1481,14 +1479,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
if self.supports_mm_inputs: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else: else:
mm_embeds = [] mm_embeds = []
if self.supports_mm_inputs and get_pp_group().is_first_rank: if self.is_multimodal_model and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
@ -1819,7 +1817,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
mm_embeds = None mm_embeds = None
if self.supports_mm_inputs: if self.is_multimodal_model:
mm_embeds = self._gather_mm_embeddings(scheduler_output, mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1) shift_computed_tokens=1)
@ -2211,7 +2209,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
if self.supports_mm_inputs: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
@ -2419,7 +2417,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.is_multimodal_model:
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None

View File

@ -157,6 +157,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype] cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype self._hidden_states_dtype = self.dtype
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
@ -192,8 +193,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
# TODO: Support M-RoPE (e.g, Qwen2-VL) # TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet." assert not self.uses_mrope, "TPU does not support M-RoPE yet."
@ -294,7 +293,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry, self.mm_registry,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
) if self.supports_mm_inputs else None) ) if self.is_multimodal_model else None)
if not self.use_spmd: if not self.use_spmd:
self.sample_from_logits_func = torch.compile( self.sample_from_logits_func = torch.compile(
@ -948,7 +947,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_model_inputs(self, input_ids: torch.Tensor, def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]): mm_embeds: list[torch.Tensor]):
if self.supports_mm_inputs: if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
@ -980,7 +979,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output, return self.kv_connector_no_forward(scheduler_output,
self.vllm_config) self.vllm_config)
if self.supports_mm_inputs: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
@ -1231,7 +1230,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@torch.no_grad() @torch.no_grad()
def _dummy_run(self, num_tokens: int, num_reqs: int, def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None: num_blocks: int) -> None:
if self.supports_mm_inputs: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size), inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
@ -1272,7 +1271,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
_num_slices_per_kv_cache_update_block, _num_slices_per_kv_cache_update_block,
) )
if self.supports_mm_inputs: if self.is_multimodal_model:
torch._dynamo.mark_dynamic(inputs_embeds, 0) torch._dynamo.mark_dynamic(inputs_embeds, 0)
else: else:
torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(input_ids, 0)
@ -1306,7 +1305,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xm.mark_step() # Captures metadata updates xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None: def _precompile_mm_encoder(self) -> None:
if not self.supports_mm_inputs: if not self.is_multimodal_model:
return return
# Pre-compile MM encoder for all supported data modalities. # Pre-compile MM encoder for all supported data modalities.
@ -1528,7 +1527,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens: int, num_tokens: int,
) -> None: ) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.is_multimodal_model:
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None
@ -1685,11 +1684,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self): def reset_dynamo_cache(self):
if self.is_multimodal_model:
# NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
# since the compiled model object of the language backbone of a
# multimodal model needs to be extracted via `get_language_model`.
if self.model_config.is_multimodal_model:
compiled_model = self.model.get_language_model().model compiled_model = self.model.get_language_model().model
else: else:
compiled_model = self.model.model compiled_model = self.model.model