[Core][Multimodal] Track encode cache entries by mm_hash and enable embedding sharing between requests (#22711)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Chenguang Zheng
2025-08-25 15:41:17 +08:00
committed by GitHub
parent 712d0f88d8
commit d765cf01fe
12 changed files with 365 additions and 154 deletions

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
# ------------------ Mock Classes ------------------ #
class MockRequest:
def __init__(self, request_id, mm_hashes, token_counts):
self.request_id = request_id
self.mm_hashes = mm_hashes
self._token_counts = token_counts
def get_num_encoder_tokens(self, input_id: int) -> int:
return self._token_counts[input_id]
# ------------------ Unit Tests ------------------ #
def test_basic_allocate_and_reuse():
cache = EncoderCacheManager(cache_size=10)
req = MockRequest("r1", ["imgA"], [4])
assert not cache.check_and_update_cache(req, 0)
assert cache.try_allocate(req, 0, int(1e9))
cache.allocate(req, 0)
assert cache.check_and_update_cache(req, 0)
assert "r1" in cache.cached["imgA"]
assert cache.num_free_slots == 6
# Free twice to bring refcount to 0.
cache.free_encoder_input(req, 0)
cache.free_encoder_input(req, 0)
assert not cache.cached["imgA"]
assert "imgA" in cache.freeable
assert cache.num_freeable_slots == 10
assert cache.num_free_slots == 6
def test_freeing_decreases_refcount_and_moves_to_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req2", ["img3"], [5])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
manager.free_encoder_input(req, 0)
assert not manager.cached["img3"]
assert "img3" in manager.freeable
assert manager.num_freeable_slots == 10
def test_free_request_frees_all_inputs():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req3", ["a", "b"], [2, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 1, int(1e9))
manager.allocate(req, 1)
assert len(manager.cached["a"]) == 1
assert len(manager.cached["b"]) == 1
manager.free(req)
assert not manager.cached["a"]
assert not manager.cached["b"]
assert "a" in manager.freeable
assert "b" in manager.freeable
assert manager.num_freeable_slots == 10
def test_eviction_when_cache_is_full():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["x"], [6])
req2 = MockRequest("req2", ["y"], [5])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
# 'x' should have been evicted.
assert "x" not in manager.cached
assert "x" in manager.get_freed_mm_hashes()
def test_get_cached_input_ids():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 2, int(1e9))
manager.allocate(req, 2)
cached_ids = manager.get_cached_input_ids(req)
assert cached_ids == {0, 2}
def test_has_cache_restores_from_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqY", ["imgZ"], [4])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
manager.free_encoder_input(req, 0)
# Should restore from freeable.
assert manager.check_and_update_cache(req, 0)
assert len(manager.cached["imgZ"]) == 1
assert "imgZ" not in manager.freeable
assert manager.num_freeable_slots == 6
def test_get_freed_mm_hashes_clears_freed_list():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("reqA", ["a"], [5])
req2 = MockRequest("reqB", ["b"], [6])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
# Should trigger eviction of 'a'.
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
freed = manager.get_freed_mm_hashes()
assert "a" in freed
assert manager.get_freed_mm_hashes() == []

View File

@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)

View File

@ -143,7 +143,11 @@ def create_requests(
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
mm_hashes = [
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
]
else:
mm_position = None
mm_kwargs = None

View File

@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)

View File

@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
mm_hashes=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),

View File

@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING
@ -31,34 +33,52 @@ class EncoderCacheManager:
within requests, allowing for fine-grained memory management and enabling
chunked processing of multimodal inputs.
Note that no caching is shared between requests at this time. If the same
input is used across multiple requests, it will be reprocessed for each
request.
Cache is enabled to share embeddings of same multimodal data
item (identified by their hash value) between different requests,
and eviction takes place at allocation time when there's no free
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens
num_free_slots: Current available cache capacity in encoder tokens
cached: Mapping from request_id to set of cached input_ids for that
request
freed: List of (request_id, input_id) pairs that were recently freed.
This is cleared after every call to get_freed_ids().
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
last call to get_freed_mm_hashes(). This list is cleared on return.
"""
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
self.num_freeable_slots = cache_size
def has_cache(self, request: Request, input_id: int) -> bool:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached.
If the encoder output is cached, update `cached` to add the request id
to the set of request ids that reference the cached encoder output.
If the encoder output was previously not referenced by any request,
update `freeable` and `num_freeable_slots` accordingly.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
@ -66,103 +86,151 @@ class EncoderCacheManager:
Returns:
True if the encoder output for this input is already cached
"""
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
mm_hash = request.mm_hashes[input_id]
# Not cached at all
if mm_hash not in self.cached:
return False
def can_allocate(self, request: Request, input_id: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
self.cached[mm_hash].add(request.request_id)
return True
def try_allocate(self, request: Request, input_id: int,
encoder_budget: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
If there is not enough free space in `num_free_slots` but there is
enough reclaimable space in `num_freeable_slots`, entries will be
evicted from `freeable` (their mm_hash appended to `freed`) until
enough space is available, and then this method returns True.
Older entries are evicted first.
Returns False only if the requested number of tokens exceeds both
the free and reclaimable capacities combined.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
Returns:
True if there's enough free cache space to store the encoder output
for this multimodal input
True if there's enough capacity to hold the encoder output for this
input (possibly after reclaiming `freeable` entries); otherwise
False.
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
# Not enough compute budget
if num_tokens > encoder_budget:
return False
# Enough free slots
if num_tokens <= self.num_free_slots:
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output.
This method reserves cache space for storing the encoder output of
the specified multimodal input. The actual encoder output storage
happens in the model runner, but this method ensures the cache
manager tracks the allocation.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
This reserves cache space for storing the encoder output of the
specified multimodal input. The actual encoder output storage happens in
the model runner; this method updates the manager's bookkeeping.
Note:
This method assumes can_allocate() returned True for the same
request and input_id. It will reduce available cache space.
This method assumes try_allocate() returned True for the same input.
"""
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
# Encoder cache space budget should be already updated for the
# multimodal input and non-negative after try_allocate() is called.
assert self.num_free_slots >= 0
assert self.num_freeable_slots >= 0
mm_hash = request.mm_hashes[input_id]
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
self.cached[mm_hash].add(request_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
Args:
request: The request to query
Returns:
Set of input_ids that have cached encoder outputs for this request.
Returns empty set if no inputs are cached for this request.
Returns the set of input IDs whose `mm_hash` exists in the cache map.
This includes entries that are currently unreferenced (and thus present
in `freeable`); for such entries, freeing for this request will be a
no-op.
"""
return self.cached.get(request.request_id, set())
return {
input_id
for input_id in range(len(request.mm_hashes))
if request.mm_hashes[input_id] in self.cached
}
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free cache space for a single multimodal input's encoder output.
"""Free the request's reference to the encoder input (`mm_data`)
This method is called when:
- The encoder output has been fully consumed by the decoder and is
no longer needed (e.g., in vision-language models after image
tokens are processed)
- A request is being cancelled or aborted
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input to free from cache
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
"""
req_id = request.request_id
if req_id not in self.cached:
mm_hash = request.mm_hashes[input_id]
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
def free(self, request: Request) -> None:
"""Free all cached encoder outputs for a request.
"""Free all encoder input cache reference held by *request*.
This method is typically called when a request is finished, cancelled,
or aborted, and all its encoder outputs should be freed from cache.
For each cached input ID, `free_encoder_input` is invoked.
The data stays in memory until eviction is triggered by a future
attempt allocation called by 'can_allocate'.
Args:
request: The request whose encoder outputs should be freed
Typically called when a request is finished, cancelled, or aborted.
"""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]:
def get_freed_mm_hashes(self) -> list[str]:
"""Get and clear the list of recently freed encoder cache entries.
This method returns all encoder cache entries that were freed since
the last call to this method. It's used by the scheduler to notify
workers about which encoder outputs can be removed from their caches.
Returns:
List of (request_id, input_id) tuples that were freed since the
last call. The internal freed list is cleared after this call.
List of mm_hash strings that were actually evicted since the last
call to be used by the scheduler to notify workers about which
encoder outputs can be removed from their caches. The internal
list is cleared after this call.
"""
freed = self.freed
self.freed = []
@ -177,16 +245,11 @@ def compute_encoder_budget(
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry \
@ -231,10 +294,10 @@ def compute_mm_encoder_budget(
non-text modality.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if not max_tokens_by_modality:

View File

@ -143,9 +143,9 @@ class SchedulerOutput:
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch
# for filling the next token bitmask

View File

@ -252,6 +252,7 @@ class Scheduler(SchedulerInterface):
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
@ -550,7 +551,8 @@ class Scheduler(SchedulerInterface):
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
@ -698,7 +700,7 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache.
continue
if self.encoder_cache_manager.has_cache(request, i):
if self.encoder_cache_manager.check_and_update_cache(request, i):
# The encoder input is already computed and cached.
continue
@ -712,8 +714,8 @@ class Scheduler(SchedulerInterface):
num_new_tokens = start_pos - num_computed_tokens
break
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
if not self.encoder_cache_manager.try_allocate(
request, i, encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses

View File

@ -33,6 +33,7 @@ class CachedRequestState:
prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]

View File

@ -176,8 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@ -436,7 +436,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@ -447,12 +446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.remove_request(req_id)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
@ -496,6 +491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
@ -1161,17 +1157,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -1204,15 +1201,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for output in curr_group_outputs:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
@ -1230,6 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -1249,11 +1241,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
num_encoder_tokens,
)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]

View File

@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
removed_req_indices.append(req_index)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
@ -394,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
@ -845,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
# List of tuple (mm_hash, pos_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -895,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE (NickLucche) here we diverge from logic in other runners, as we
# assume to only have whole mm items to process. Hence we avoid the
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
for (mm_hash, pos_info), output in zip(
mm_hashes_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
assert pos_info.is_embed is None, "Expected all positions to be"\
" contiguous and embeddings."
self.encoder_cache[req_id][input_id] = output
self.encoder_cache[mm_hash] = output
def _gather_mm_embeddings(
self,
@ -916,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
@ -936,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the decoder's KV cache.
continue
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings."
encoder_output = self.encoder_cache[req_id][i]
encoder_output = self.encoder_cache[mm_hash]
mm_embeds.append(encoder_output)
return mm_embeds