[v1] Redo "Support multiple KV cache groups in GPU model runner (#17945)" (#18593)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-24 00:39:47 +08:00
committed by GitHub
parent 9520a989df
commit 6550114c9c
15 changed files with 469 additions and 203 deletions

View File

@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
use_mla=use_mla,
sliding_window=sliding_window)
def test_none_hash(monkeypatch):
@ -492,6 +495,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs(diff_kv_cache_config)
def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32),
]
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
assert merged_layer_spec.block_size == 16
assert merged_layer_spec.num_kv_heads == 32
assert merged_layer_spec.head_size == 64
assert merged_layer_spec.dtype == torch.float32
assert merged_layer_spec.sliding_window is None
different_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=16),
]
with pytest.raises(AssertionError):
different_layer_specs[0].merge(different_layer_specs)
full_spec = new_kv_cache_spec(num_kv_heads=32)
different_type_layer_specs = [
full_spec,
SlidingWindowSpec(
block_size=full_spec.block_size,
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
with pytest.raises(AssertionError):
different_type_layer_specs[0].merge(different_type_layer_specs)
with pytest.raises(AssertionError):
different_type_layer_specs[1].merge(different_type_layer_specs)
different_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
]
with pytest.raises(ValueError):
different_sliding_window_layer_specs[0].merge(
different_sliding_window_layer_specs)
same_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
]
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
same_sliding_window_layer_specs)
assert merged_layer_spec.sliding_window == 1
same_sliding_window_layer_spec_with_none = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
]
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
same_sliding_window_layer_spec_with_none)
assert merged_layer_spec.sliding_window == 1
@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),

View File

@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Check full block metadata
parent_block_hash = None
@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]
assert blocks.get_block_ids() == [[6]]
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -208,7 +208,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata
@ -233,13 +233,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4]
assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2)
@ -307,7 +307,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()

View File

@ -9,9 +9,11 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
@ -22,6 +24,27 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
block_ids=[[]],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
reqs: list[CachedRequestState] = []

View File

@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import weakref
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
kv_cache_config = KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=16,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
))
])
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
)
runner.initialize_attn_backend(kv_cache_config)
@pytest.fixture
@ -48,10 +70,12 @@ def model_runner():
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
device = "cuda"
@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
block_ids=[[0]],
num_computed_tokens=0,
lora_request=None,
))
@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
block_table = model_runner.input_batch.block_table[0]
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
if block_table.num_blocks_per_row[req_index] != len(
req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
new_block_ids=[[]],
num_computed_tokens=0,
)

View File

@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False)
total_need_load += 1
@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True)
@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids
block_ids = cached_req.new_block_ids[0]
meta.add_request(token_ids=token_ids,
block_ids=block_ids,

View File

@ -69,13 +69,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_model_len = self.runner.model_config.max_model_len
assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.runner.block_size
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device

View File

@ -32,9 +32,16 @@ class KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
def get_block_ids(self) -> list[list[int]]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* each inner list contains the block_ids of the blocks in that group
"""
return [[block.block_id for block in self.blocks]]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
@ -300,9 +307,9 @@ class KVCacheManager:
self,
request: Request,
num_running_requests: int,
) -> int:
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.
in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
@ -332,11 +339,14 @@ class KVCacheManager:
requests in the current step.
Returns:
int: The number of common prefix blocks.
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
return self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
return [
self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
]
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
@ -354,10 +364,8 @@ class KVCacheManager:
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
def get_block_ids(self, request_id: str) -> list[list[int]]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
).get_block_ids()

View File

@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]]
assert all(
kv_cache_spec[layer_name] == layer_spec
for layer_name in layer_names_one_group[1:]), (
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
layer_specs = [
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
]
merged_layer_spec = layer_specs[0].merge(layer_specs)
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec))
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups
@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)

View File

@ -26,7 +26,7 @@ class NewRequestData:
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[int]
block_ids: list[list[int]]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -34,7 +34,7 @@ class NewRequestData:
def from_request(
cls,
request: Request,
block_ids: list[int],
block_ids: list[list[int]],
) -> NewRequestData:
return cls(
req_id=request.request_id,
@ -85,7 +85,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: list[int]
new_block_ids: list[list[int]]
num_computed_tokens: int
@classmethod
@ -94,7 +94,7 @@ class CachedRequestData:
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[int],
new_block_ids: list[list[int]],
) -> CachedRequestData:
return cls(
req_id=request.request_id,
@ -131,9 +131,9 @@ class SchedulerOutput:
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests.
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: int
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests

View File

@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {}
req_to_new_block_ids: dict[str, list[list[int]]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
@ -486,7 +486,8 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
@ -573,7 +574,7 @@ class Scheduler(SchedulerInterface):
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
new_block_ids: list[list[int]],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
@ -949,7 +950,9 @@ class Scheduler(SchedulerInterface):
"""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
@ -966,9 +969,10 @@ class Scheduler(SchedulerInterface):
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1

View File

@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from dataclasses import dataclass
from typing import Optional
import torch
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -53,6 +56,16 @@ class KVCacheSpec:
"""
raise NotImplementedError
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
return copy.deepcopy(specs[0])
@dataclass
class AttentionSpec(KVCacheSpec):
@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
@dataclass
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@property
def type_id(self) -> str:
@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
return merged_spec
@dataclass
class SlidingWindowSpec(AttentionSpec):

View File

@ -4,6 +4,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__)
@ -96,3 +97,43 @@ class BlockTable:
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_size: int) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]

View File

@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5
@ -29,7 +29,7 @@ class CachedRequestState:
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: list[int]
block_ids: list[list[int]]
num_computed_tokens: int
output_token_ids: list[int]
@ -58,15 +58,14 @@ class InputBatch:
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_size: int,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
@ -99,12 +98,13 @@ class InputBatch:
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = BlockTable(
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_size=block_size,
)
# Sampling-related.

View File

@ -12,6 +12,8 @@ import torch.distributed
import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
@ -32,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
check_use_alibi, is_pin_memory_available)
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -51,6 +53,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -103,59 +106,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
# NOTE(woosuk): sliding_window is None for models with interleaved
# attention. Use interleaved_sliding_window instead.
self.sliding_window = model_config.get_sliding_window()
self.interleaved_sliding_window = getattr(
model_config.hf_text_config, "interleaved_sliding_window", None)
self.window_size = (self.sliding_window
or self.interleaved_sliding_window)
self.is_multimodal_model = model_config.is_multimodal_model
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = self.attn_backend.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
@ -177,8 +138,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
self.attn_backends: list[type[AttentionBackend]] = []
# self.kv_cache_config: KVCacheConfig
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
# self.input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@ -207,15 +170,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
@ -311,6 +274,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
Returns:
True if the batch was reordered, False otherwise.
"""
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
self.input_batch, scheduler_output)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
assert not self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
return batch_reordered
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@ -447,7 +435,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids)
for i in range(len(self.kv_cache_config.kv_cache_groups)):
req_state.block_ids[i].extend(req_data.new_block_ids[i])
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
@ -505,11 +494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
batch_reordered = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
batch_reordered = self._may_reorder_batch(scheduler_output)
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
@ -577,21 +562,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table.
slot_mapping_np[:total_num_scheduled_tokens])
# Calculate the slot mapping for each KV cache group.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
@ -633,10 +626,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata: dict[str, FlashAttentionMetadata] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
@ -645,15 +634,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
self.attn_metadata_builders[kv_cache_group_id],
)
attn_metadata_i = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@ -691,6 +684,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
num_scheduled_tokens: np.ndarray,
num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder,
) -> int:
"""Compute the length of the common prefix for cascade attention.
@ -709,7 +704,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len = num_common_prefix_blocks * self.block_size
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
if common_prefix_len == 0:
# Common case.
return 0
@ -758,15 +753,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = self.attn_metadata_builder.use_cascade_attention(
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
kv_cache_spec.block_size)
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
(isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.sliding_window is not None))
assert isinstance(kv_cache_spec, AttentionSpec)
use_cascade = attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
num_kv_heads=kv_cache_spec.num_kv_heads,
use_alibi=self.use_alibi,
use_sliding_window=self.window_size is not None,
use_sliding_window=use_sliding_window,
num_sms=self.num_sms,
)
return common_prefix_len if use_cascade else 0
@ -1661,7 +1660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=np.int32)
if skip_attn:
attn_metadata = None
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
@ -1669,13 +1668,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
attn_metadata = {}
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@ -1909,6 +1914,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
raise NotImplementedError(
"Only AttentionSpec is supported for now.")
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = attn_backend_i.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is "
f"{attn_backend_name}, FlashAttention version is "
f"{flash_attn_version}.")
block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
weakref.proxy(self), kv_cache_spec, block_table_i)
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -1921,10 +1976,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config
self.initialize_attn_backend(kv_cache_config)
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
@ -1939,7 +1995,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
@ -1959,11 +2015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
self.input_batch.block_table)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.vocab_size,
)
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
dtype=torch.int32,
device="cpu")
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table.
out=self.input_batch.block_table[0].
slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.input_batch.block_table.slot_mapping_cpu[
self.input_batch.block_table[0].slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = (
self.input_batch.block_table.
self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
@ -1263,6 +1254,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
)
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups: