[Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673)

Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
Zebing Lin
2025-09-09 00:34:37 -04:00
committed by GitHub
parent bba1042c6f
commit 82dfb12e52
15 changed files with 298 additions and 283 deletions

View File

@ -6,6 +6,8 @@ import msgspec
import zmq
from msgspec.msgpack import Decoder
from vllm.v1.core.kv_cache_utils import BlockHash
#
# Types copied from vllm.distributed.kv_events
@ -22,8 +24,8 @@ class KVCacheEvent(
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
block_hashes: list[BlockHash]
parent_block_hash: Optional[BlockHash]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
block_hashes: list[BlockHash]
medium: Optional[str]

View File

@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
hash = sha256(input)
assert hash is not None
assert isinstance(hash, int)
assert hash != 0
def test_sha256(input: tuple):
digest = sha256(input)
assert digest is not None
assert isinstance(digest, bytes)
assert digest != b""
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert digest == hashlib.sha256(input_bytes).digest()
# hashing again, returns the same value
assert hash == sha256(input)
assert digest == sha256(input)
# hashing different input, returns different value
assert hash != sha256(input + (1, ))
assert digest != sha256(input + (1, ))
@pytest.mark.parametrize(

View File

@ -6,20 +6,22 @@ from typing import Callable, Optional
import pytest
import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.utils import GiB_bytes, sha256, sha256_cbor
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs)
is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert reloaded_kv_cache_utils.NONE_HASH != 0
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert reloaded_kv_cache_utils.NONE_HASH != b""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m:
@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block():
import vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0)
@ -127,8 +128,7 @@ def test_kv_cache_block():
assert block.ref_cnt == 0
# Test block hash setting and resetting
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
token_ids=(1, 2, 3))
block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
block.block_hash = block_hash
assert block.block_hash == block_hash
@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
parent_block_hash = 123
parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
assert block_hash.hash_value == hash_fn(
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash == expected
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
kv_cache_utils.init_none_hash(hash_fn)
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
# Check the first block
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys == ("hash1", )
# Check the second block
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys == ("hash2", )
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
assert block_hashes[1] == hash_fn(
(block_hashes[0], (3, 4, 5), ("hash2", )))
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1]
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)
kv_cache_utils.init_none_hash(hash_fn)
request = make_request(
request_id="0",
@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
def test_metrics():

View File

@ -8,17 +8,19 @@ from typing import Callable, Optional
import pytest
import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
get_block_hash, get_group_id,
get_request_block_hasher,
hash_block_tokens, init_none_hash)
hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
def test_prefill(hash_algo):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_prefill(hash_fn):
init_none_hash(hash_fn)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching=True,
)
# choose the hash function according to the parameter
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
sha256 if hash_algo == "sha256" else hash)
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, ):
@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching=True,
)
hash_fn = hash
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
for block_id in block_ids:
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
for group_id, block_id in enumerate(block_ids):
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == group_id
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, 8, 12):
@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, hash)
block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
hash_with_group_id)
@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit("2", [
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2)
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2)
], 3)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit("3", [
BlockHashWithGroupId(block_hashes[0], 0),
], 0)
test_partial_request_hit(
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit("4", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[2], 1),
BlockHashWithGroupId(block_hashes[2], 2),
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[2], 1),
make_block_hash_with_group_id(block_hashes[2], 2),
], 2)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
2)
test_partial_request_hit(
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
2)
test_partial_request_hit(
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
2)
test_partial_request_hit(
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit("8", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2),
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
], 0)
@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len=8192,
enable_caching=True,
)
# the default hash function is hash
hash_fn = hash
# the default hash function is sha256
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
blk_hash = (manager.block_pool.blocks[block_id].block_hash)
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, ):
@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
hash)
sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -538,7 +542,7 @@ def test_evict():
)
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -550,7 +554,7 @@ def test_evict():
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)), block_size,
hash)
sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -572,7 +576,7 @@ def test_evict():
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)), block_size, hash)
req = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
block_size, hash)
block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 1
assert computed_blocks.blocks[0][0].block_id == 1
@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
)
req1 = make_request("1", list(range(10)), block_size,
hash) # 2 blocks and some more
sha256) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2 = make_request("2", list(range(16)), block_size,
hash) # shared prefix
sha256) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, hash)
req3 = make_request("3", list(range(4)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert not blocks
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_cache_blocks(hash_fn):
"""
This is a unit test that tests the correctness of the _cache_full_blocks
@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req = make_request("0", list(range(14)), block_size, hash)
req = make_request("0", list(range(14)), block_size, sha256)
# Cache the blocks for group 0.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
@ -845,6 +849,8 @@ def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
kv_cache_utils.init_none_hash(sha256)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0 = make_request("0",
all_token_ids,
block_size,
hash,
sha256,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
# Completed block should have hashes
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", )
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
("aaa", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
("aaa", "bbb")))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
("bbb", )))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys == ("ccc", )
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
("ccc", )))
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1 = make_request("1",
all_token_ids,
block_size,
hash,
sha256,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
kv_cache_utils.init_none_hash(sha256)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
# Completed block should have hashes
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
@ -964,14 +984,13 @@ def test_cache_key_salting():
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks[0]) == 3
@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 0
assert num_computed_tokens == 0
block_hashes = req2.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
def test_prefill_not_enough_free_blocks_with_computed_blocks():
@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids, block_size, hash)
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2, block_size, hash)
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks[0] == block_part0
assert num_computed_tokens == 3 * 16
@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3, block_size, hash)
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks.blocks[0] == block_part1
assert num_computed_tokens == 6 * 16
@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, hash)
req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids, block_size, hash)
req1 = make_request("1", all_token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == 3
assert len(computed_blocks.blocks[0]) == 3
@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)), block_size, hash)
req = make_request("0", list(range(16)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
token_ids=(100, )),
group_id=1000)
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
token_ids=(200, )),
group_id=2000)
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
token_ids=(300, )),
group_id=3000)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
block_hashes = [
block_hash0,
block_hash1,
@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids, block_size, hash)
req = make_request("divisible_request", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 1 block:
@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash)
req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash)
req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
BlockHashWithGroupId(block_hash_first_block, 0))
make_block_hash_with_group_id(block_hash_first_block, 0))
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
block_size, hash)
block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,

View File

@ -6,8 +6,8 @@ import random
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
make_block_hash_with_group_id)
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}

View File

@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
@ -130,10 +131,10 @@ def create_requests(
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,

View File

@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert vllm_config.cache_config.enable_prefix_caching
# default hash algorithm is "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to sha256_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == \
"sha256_cbor"
# set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to builtin
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# an invalid hash algorithm raises an error
parser.exit_on_error = False
with pytest.raises(ArgumentError):

View File

@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.utils import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
@ -127,11 +128,11 @@ def create_request(request_id: int,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = hash) -> Request:
hash_fn: Callable = sha256) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
init_none_hash(hash_fn)
_none_hash_initialized = True
kv_transfer_params: Optional[dict[str, Any]] = None

View File

@ -24,7 +24,7 @@ logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
@config
@ -63,17 +63,12 @@ class CacheConfig:
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: Optional[bool] = None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
default for V1."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Whether to enable prefix caching. Enabled by default for V1."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads.
This option uses Pickle for object serialization before hashing.\n
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
- "sha256" uses Pickle for object serialization before hashing.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to

View File

@ -16,6 +16,7 @@ import zmq
from vllm.config.kv_events import KVEventsConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
logger = init_logger(__name__)
@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
block_hashes: list[ExternalBlockHash]
parent_block_hash: Optional[ExternalBlockHash]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
block_hashes: list[ExternalBlockHash]
medium: Optional[str]

View File

@ -1592,20 +1592,12 @@ class EngineArgs:
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching:
# Disable prefix caching for multimodal models for VLLM_V0.
if model_config.is_multimodal_model:
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
# Disable prefix caching for multimodal models for VLLM_V0.
if self.enable_prefix_caching and model_config.is_multimodal_model:
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None:

View File

@ -171,6 +171,7 @@ if TYPE_CHECKING:
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
def get_default_cache_root():
@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
# Represent block hashes in KV cache events as 64-bit integers instead of
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
}
# --8<-- [end:env-vars-definition]

View File

@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
and getattr(cfg.attn_config, "alibi", False)))))
def sha256(input) -> int:
def sha256(input) -> bytes:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
@ -3260,16 +3260,15 @@ def sha256(input) -> int:
input: Any picklable Python object.
Returns:
An integer representing the SHA-256 hash of the serialized input.
Bytes representing the SHA-256 hash of the serialized input.
"""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
return hashlib.sha256(input_bytes).digest()
def sha256_cbor_64bit(input) -> int:
def sha256_cbor(input) -> bytes:
"""
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
Hash objects using CBOR serialization and SHA-256.
This option is useful for non-Python-dependent serialization and hashing.
@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
Custom classes must implement CBOR serialization methods.
Returns:
An integer in the range [0, 2^64-1] representing the lower 64 bits
of the SHA-256 hash of the CBOR serialized input.
Bytes representing the SHA-256 hash of the CBOR serialized input.
"""
input_bytes = cbor2.dumps(input, canonical=True)
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
return full_hash & ((1 << 64) - 1)
return hashlib.sha256(input_bytes).digest()
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:
@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
"""
if hash_fn_name == "sha256":
return sha256
if hash_fn_name == "sha256_cbor_64bit":
return sha256_cbor_64bit
if hash_fn_name == "builtin":
return hash
if hash_fn_name == "sha256_cbor":
return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")

View File

@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock)
ExternalBlockHash,
FreeKVCacheBlockQueue, KVCacheBlock,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash)
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -84,8 +88,10 @@ class BlockPool:
"""
cached_blocks = []
for group_id in kv_cache_group_ids:
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, group_id)
cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id))
block_hash_with_group_id)
if not cached_blocks_one_group:
return None
first_block = next(iter(cached_blocks_one_group.values()))
@ -124,28 +130,29 @@ class BlockPool:
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
new_hashes: Optional[list[ExternalBlockHash]] = (
[] if self.enable_kv_cache_events else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
block_hash = new_block_hashes[i]
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
new_hashes.append(maybe_convert_block_hash(block_hash))
if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash = None
parent_block_hash: Optional[ExternalBlockHash] = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value()
parent_block_hash = maybe_convert_block_hash(
get_block_hash(parent_block.block_hash))
self.kv_event_queue.append(
BlockStored(
@ -220,7 +227,9 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
BlockRemoved(block_hashes=[
maybe_convert_block_hash(get_block_hash(block_hash))
],
medium=MEDIUM_GPU))
return True

View File

@ -6,11 +6,12 @@ import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import astuple, dataclass
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, Callable, NewType, Optional, Union
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.utils import GiB_bytes, cdiv, sha256_cbor
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
logger = init_logger(__name__)
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from ``bytes`` helps
# catch accidental misuse when passing around raw byte strings.
BlockHash = NewType("BlockHash", bytes)
# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID.
# It is represented as raw bytes for compactness and efficiency. The helper
# functions below pack/unpack the ``BlockHash`` and group id into/from the key.
BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes)
# ExternalBlockHash is used for reproducible prefix-cache block hashing.
# It's a union of ``bytes`` and ``int`` to keep backward compatibility
# after we default block hashing to use sha256 bytes.
ExternalBlockHash = Union[bytes, int]
class BlockHash(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
def make_block_hash_with_group_id(block_hash: BlockHash,
group_id: int) -> BlockHashWithGroupId:
"""Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``.
The group id is encoded using 4 bytes in big-endian order and appended to
the block hash bytes. This representation avoids creating tuples while
still allowing us to recover both components when needed.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
return BlockHashWithGroupId(block_hash +
group_id.to_bytes(4, "big", signed=False))
class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_block_hash(key: BlockHashWithGroupId) -> BlockHash:
"""Extract the ``BlockHash`` from a ``BlockHashWithGroupId``."""
return BlockHash(key[:-4])
def get_hash_value(self) -> int:
return self.block_hash.hash_value
def get_group_id(key: BlockHashWithGroupId) -> int:
"""Extract the group id from a ``BlockHashWithGroupId``."""
return int.from_bytes(key[-4:], "big", signed=False)
def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES:
return hash_bytes
return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1)
logger = init_logger(__name__)
# The hash seed for the first block of any prefix block sequence.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
# variable if set such that processes can share the seed if needed. This aligns
# with the behavior of Python's hash() function, which also uses a random seed
# if PYTHONHASHSEED is not set.
#
# The function `init_none_hash` initializes this variable globally.
NONE_HASH: int
NONE_HASH: BlockHash
def init_none_hash(hash_fn: Callable):
def init_none_hash(hash_fn: Callable[[Any], bytes]):
global NONE_HASH
hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor_64bit:
if hash_seed is None and hash_fn is sha256_cbor:
logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor_64bit as the hash function."
"block-hashes when using sha256_cbor as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility.")
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
if hash_seed is None else hash_fn(hash_seed))
if hash_seed is None:
NONE_HASH = BlockHash(os.urandom(32))
else:
NONE_HASH = BlockHash(hash_fn(hash_seed))
class PrefixCachingMetrics:
@ -142,8 +162,8 @@ class KVCacheBlock:
block_id: int
# Reference count.
ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
# The hash key (block hash + group id) of the block, only available
# when the block is full and cached.
_block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks.
@ -177,7 +197,7 @@ class KVCacheBlock:
if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, "
f"_block_hash={self._block_hash!r}, "
f"prev_free_block={prev_block_id}, "
f"next_free_block={next_block_id})")
@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
hash_function: Callable[[Any], bytes],
parent_block_hash: Optional[BlockHash],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
Args:
hash_function: The hash function used to compute block hash.
parent_block_hash: The hash of the parent block. None
@ -533,7 +552,6 @@ def hash_block_tokens(
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
@ -544,26 +562,16 @@ def hash_block_tokens(
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash(
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)))
def get_request_block_hasher(
block_size: int,
caching_hash_fn: Callable[[Any],
int]) -> Callable[[Request], list[BlockHash]]:
caching_hash_fn: Callable[[Any], bytes],
) -> Callable[[Request], list[BlockHash]]:
"""
Returns a function which computes the list of un-computed block hashes
of a request.
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
of a request."""
def request_block_hasher(request: Request) -> list[BlockHash]:
start_token_idx = len(request.block_hashes) * block_size
@ -577,8 +585,8 @@ def get_request_block_hasher(
# last mm input.
curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \
if request.block_hashes else None
prev_block_hash_value = (request.block_hashes[-1]
if request.block_hashes else None)
new_block_hashes: list[BlockHash] = []
while True:
end_token_idx = start_token_idx + block_size
@ -598,7 +606,7 @@ def get_request_block_hasher(
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash.hash_value
prev_block_hash_value = block_hash
return new_block_hashes