mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673)
Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
@ -6,6 +6,8 @@ import msgspec
|
||||
import zmq
|
||||
from msgspec.msgpack import Decoder
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
#
|
||||
# Types copied from vllm.distributed.kv_events
|
||||
@ -22,8 +24,8 @@ class KVCacheEvent(
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
block_hashes: list[BlockHash]
|
||||
parent_block_hash: Optional[BlockHash]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
block_hashes: list[BlockHash]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
|
@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
assert hash is not None
|
||||
assert isinstance(hash, int)
|
||||
assert hash != 0
|
||||
def test_sha256(input: tuple):
|
||||
digest = sha256(input)
|
||||
assert digest is not None
|
||||
assert isinstance(digest, bytes)
|
||||
assert digest != b""
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
||||
byteorder="big")
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert digest == hashlib.sha256(input_bytes).digest()
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
assert digest == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert hash != sha256(input + (1, ))
|
||||
assert digest != sha256(input + (1, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -6,20 +6,22 @@ from typing import Callable, Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
||||
is_kv_cache_type_uniform, make_block_hash_with_group_id,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH != 0
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH != b""
|
||||
|
||||
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
|
||||
with monkeypatch.context() as m:
|
||||
@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
# Test KVCacheBlock initialization
|
||||
block = KVCacheBlock(block_id=0)
|
||||
@ -127,8 +128,7 @@ def test_kv_cache_block():
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
# Test block hash setting and resetting
|
||||
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
|
||||
token_ids=(1, 2, 3))
|
||||
block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
|
||||
block.block_hash = block_hash
|
||||
assert block.block_hash == block_hash
|
||||
|
||||
@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
assert next_mm_idx == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_block_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
parent_block_hash = 123
|
||||
parent_block_hash = BlockHash(b"123")
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
curr_block_token_ids, extra_keys)
|
||||
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
assert block_hash.hash_value == hash_fn(
|
||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_request_block_hasher(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
kv_cache_utils.init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
|
||||
# Check the first block
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys == ("hash1", )
|
||||
|
||||
# Check the second block
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
assert block_hashes[0] == hash_fn(
|
||||
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
|
||||
assert block_hashes[1] == hash_fn(
|
||||
(block_hashes[0], (3, 4, 5), ("hash2", )))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_tokens_different_mm_input(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
kv_cache_utils.init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
block_hashes = request.block_hashes
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys is None
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys is None
|
||||
assert block_hashes[0] == hash_fn(
|
||||
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
|
||||
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
|
||||
|
||||
|
||||
def test_metrics():
|
||||
|
@ -8,17 +8,19 @@ from typing import Callable, Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.utils import sha256, sha256_cbor
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
get_block_hash, get_group_id,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash)
|
||||
hash_block_tokens, init_none_hash,
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_prefill(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -110,10 +114,6 @@ def test_prefill(hash_algo):
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if hash_algo == "sha256" else hash)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
@ -137,10 +137,12 @@ def test_prefill(hash_algo):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == 0
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, ):
|
||||
@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
hash_fn = hash
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
|
||||
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
for group_id, block_id in enumerate(block_ids):
|
||||
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == group_id
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, 8, 12):
|
||||
@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
hash_with_group_id)
|
||||
@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
test_partial_request_hit("2", [
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2)
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2)
|
||||
], 3)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
test_partial_request_hit("3", [
|
||||
BlockHashWithGroupId(block_hashes[0], 0),
|
||||
], 0)
|
||||
test_partial_request_hit(
|
||||
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 2.
|
||||
test_partial_request_hit("4", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[2], 1),
|
||||
BlockHashWithGroupId(block_hashes[2], 2),
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[2], 1),
|
||||
make_block_hash_with_group_id(block_hashes[2], 2),
|
||||
], 2)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 2.
|
||||
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2),
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
], 0)
|
||||
|
||||
|
||||
@ -372,8 +374,8 @@ def test_prefill_plp():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
# the default hash function is hash
|
||||
hash_fn = hash
|
||||
# the default hash function is sha256
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -404,10 +406,12 @@ def test_prefill_plp():
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
blk_hash = (manager.block_pool.blocks[block_id].block_hash)
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == 0
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, ):
|
||||
@ -493,7 +497,7 @@ def test_decode():
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
|
||||
hash)
|
||||
sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -538,7 +542,7 @@ def test_evict():
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
|
||||
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -550,7 +554,7 @@ def test_evict():
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)), block_size,
|
||||
hash)
|
||||
sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -572,7 +576,7 @@ def test_evict():
|
||||
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||
assert num_computed_tokens == 2 * 16
|
||||
@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate a new block that's not full, make sure hash info on the
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
|
||||
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Allocate a block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Now if we have a cache hit on the first block, we should evict the second
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
|
||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert computed_blocks.blocks[0][0].block_id == 1
|
||||
@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
|
||||
)
|
||||
|
||||
req1 = make_request("1", list(range(10)), block_size,
|
||||
hash) # 2 blocks and some more
|
||||
sha256) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
|
||||
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16)), block_size,
|
||||
hash) # shared prefix
|
||||
sha256) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)), block_size, hash)
|
||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
|
||||
assert not blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
|
||||
# Block 1/5: [4, 5, 6, 7]
|
||||
# Block 2/6: [8, 9, 10, 11]
|
||||
# Block 3/7: [12, 13]
|
||||
req = make_request("0", list(range(14)), block_size, hash)
|
||||
req = make_request("0", list(range(14)), block_size, sha256)
|
||||
|
||||
# Cache the blocks for group 0.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
@ -845,6 +849,8 @@ def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
"""
|
||||
kv_cache_utils.init_none_hash(sha256)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -874,23 +880,30 @@ def test_mm_prefix_caching():
|
||||
req0 = make_request("0",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
sha256,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
# Completed block should have hashes
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("aaa", )
|
||||
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
|
||||
("aaa", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
|
||||
("aaa", "bbb")))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
|
||||
("bbb", )))
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
@ -901,10 +914,10 @@ def test_mm_prefix_caching():
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
assert block_hashes[3].extra_keys == ("ccc", )
|
||||
assert block_hashes[3] == sha256(
|
||||
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
|
||||
("ccc", )))
|
||||
|
||||
# Cache hit.
|
||||
unique_token_ids = [-1] * 7 + [200] * 5
|
||||
@ -916,7 +929,7 @@ def test_mm_prefix_caching():
|
||||
req1 = make_request("1",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
sha256,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
@ -929,6 +942,8 @@ def test_cache_key_salting():
|
||||
This tests that cache salts are applied during hashing and the cache
|
||||
is separated cache as expected.
|
||||
"""
|
||||
kv_cache_utils.init_none_hash(sha256)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -939,21 +954,26 @@ def test_cache_key_salting():
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
token_ids = common_token_ids + [3] * 11
|
||||
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
|
||||
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
# Completed block should have hashes
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt1", )
|
||||
assert block_hashes[1].extra_keys is None
|
||||
assert block_hashes[2].extra_keys is None
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||
None))
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
@ -964,14 +984,13 @@ def test_cache_key_salting():
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
assert block_hashes[3].extra_keys is None
|
||||
assert block_hashes[3] == sha256(
|
||||
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
|
||||
|
||||
# Test cache hit with a new request that has the same salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
|
||||
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
@ -979,13 +998,19 @@ def test_cache_key_salting():
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
|
||||
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req2.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt2", )
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||
None))
|
||||
|
||||
|
||||
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids, block_size, hash)
|
||||
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2, block_size, hash)
|
||||
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks[0] == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
|
||||
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3, block_size, hash)
|
||||
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks.blocks[0] == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, hash)
|
||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids, block_size, hash)
|
||||
req1 = make_request("1", all_token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)), block_size, hash)
|
||||
req = make_request("0", list(range(16)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
def test_maybe_evict_cached_block():
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
|
||||
token_ids=(100, )),
|
||||
group_id=1000)
|
||||
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
|
||||
token_ids=(200, )),
|
||||
group_id=2000)
|
||||
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
|
||||
token_ids=(300, )),
|
||||
group_id=3000)
|
||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||
block_hashes = [
|
||||
block_hash0,
|
||||
block_hash1,
|
||||
@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
_ = manager.allocate_slots(req0, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
# Should see block_to_cache number of removed block events and a new block
|
||||
# stored event
|
||||
manager.free(req0)
|
||||
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
|
||||
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
|
||||
_ = manager.allocate_slots(req1, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
token_ids = [0] * (3 * block_size)
|
||||
req = make_request("divisible_request", token_ids, block_size, hash)
|
||||
req = make_request("divisible_request", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
|
||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
|
||||
# Should retain 1 block:
|
||||
@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
|
@ -6,8 +6,8 @@ import random
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
ChunkedLocalAttentionManager, SlidingWindowManager)
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
|
||||
def run_one_case(block_is_cached, tail_token, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
@ -130,10 +131,10 @@ def create_requests(
|
||||
) -> list[Request]:
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(sha256)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
|
@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# default hash algorithm is "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to sha256_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == \
|
||||
"sha256_cbor"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to builtin
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# an invalid hash algorithm raises an error
|
||||
parser.exit_on_error = False
|
||||
with pytest.raises(ArgumentError):
|
||||
|
@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
@ -127,11 +128,11 @@ def create_request(request_id: int,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = hash) -> Request:
|
||||
hash_fn: Callable = sha256) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(hash_fn)
|
||||
_none_hash_initialized = True
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
@ -24,7 +24,7 @@ logger = init_logger(__name__)
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
|
||||
|
||||
@config
|
||||
@ -63,17 +63,12 @@ class CacheConfig:
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
||||
default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||
"""Whether to enable prefix caching. Enabled by default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads.
|
||||
This option uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
||||
hash. It serializes objects using canonical CBOR and hashes them with
|
||||
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
|
||||
digest."""
|
||||
- "sha256" uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256."""
|
||||
cpu_offload_gb: float = 0
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
|
@ -16,6 +16,7 @@ import zmq
|
||||
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
parent_block_hash: Optional[ExternalBlockHash]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
|
@ -1592,20 +1592,12 @@ class EngineArgs:
|
||||
"in low performance due to small KV cache size. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching:
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
||||
if self.prefix_caching_hash_algo == "sha256":
|
||||
raise ValueError(
|
||||
"sha256 is not supported for prefix caching in V0 engine. "
|
||||
"Please use 'builtin'.")
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if self.enable_prefix_caching and model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
|
@ -171,6 +171,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||
|
||||
# Represent block hashes in KV cache events as 64-bit integers instead of
|
||||
# raw bytes. Defaults to True for backward compatibility.
|
||||
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
|
||||
and getattr(cfg.attn_config, "alibi", False)))))
|
||||
|
||||
|
||||
def sha256(input) -> int:
|
||||
def sha256(input) -> bytes:
|
||||
"""Hash any picklable Python object using SHA-256.
|
||||
|
||||
The input is serialized using pickle before hashing, which allows
|
||||
@ -3260,16 +3260,15 @@ def sha256(input) -> int:
|
||||
input: Any picklable Python object.
|
||||
|
||||
Returns:
|
||||
An integer representing the SHA-256 hash of the serialized input.
|
||||
Bytes representing the SHA-256 hash of the serialized input.
|
||||
"""
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||
byteorder="big")
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
|
||||
def sha256_cbor_64bit(input) -> int:
|
||||
def sha256_cbor(input) -> bytes:
|
||||
"""
|
||||
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
|
||||
Hash objects using CBOR serialization and SHA-256.
|
||||
|
||||
This option is useful for non-Python-dependent serialization and hashing.
|
||||
|
||||
@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
|
||||
Custom classes must implement CBOR serialization methods.
|
||||
|
||||
Returns:
|
||||
An integer in the range [0, 2^64-1] representing the lower 64 bits
|
||||
of the SHA-256 hash of the CBOR serialized input.
|
||||
Bytes representing the SHA-256 hash of the CBOR serialized input.
|
||||
"""
|
||||
input_bytes = cbor2.dumps(input, canonical=True)
|
||||
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
return full_hash & ((1 << 64) - 1)
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
"""Get a hash function by name, or raise an error if
|
||||
the function is not found.
|
||||
Args:
|
||||
@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
|
||||
"""
|
||||
if hash_fn_name == "sha256":
|
||||
return sha256
|
||||
if hash_fn_name == "sha256_cbor_64bit":
|
||||
return sha256_cbor_64bit
|
||||
if hash_fn_name == "builtin":
|
||||
return hash
|
||||
if hash_fn_name == "sha256_cbor":
|
||||
return sha256_cbor
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
|
||||
|
@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||
KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
get_block_hash,
|
||||
make_block_hash_with_group_id,
|
||||
maybe_convert_block_hash)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -84,8 +88,10 @@ class BlockPool:
|
||||
"""
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, group_id)
|
||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||
BlockHashWithGroupId(block_hash, group_id))
|
||||
block_hash_with_group_id)
|
||||
if not cached_blocks_one_group:
|
||||
return None
|
||||
first_block = next(iter(cached_blocks_one_group.values()))
|
||||
@ -124,28 +130,29 @@ class BlockPool:
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
else None)
|
||||
new_hashes: Optional[list[ExternalBlockHash]] = (
|
||||
[] if self.enable_kv_cache_events else None)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, kv_cache_group_id)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash = None
|
||||
parent_block_hash: Optional[ExternalBlockHash] = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = parent_block.block_hash.get_hash_value()
|
||||
parent_block_hash = maybe_convert_block_hash(
|
||||
get_block_hash(parent_block.block_hash))
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
@ -220,7 +227,9 @@ class BlockPool:
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
|
||||
BlockRemoved(block_hashes=[
|
||||
maybe_convert_block_hash(get_block_hash(block_hash))
|
||||
],
|
||||
medium=MEDIUM_GPU))
|
||||
return True
|
||||
|
||||
|
@ -6,11 +6,12 @@ import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
from typing import Any, Callable, NewType, Optional, Union
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
# BlockHash represents the hash of a single KV-cache block used for
|
||||
# prefix caching. Treating it as a distinct type from ``bytes`` helps
|
||||
# catch accidental misuse when passing around raw byte strings.
|
||||
BlockHash = NewType("BlockHash", bytes)
|
||||
|
||||
# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID.
|
||||
# It is represented as raw bytes for compactness and efficiency. The helper
|
||||
# functions below pack/unpack the ``BlockHash`` and group id into/from the key.
|
||||
BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes)
|
||||
|
||||
# ExternalBlockHash is used for reproducible prefix-cache block hashing.
|
||||
# It's a union of ``bytes`` and ``int`` to keep backward compatibility
|
||||
# after we default block hashing to use sha256 bytes.
|
||||
ExternalBlockHash = Union[bytes, int]
|
||||
|
||||
|
||||
class BlockHash(NamedTuple):
|
||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
||||
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
||||
hash collisions when the hash value is the same. By using SHA256 however,
|
||||
hash collisions are practically impossible.
|
||||
def make_block_hash_with_group_id(block_hash: BlockHash,
|
||||
group_id: int) -> BlockHashWithGroupId:
|
||||
"""Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``.
|
||||
|
||||
The group id is encoded using 4 bytes in big-endian order and appended to
|
||||
the block hash bytes. This representation avoids creating tuples while
|
||||
still allowing us to recover both components when needed.
|
||||
"""
|
||||
# Hash value of the block in an integer.
|
||||
hash_value: int
|
||||
# Token IDs in the block.
|
||||
token_ids: tuple[int, ...]
|
||||
# Extra keys for the block.
|
||||
extra_keys: Optional[Any] = None
|
||||
return BlockHashWithGroupId(block_hash +
|
||||
group_id.to_bytes(4, "big", signed=False))
|
||||
|
||||
|
||||
class BlockHashWithGroupId(NamedTuple):
|
||||
# The hash value for the contents (e.g., token_ids) of a block without group
|
||||
# ID. The value is the same for blocks representing the same tokens but for
|
||||
# different groups.
|
||||
block_hash: BlockHash
|
||||
# The KV cache group ID.
|
||||
group_id: int
|
||||
def get_block_hash(key: BlockHashWithGroupId) -> BlockHash:
|
||||
"""Extract the ``BlockHash`` from a ``BlockHashWithGroupId``."""
|
||||
return BlockHash(key[:-4])
|
||||
|
||||
def get_hash_value(self) -> int:
|
||||
return self.block_hash.hash_value
|
||||
|
||||
def get_group_id(key: BlockHashWithGroupId) -> int:
|
||||
"""Extract the group id from a ``BlockHashWithGroupId``."""
|
||||
return int.from_bytes(key[-4:], "big", signed=False)
|
||||
|
||||
|
||||
def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
|
||||
if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES:
|
||||
return hash_bytes
|
||||
return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The hash seed for the first block of any prefix block sequence.
|
||||
#
|
||||
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
||||
# variable if set such that processes can share the seed if needed.
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
# variable if set such that processes can share the seed if needed. This aligns
|
||||
# with the behavior of Python's hash() function, which also uses a random seed
|
||||
# if PYTHONHASHSEED is not set.
|
||||
#
|
||||
# The function `init_none_hash` initializes this variable globally.
|
||||
NONE_HASH: int
|
||||
NONE_HASH: BlockHash
|
||||
|
||||
|
||||
def init_none_hash(hash_fn: Callable):
|
||||
def init_none_hash(hash_fn: Callable[[Any], bytes]):
|
||||
global NONE_HASH
|
||||
|
||||
hash_seed = os.getenv("PYTHONHASHSEED")
|
||||
if hash_seed is None and hash_fn is sha256_cbor_64bit:
|
||||
if hash_seed is None and hash_fn is sha256_cbor:
|
||||
logger.warning(
|
||||
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
|
||||
"block-hashes when using sha256_cbor_64bit as the hash function."
|
||||
"block-hashes when using sha256_cbor as the hash function."
|
||||
"Consider setting PYTHONHASHSEED to a fixed value for "
|
||||
"reproducibility.")
|
||||
|
||||
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
|
||||
if hash_seed is None else hash_fn(hash_seed))
|
||||
if hash_seed is None:
|
||||
NONE_HASH = BlockHash(os.urandom(32))
|
||||
else:
|
||||
NONE_HASH = BlockHash(hash_fn(hash_seed))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
@ -142,8 +162,8 @@ class KVCacheBlock:
|
||||
block_id: int
|
||||
# Reference count.
|
||||
ref_cnt: int = 0
|
||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||
# It is only available when the block is full.
|
||||
# The hash key (block hash + group id) of the block, only available
|
||||
# when the block is full and cached.
|
||||
_block_hash: Optional[BlockHashWithGroupId] = None
|
||||
|
||||
# Used to construct a doubly linked list for free blocks.
|
||||
@ -177,7 +197,7 @@ class KVCacheBlock:
|
||||
if self.next_free_block else None)
|
||||
return (f"KVCacheBlock(block_id={self.block_id}, "
|
||||
f"ref_cnt={self.ref_cnt}, "
|
||||
f"_block_hash={self._block_hash}, "
|
||||
f"_block_hash={self._block_hash!r}, "
|
||||
f"prev_free_block={prev_block_id}, "
|
||||
f"next_free_block={next_block_id})")
|
||||
|
||||
@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
|
||||
|
||||
|
||||
def hash_block_tokens(
|
||||
hash_function: Callable,
|
||||
parent_block_hash: Optional[int],
|
||||
hash_function: Callable[[Any], bytes],
|
||||
parent_block_hash: Optional[BlockHash],
|
||||
curr_block_token_ids: Sequence[int],
|
||||
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||
hash values for the same block contents.
|
||||
|
||||
Args:
|
||||
hash_function: The hash function used to compute block hash.
|
||||
parent_block_hash: The hash of the parent block. None
|
||||
@ -533,7 +552,6 @@ def hash_block_tokens(
|
||||
curr_block_token_ids: A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
extra_keys: Extra keys for the block.
|
||||
|
||||
Returns:
|
||||
The hash value of the block and the token ids in the block.
|
||||
The entire tuple is used as the hash key of the block.
|
||||
@ -544,26 +562,16 @@ def hash_block_tokens(
|
||||
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
||||
return BlockHash(
|
||||
hash_function(
|
||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
||||
curr_block_token_ids_tuple, extra_keys)
|
||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)))
|
||||
|
||||
|
||||
def get_request_block_hasher(
|
||||
block_size: int,
|
||||
caching_hash_fn: Callable[[Any],
|
||||
int]) -> Callable[[Request], list[BlockHash]]:
|
||||
caching_hash_fn: Callable[[Any], bytes],
|
||||
) -> Callable[[Request], list[BlockHash]]:
|
||||
"""
|
||||
Returns a function which computes the list of un-computed block hashes
|
||||
of a request.
|
||||
|
||||
Each request holds a list of its block hashes (request.block_hashes).
|
||||
When a request is created, it calls the below function to compute
|
||||
the hashes of all full blocks of the request's initial tokens.
|
||||
The hashes are then stored in request.block_hashes.
|
||||
Later, whenever new tokens are appended to the request, it calls
|
||||
the below function again to compute any new full blocks of tokens.
|
||||
The returned new hashes are appended to request.block_hashes.
|
||||
"""
|
||||
of a request."""
|
||||
|
||||
def request_block_hasher(request: Request) -> list[BlockHash]:
|
||||
start_token_idx = len(request.block_hashes) * block_size
|
||||
@ -577,8 +585,8 @@ def get_request_block_hasher(
|
||||
# last mm input.
|
||||
curr_mm_idx = -1
|
||||
|
||||
prev_block_hash_value = request.block_hashes[-1].hash_value \
|
||||
if request.block_hashes else None
|
||||
prev_block_hash_value = (request.block_hashes[-1]
|
||||
if request.block_hashes else None)
|
||||
new_block_hashes: list[BlockHash] = []
|
||||
while True:
|
||||
end_token_idx = start_token_idx + block_size
|
||||
@ -598,7 +606,7 @@ def get_request_block_hasher(
|
||||
|
||||
new_block_hashes.append(block_hash)
|
||||
start_token_idx += block_size
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
prev_block_hash_value = block_hash
|
||||
|
||||
return new_block_hashes
|
||||
|
||||
|
Reference in New Issue
Block a user