[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (#23779)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2025-08-29 03:36:57 -07:00
committed by GitHub
parent d9e00dbd1f
commit 69f46359dd
16 changed files with 143 additions and 146 deletions

View File

@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
None,

View File

@ -7,7 +7,8 @@ import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -37,17 +38,20 @@ def make_request(
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,

View File

@ -9,7 +9,8 @@ import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
@ -32,17 +33,20 @@ def make_request(
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(
max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None,

View File

@ -8,7 +8,8 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
@ -1308,21 +1309,24 @@ def create_requests_with_priority(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
request = Request(
request_id=f"{i + starting_idx}",
prompt_token_ids=[i + starting_idx] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_kwargs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,

View File

@ -6,7 +6,8 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
@ -139,19 +140,20 @@ def create_requests(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
mm_hashes = [
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
]
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
@ -159,9 +161,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=mm_hashes,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)

View File

@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,

View File

@ -52,9 +52,7 @@ def make_request(
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,

View File

@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
request_id="test",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
request = EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=eos_token_id,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest(
request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,

View File

@ -162,9 +162,7 @@ def create_request(request_id: int,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
mm_features=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)

View File

@ -12,9 +12,9 @@ from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalKwargsItems, NestedTensors)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
MultiModalKwargsItem]):
"""The required interface for caches on P1."""
def get_and_update_features(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""Update multimodal features with cached encoder outputs."""
for feature in mm_features:
feature.data = self.get_and_update_item(feature.data,
feature.identifier)
return mm_features
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""

View File

@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via
"""
@dataclass
class MultiModalFeatureSpec:
"""
Represents a single multimodal input with its processed data and metadata.
Used by the V1 engine to track multimodal data through processing and
caching. A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
"""
data: Optional["MultiModalKwargsItem"]
"""Multimodal data for this feature"""
modality: str
"""Based on the input, e.g., "image", "audio", "video"."""
identifier: str
"""mm_hash or uuid for caching encoder outputs."""
mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
@dataclass
class MultiModalFieldElem:
"""

View File

@ -3,14 +3,13 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
@ -48,9 +47,7 @@ class EngineCoreRequest(
request_id: str
prompt_token_ids: list[int]
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
mm_features: Optional[list[MultiModalFeatureSpec]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int]

View File

@ -434,15 +434,13 @@ class EngineCore:
This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward
"""
if request.mm_hashes is not None:
assert request.mm_kwargs is not None
# Note on thread safety: no race condition.
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
if self.mm_receiver_cache is not None:
request.mm_kwargs = self.mm_receiver_cache.get_and_update(
request.mm_kwargs, request.mm_hashes)
# Note on thread safety: no race condition.
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
if self.mm_receiver_cache is not None and request.mm_features:
request.mm_features = (
self.mm_receiver_cache.get_and_update_features(
request.mm_features))
req = Request.from_engine_core_request(request,
self.request_block_hasher)

View File

@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
@ -346,9 +346,8 @@ class Processor:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
mm_features: Optional[list[MultiModalFeatureSpec]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
@ -359,25 +358,19 @@ class Processor:
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
sorted_mm_inputs = [
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
sorted_mm_positions = [
decoder_mm_positions[modality][idx]
for modality, idx in sorted_mm_idxs
]
sorted_mm_hashes = [
decoder_mm_hashes[modality][idx]
for modality, idx in sorted_mm_idxs
]
mm_features = []
for modality, idx in sorted_mm_idxs:
mm_features.append(
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx]))
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_kwargs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,

View File

@ -6,10 +6,9 @@ import time
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
@ -26,14 +25,12 @@ class Request:
self,
request_id: str,
prompt_token_ids: list[int],
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
arrival_time: Optional[float] = None,
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
@ -89,16 +86,14 @@ class Request:
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_kwargs = multi_modal_kwargs or []
self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_kwargs)
self.mm_features = mm_features or []
self.num_encoder_inputs = len(self.mm_features)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check
assert len(self.mm_kwargs) == len(self.mm_positions)
if self.mm_hashes:
assert len(self.mm_kwargs) == len(self.mm_hashes)
# TODO(sfeng33): Remove these legacy fields after clearing out all
# references in scheduler and model runner
self.mm_positions = [f.mm_position for f in self.mm_features]
self.mm_kwargs = [f.data for f in self.mm_features]
self.mm_hashes = [f.identifier for f in self.mm_features]
# Read-only views
# Prevent directly appending to these lists since
@ -126,20 +121,11 @@ class Request:
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
mm_kwargs_lst = list(request.mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
else:
mm_kwargs_lst = None
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
multi_modal_kwargs=mm_kwargs_lst,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,