mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (#23779)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
|
||||
request = EngineCoreRequest("",
|
||||
prompt_token_ids,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
params,
|
||||
None,
|
||||
None,
|
||||
|
@ -7,7 +7,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@ -37,17 +38,20 @@ def make_request(
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
mm_kwargs = None
|
||||
else:
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
for j, position in enumerate(mm_positions):
|
||||
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
return Request(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
|
@ -9,7 +9,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@ -32,17 +33,20 @@ def make_request(
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
mm_kwargs = None
|
||||
else:
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
for j, position in enumerate(mm_positions):
|
||||
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
return Request(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=17, prompt_logprobs=prompt_logprobs),
|
||||
pooling_params=None,
|
||||
|
@ -8,7 +8,8 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@ -1308,21 +1309,24 @@ def create_requests_with_priority(
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_kwargs = None
|
||||
for j, position in enumerate(mm_position):
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
request = Request(
|
||||
request_id=f"{i + starting_idx}",
|
||||
prompt_token_ids=[i + starting_idx] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=arrival_times[i],
|
||||
priority=priorities[i],
|
||||
@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1],
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_hashes=None,
|
||||
multi_modal_placeholders=None,
|
||||
mm_features=None,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
|
@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
@ -139,19 +140,20 @@ def create_requests(
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
mm_hashes = [
|
||||
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
|
||||
]
|
||||
else:
|
||||
mm_position = None
|
||||
mm_kwargs = None
|
||||
mm_hashes = None
|
||||
for j, position in enumerate(mm_position):
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||
num_tokens)
|
||||
request = Request(
|
||||
@ -159,9 +161,7 @@ def create_requests(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
|
@ -52,9 +52,7 @@ def make_request(
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
|
@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
prompt_token_ids = [107, 4606, 236787, 107]
|
||||
params = SamplingParams(skip_special_tokens=True)
|
||||
request = EngineCoreRequest(
|
||||
"test",
|
||||
prompt_token_ids,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
params,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
request_id="test",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
request = EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
EngineCoreRequest(
|
||||
request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
|
@ -162,9 +162,7 @@ def create_request(request_id: int,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
mm_features=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
||||
)
|
||||
|
@ -12,9 +12,9 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, LRUCache
|
||||
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
|
||||
|
||||
from .inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
NestedTensors)
|
||||
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
|
||||
MultiModalKwargsItem]):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""Update multimodal features with cached encoder outputs."""
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data,
|
||||
feature.identifier)
|
||||
return mm_features
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
|
@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalFeatureSpec:
|
||||
"""
|
||||
Represents a single multimodal input with its processed data and metadata.
|
||||
|
||||
Used by the V1 engine to track multimodal data through processing and
|
||||
caching. A request containing multiple multimodal items will have one
|
||||
MultiModalFeatureSpec per item.
|
||||
"""
|
||||
|
||||
data: Optional["MultiModalKwargsItem"]
|
||||
"""Multimodal data for this feature"""
|
||||
|
||||
modality: str
|
||||
"""Based on the input, e.g., "image", "audio", "video"."""
|
||||
|
||||
identifier: str
|
||||
"""mm_hash or uuid for caching encoder outputs."""
|
||||
|
||||
mm_position: PlaceholderRange
|
||||
"""e.g., PlaceholderRange(offset=2, length=336)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalFieldElem:
|
||||
"""
|
||||
|
@ -3,14 +3,13 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
@ -48,9 +47,7 @@ class EngineCoreRequest(
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
eos_token_id: Optional[int]
|
||||
|
@ -434,15 +434,13 @@ class EngineCore:
|
||||
This function could be directly used in input processing thread to allow
|
||||
request initialization running in parallel with Model forward
|
||||
"""
|
||||
if request.mm_hashes is not None:
|
||||
assert request.mm_kwargs is not None
|
||||
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_receiver_cache` is reset at the end of LLMEngine init,
|
||||
# and will only accessed in the input processing thread afterwards.
|
||||
if self.mm_receiver_cache is not None:
|
||||
request.mm_kwargs = self.mm_receiver_cache.get_and_update(
|
||||
request.mm_kwargs, request.mm_hashes)
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_receiver_cache` is reset at the end of LLMEngine init,
|
||||
# and will only accessed in the input processing thread afterwards.
|
||||
if self.mm_receiver_cache is not None and request.mm_features:
|
||||
request.mm_features = (
|
||||
self.mm_receiver_cache.get_and_update_features(
|
||||
request.mm_features))
|
||||
|
||||
req = Request.from_engine_core_request(request,
|
||||
self.request_block_hasher)
|
||||
|
@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -346,9 +346,8 @@ class Processor:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
||||
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
@ -359,25 +358,19 @@ class Processor:
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
sorted_mm_inputs = [
|
||||
decoder_mm_inputs[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
sorted_mm_positions = [
|
||||
decoder_mm_positions[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
sorted_mm_hashes = [
|
||||
decoder_mm_hashes[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx]))
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
mm_kwargs=sorted_mm_inputs,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
|
@ -6,10 +6,9 @@ import time
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
@ -26,14 +25,12 @@ class Request:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
@ -89,16 +86,14 @@ class Request:
|
||||
self.cache_salt: Optional[str] = cache_salt
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_positions = multi_modal_placeholders or []
|
||||
self.mm_kwargs = multi_modal_kwargs or []
|
||||
self.mm_hashes: list[str] = multi_modal_hashes or []
|
||||
self.num_encoder_inputs = len(self.mm_kwargs)
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_kwargs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
assert len(self.mm_kwargs) == len(self.mm_hashes)
|
||||
# TODO(sfeng33): Remove these legacy fields after clearing out all
|
||||
# references in scheduler and model runner
|
||||
self.mm_positions = [f.mm_position for f in self.mm_features]
|
||||
self.mm_kwargs = [f.data for f in self.mm_features]
|
||||
self.mm_hashes = [f.identifier for f in self.mm_features]
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
@ -126,20 +121,11 @@ class Request:
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
mm_kwargs_lst = list(request.mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
else:
|
||||
mm_kwargs_lst = None
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs_lst,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
|
Reference in New Issue
Block a user