[Core][Multimodal] Track encode cache entries by mm_hash and enable embedding sharing between requests (#22711)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
|
||||
|
||||
# ------------------ Mock Classes ------------------ #
|
||||
class MockRequest:
|
||||
|
||||
def __init__(self, request_id, mm_hashes, token_counts):
|
||||
self.request_id = request_id
|
||||
self.mm_hashes = mm_hashes
|
||||
self._token_counts = token_counts
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
# ------------------ Unit Tests ------------------ #
|
||||
def test_basic_allocate_and_reuse():
|
||||
cache = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("r1", ["imgA"], [4])
|
||||
|
||||
assert not cache.check_and_update_cache(req, 0)
|
||||
assert cache.try_allocate(req, 0, int(1e9))
|
||||
|
||||
cache.allocate(req, 0)
|
||||
|
||||
assert cache.check_and_update_cache(req, 0)
|
||||
assert "r1" in cache.cached["imgA"]
|
||||
assert cache.num_free_slots == 6
|
||||
|
||||
# Free twice to bring refcount to 0.
|
||||
cache.free_encoder_input(req, 0)
|
||||
cache.free_encoder_input(req, 0)
|
||||
|
||||
assert not cache.cached["imgA"]
|
||||
assert "imgA" in cache.freeable
|
||||
assert cache.num_freeable_slots == 10
|
||||
assert cache.num_free_slots == 6
|
||||
|
||||
|
||||
def test_freeing_decreases_refcount_and_moves_to_freeable():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("req2", ["img3"], [5])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert len(manager.cached["img3"]) == 1
|
||||
|
||||
manager.free_encoder_input(req, 0)
|
||||
|
||||
assert not manager.cached["img3"]
|
||||
assert "img3" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
|
||||
|
||||
def test_free_request_frees_all_inputs():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("req3", ["a", "b"], [2, 3])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert manager.try_allocate(req, 1, int(1e9))
|
||||
manager.allocate(req, 1)
|
||||
|
||||
assert len(manager.cached["a"]) == 1
|
||||
assert len(manager.cached["b"]) == 1
|
||||
|
||||
manager.free(req)
|
||||
|
||||
assert not manager.cached["a"]
|
||||
assert not manager.cached["b"]
|
||||
assert "a" in manager.freeable
|
||||
assert "b" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
|
||||
|
||||
def test_eviction_when_cache_is_full():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
|
||||
req1 = MockRequest("req1", ["x"], [6])
|
||||
req2 = MockRequest("req2", ["y"], [5])
|
||||
|
||||
assert manager.try_allocate(req1, 0, int(1e9))
|
||||
manager.allocate(req1, 0)
|
||||
manager.free_encoder_input(req1, 0)
|
||||
|
||||
assert manager.try_allocate(req2, 0, int(1e9))
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
# 'x' should have been evicted.
|
||||
assert "x" not in manager.cached
|
||||
assert "x" in manager.get_freed_mm_hashes()
|
||||
|
||||
|
||||
def test_get_cached_input_ids():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert manager.try_allocate(req, 2, int(1e9))
|
||||
manager.allocate(req, 2)
|
||||
|
||||
cached_ids = manager.get_cached_input_ids(req)
|
||||
assert cached_ids == {0, 2}
|
||||
|
||||
|
||||
def test_has_cache_restores_from_freeable():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("reqY", ["imgZ"], [4])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
manager.free_encoder_input(req, 0)
|
||||
|
||||
# Should restore from freeable.
|
||||
assert manager.check_and_update_cache(req, 0)
|
||||
assert len(manager.cached["imgZ"]) == 1
|
||||
assert "imgZ" not in manager.freeable
|
||||
assert manager.num_freeable_slots == 6
|
||||
|
||||
|
||||
def test_get_freed_mm_hashes_clears_freed_list():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req1 = MockRequest("reqA", ["a"], [5])
|
||||
req2 = MockRequest("reqB", ["b"], [6])
|
||||
|
||||
assert manager.try_allocate(req1, 0, int(1e9))
|
||||
manager.allocate(req1, 0)
|
||||
manager.free_encoder_input(req1, 0)
|
||||
|
||||
# Should trigger eviction of 'a'.
|
||||
assert manager.try_allocate(req2, 0, int(1e9))
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
freed = manager.get_freed_mm_hashes()
|
||||
assert "a" in freed
|
||||
assert manager.get_freed_mm_hashes() == []
|
@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
|
@ -143,7 +143,11 @@ def create_requests(
|
||||
mm_position = mm_positions[i]
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
mm_hashes = ["hash"] * len(mm_position)
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
mm_hashes = [
|
||||
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
|
||||
]
|
||||
else:
|
||||
mm_position = None
|
||||
mm_kwargs = None
|
||||
|
@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
pooling_params=None,
|
||||
mm_kwargs=[],
|
||||
mm_positions=[],
|
||||
mm_hashes=[],
|
||||
block_ids=([], ),
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
|
@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -31,34 +33,52 @@ class EncoderCacheManager:
|
||||
within requests, allowing for fine-grained memory management and enabling
|
||||
chunked processing of multimodal inputs.
|
||||
|
||||
Note that no caching is shared between requests at this time. If the same
|
||||
input is used across multiple requests, it will be reprocessed for each
|
||||
request.
|
||||
Cache is enabled to share embeddings of same multimodal data
|
||||
item (identified by their hash value) between different requests,
|
||||
and eviction takes place at allocation time when there's no free
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens
|
||||
num_free_slots: Current available cache capacity in encoder tokens
|
||||
cached: Mapping from request_id to set of cached input_ids for that
|
||||
request
|
||||
freed: List of (request_id, input_id) pairs that were recently freed.
|
||||
This is cleared after every call to get_freed_ids().
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
last call to get_freed_mm_hashes(). This list is cleared on return.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
# req_id -> cached input ids
|
||||
self.cached: dict[str, set[int]] = {}
|
||||
# list of [req_id, input_id]
|
||||
self.freed: list[tuple[str, int]] = []
|
||||
self.num_freeable_slots = cache_size
|
||||
|
||||
def has_cache(self, request: Request, input_id: int) -> bool:
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if encoder output for a specific multimodal input is cached.
|
||||
|
||||
If the encoder output is cached, update `cached` to add the request id
|
||||
to the set of request ids that reference the cached encoder output.
|
||||
If the encoder output was previously not referenced by any request,
|
||||
update `freeable` and `num_freeable_slots` accordingly.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
@ -66,103 +86,151 @@ class EncoderCacheManager:
|
||||
Returns:
|
||||
True if the encoder output for this input is already cached
|
||||
"""
|
||||
req_id = request.request_id
|
||||
return req_id in self.cached and input_id in self.cached[req_id]
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
# Not cached at all
|
||||
if mm_hash not in self.cached:
|
||||
return False
|
||||
|
||||
def can_allocate(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
|
||||
def try_allocate(self, request: Request, input_id: int,
|
||||
encoder_budget: int) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
|
||||
If there is not enough free space in `num_free_slots` but there is
|
||||
enough reclaimable space in `num_freeable_slots`, entries will be
|
||||
evicted from `freeable` (their mm_hash appended to `freed`) until
|
||||
enough space is available, and then this method returns True.
|
||||
Older entries are evicted first.
|
||||
|
||||
Returns False only if the requested number of tokens exceeds both
|
||||
the free and reclaimable capacities combined.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
|
||||
Returns:
|
||||
True if there's enough free cache space to store the encoder output
|
||||
for this multimodal input
|
||||
True if there's enough capacity to hold the encoder output for this
|
||||
input (possibly after reclaiming `freeable` entries); otherwise
|
||||
False.
|
||||
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
return num_tokens <= self.num_free_slots
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_budget:
|
||||
return False
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
self.num_free_slots -= num_tokens
|
||||
self.num_freeable_slots -= num_tokens
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
self.num_free_slots -= num_tokens
|
||||
self.num_freeable_slots -= num_tokens
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
"""Allocate cache space for a multimodal input's encoder output.
|
||||
|
||||
This method reserves cache space for storing the encoder output of
|
||||
the specified multimodal input. The actual encoder output storage
|
||||
happens in the model runner, but this method ensures the cache
|
||||
manager tracks the allocation.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
This reserves cache space for storing the encoder output of the
|
||||
specified multimodal input. The actual encoder output storage happens in
|
||||
the model runner; this method updates the manager's bookkeeping.
|
||||
|
||||
Note:
|
||||
This method assumes can_allocate() returned True for the same
|
||||
request and input_id. It will reduce available cache space.
|
||||
This method assumes try_allocate() returned True for the same input.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
self.cached[req_id] = set()
|
||||
self.cached[req_id].add(input_id)
|
||||
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
||||
# Encoder cache space budget should be already updated for the
|
||||
# multimodal input and non-negative after try_allocate() is called.
|
||||
assert self.num_free_slots >= 0
|
||||
assert self.num_freeable_slots >= 0
|
||||
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
request_id = request.request_id
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
|
||||
Args:
|
||||
request: The request to query
|
||||
|
||||
Returns:
|
||||
Set of input_ids that have cached encoder outputs for this request.
|
||||
Returns empty set if no inputs are cached for this request.
|
||||
Returns the set of input IDs whose `mm_hash` exists in the cache map.
|
||||
This includes entries that are currently unreferenced (and thus present
|
||||
in `freeable`); for such entries, freeing for this request will be a
|
||||
no-op.
|
||||
"""
|
||||
return self.cached.get(request.request_id, set())
|
||||
return {
|
||||
input_id
|
||||
for input_id in range(len(request.mm_hashes))
|
||||
if request.mm_hashes[input_id] in self.cached
|
||||
}
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free cache space for a single multimodal input's encoder output.
|
||||
"""Free the request's reference to the encoder input (`mm_data`)
|
||||
|
||||
This method is called when:
|
||||
- The encoder output has been fully consumed by the decoder and is
|
||||
no longer needed (e.g., in vision-language models after image
|
||||
tokens are processed)
|
||||
- A request is being cancelled or aborted
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input to free from cache
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
# The mm_hash not in cache or the req_id set is empty
|
||||
if not self.cached.get(mm_hash, None):
|
||||
return
|
||||
|
||||
self.cached[req_id].discard(input_id)
|
||||
if len(self.cached[req_id]) == 0:
|
||||
del self.cached[req_id]
|
||||
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
||||
self.freed.append((req_id, input_id))
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all cached encoder outputs for a request.
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
|
||||
This method is typically called when a request is finished, cancelled,
|
||||
or aborted, and all its encoder outputs should be freed from cache.
|
||||
For each cached input ID, `free_encoder_input` is invoked.
|
||||
The data stays in memory until eviction is triggered by a future
|
||||
attempt allocation called by 'can_allocate'.
|
||||
|
||||
Args:
|
||||
request: The request whose encoder outputs should be freed
|
||||
Typically called when a request is finished, cancelled, or aborted.
|
||||
"""
|
||||
input_ids = self.get_cached_input_ids(request).copy()
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_ids(self) -> list[tuple[str, int]]:
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
"""Get and clear the list of recently freed encoder cache entries.
|
||||
|
||||
This method returns all encoder cache entries that were freed since
|
||||
the last call to this method. It's used by the scheduler to notify
|
||||
workers about which encoder outputs can be removed from their caches.
|
||||
|
||||
Returns:
|
||||
List of (request_id, input_id) tuples that were freed since the
|
||||
last call. The internal freed list is cleared after this call.
|
||||
List of mm_hash strings that were actually evicted since the last
|
||||
call to be used by the scheduler to notify workers about which
|
||||
encoder outputs can be removed from their caches. The internal
|
||||
list is cleared after this call.
|
||||
"""
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
@ -177,16 +245,11 @@ def compute_encoder_budget(
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration.
|
||||
scheduler_config: Scheduler configuration.
|
||||
mm_registry: Provides information about the token cost.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
if mm_registry.supports_multimodal_inputs(model_config):
|
||||
max_tokens_by_modality = mm_registry \
|
||||
@ -231,10 +294,10 @@ def compute_mm_encoder_budget(
|
||||
non-text modality.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
|
||||
if not max_tokens_by_modality:
|
||||
|
@ -143,9 +143,9 @@ class SchedulerOutput:
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
finished_req_ids: set[str]
|
||||
# list of (req_id, encoder_input_index) tuples.
|
||||
# Used to free the encoder cache.
|
||||
free_encoder_input_ids: list[tuple[str, int]]
|
||||
# list of mm_hash strings associated with the encoder outputs to be
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Dict of request ids to their index within the batch
|
||||
# for filling the next token bitmask
|
||||
|
@ -252,6 +252,7 @@ class Scheduler(SchedulerInterface):
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
@ -550,7 +551,8 @@ class Scheduler(SchedulerInterface):
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
@ -698,7 +700,7 @@ class Scheduler(SchedulerInterface):
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
if self.encoder_cache_manager.has_cache(request, i):
|
||||
if self.encoder_cache_manager.check_and_update_cache(request, i):
|
||||
# The encoder input is already computed and cached.
|
||||
continue
|
||||
|
||||
@ -712,8 +714,8 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
if (not self.encoder_cache_manager.can_allocate(request, i)
|
||||
or num_encoder_tokens > encoder_budget):
|
||||
if not self.encoder_cache_manager.try_allocate(
|
||||
request, i, encoder_budget):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
# be processed altogether, as the encoder usually uses
|
||||
|
@ -33,6 +33,7 @@ class CachedRequestState:
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
mm_hashes: list[str]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
|
@ -176,8 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
# Set up speculative decoding.
|
||||
@ -436,7 +436,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@ -447,12 +446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
@ -496,6 +491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
mm_hashes=new_req_data.mm_hashes,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
generator=generator,
|
||||
@ -1161,17 +1157,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
# list of tuple (mm_hash, position_info)
|
||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
req_ids_pos.append(
|
||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||
mm_hashes_pos.append(
|
||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
@ -1204,15 +1201,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for output in curr_group_outputs:
|
||||
encoder_outputs.append(output)
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
|
||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
@ -1230,6 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
mm_hashes = req_state.mm_hashes
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
@ -1249,11 +1241,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens)
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
|
||||
mm_hash = mm_hashes[i]
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
@ -394,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
mm_hashes=new_req_data.mm_hashes,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
generator=None,
|
||||
@ -845,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
# List of tuple (mm_hash, pos_info)
|
||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
req_ids_pos.append(
|
||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||
mm_hashes_pos.append(
|
||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
@ -895,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# NOTE (NickLucche) here we diverge from logic in other runners, as we
|
||||
# assume to only have whole mm items to process. Hence we avoid the
|
||||
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
for (mm_hash, pos_info), output in zip(
|
||||
mm_hashes_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
assert pos_info.is_embed is None, "Expected all positions to be"\
|
||||
" contiguous and embeddings."
|
||||
self.encoder_cache[req_id][input_id] = output
|
||||
self.encoder_cache[mm_hash] = output
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
@ -916,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
mm_hashes = req_state.mm_hashes
|
||||
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
||||
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
||||
# we assume to only have whole mm items to process. Hence we avoid
|
||||
@ -936,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
mm_hash = mm_hashes[i]
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
assert pos_info.is_embed is None, "Expected all positions to"\
|
||||
" be contiguous and embeddings."
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
encoder_output = self.encoder_cache[mm_hash]
|
||||
mm_embeds.append(encoder_output)
|
||||
return mm_embeds
|
||||
|
||||
|
Reference in New Issue
Block a user